LLM 논문을 분야 가리지 않고 여럿 읽던 중에 Nvidia의 경량화 관련 논문이 나와서 가볍게 읽어보고 리뷰해보려 한다. 이 기법은 nvidia에서 nemotron을 경량화하여minitron을 만드는 데 사용했다고 한다.
Abstract&Introduction
저자는 이 논문에서 현존하는 LLM을 기존보다 작은(<3%)의 데이터셋으로 경량화하고 retraining하는 방법론을 제안 경량화 방법으로는 pruning을 사용했으며 retraining에는 knowledge distillation(KD)를 사용했으며 기존 Nemotron4-15B모델을 8B와 4B로 경량화하였다.
연구의 시작점은 다음과 같은데 LLAMA, GPT 등 최근 LLM 모델들은 대게 7B, 13B, 70B 등 여러 parameter size의 모델을 scratch 부터 학습한다. 저자는 이 방식보다는 하나의 big model을 training 시킨 후 작은 모델로 retraining 시키는 것을 목표로 했다.
LLM은 확실한 retraining 방법론이 없어 새로운 방법이 없어 pruning을 통해 head, neuron, embedding channel 등 다양한 공간을 drop해보며 최선의 방법을 찾아냈으며 저자의 방법을 통해 경량화된 minitron-8B는 Nemotron-3 8B보다 좋은 성능을 보여주었다.
Pruning Methodology
위의 3개(Emb, Head, CH)를 적절한 방식으로 importance를 계산, sorting하고 낮은 부분은 버린다. 그리고 버리기 전에 그 잘려진 정보를 원래 층에 residual하게 넣어준다.
Importance Analysis
최근 LLM pruning에는 cosine similarity, perplexity to calibration dataset을 사용해 importance를 계산하는데 저자는 이와 다르게 activation based importance estimation strategy를 세웠다.
Importance 계산을 위해 여러 모듈(head, neuron, embed, layer)들을 scoring해줘야 하는데 이때 small calibration dataset(1024)으로 forward pass 후 사용한 가중치를 이용했다. 이 논문에서는 크게 두 가지 방식으로 importance를 계산했다.
width
width 축(head, neuron, emb)에서는 forward path의 activation(output of XW^T)을 계산해 점수를 매겼다.(코드가 안 나와 있어서 자세히는 알 수 없음)
layer
layer 축에는 perplexity 혹은 BI(1- attention 전후의 cos similarity)를 사용해 점수를 매겼다. PPL 계산은 각 층을 하나씩 빼 보면서 score를 기록했다.
Obtaining a Pruned Model
그리고 MHA, MLP, embed 축에서 적절히 trim하고 MHA는 residual connection을 이용해 정보를 보존한다고 한다. (논문에 자세한 과정이 적혀 있지 않고 코드도 비공개라 정확히 어디를 얼마나 그리고 어떤 식으로 이러 붙였는지는 확인할 수 없었다.)
Retraining
여기서 retraining은 Pruning으로 떨어진 성능을 다시 끌어올리는 작업을 의미한다. 즉 uncompressed model과 compressed 모델이 teacher와 student 역할을 한다.
L_logits 계산은KD와 같은데 한 가지 추가된게 모든seq, len에 대한 weight(즉 모든 가중치)를 동일하게 하기 위한 loss(L_is)를 도입했다는 점
최종적으로는 L = L_CLM + Llogits + α × Lis; (L_CLM은 student와 gt 간의 CE loss) 이러한 loss를 사용한다.
Experiments and Results
성능 비교를 위해 사용한 모델은 Nemotron-4 15B와 이를 사용해 4B와 8B로 경량화 한 모델을 사용했다. Nemotran-4 15B는 8T token으로 학습되었고 retraining에는 1.8B token(400 step)이 사용되었다.
Training token을 중점적으로 보면 된다.를 경량화한 것이므로 성능은 전체적으로 그에 비해 낮으나 상대적으로 적은(<5T)의 데이터로 학습된 모델보다는 좋은 성능을 보인다.
또한 상당히 낮은 토큰 수로도 현재 SOTA SLM 모델과 유사한 성능을 낼 수 있었다.
성능은 Random init→Prune and training→KD 순으로 어느 정도어느정도 선형성을 띤다.