currybab's devlog

Llama Pro 정리

해당 paper에서는 catastrophic forgetting을 효과적이고 효율적으로 해결하기 위한 구조로 Llama Pro 모델을 제안한다. catastrpohic forgetting이란 LLM을 training할 때에 오래된 정보를 잊는 것을 의미한다. 예를 들어 code-llama는 llama2를 code-specific한 데이터셋으로 추가 훈련을 시킨 모델인데 코딩을 더 잘 이해하지만 그만큼 llama2에 비해서 일반적인 벤치마크에서 성능이 감소하였다. 해당 논문에서는 효과적인 post training 방법으로 block expansion을 제안한다.

attention_layer

사전 지식: LLaMA 블록

라마 블록은 Multi Head Self Attention 블록과 (SwiGlu와 residual connection이 있는) position-wise FFN으로 이루어져있다. 라마 블록의 입력을 $x$, 출력을 $y$라고 하면, $$ x\prime = x + MHSA(RMSNorm(x)) $$ $$ y = x\prime + FFN(RMSNorm(x\prime)) $$ 입력 $x$가 sequence length $n$과 hidden dimension $d$를 가지고 있으면 $n \times d$ 차원을 갖게 된다. 출력 $y$ 역시도 같은 차원을 가진다. MHSA는 다음과 같이 정의된다 $$ MHSA(Q,K,V) = Concat(head_1,…,head_h)W^{O} $$ $$ head_{i} = Attention(x W_{i}^{Q}, x W_{i}^{K}, x W_{i}^{V}) $$ $$ Attention(Q_i, K_i, V_i) = Softmax(\frac{Q_{i}K_{i}^{T}}{\sqrt{d_{k}}}) V_{i} $$ FFN 블록에서 라마는 SwiGLU 활성화 함수를 사용한다. $\otimes$는 element-wise multiplication을 의미한다. $$ SwiGLU(x, W, V) = SiLU(xW) \otimes (xV) $$ $$ FFN(x) = SwiGLU(x, W_1, W_2)W_3 $$ $$ SiLU(x) = x \otimes \sigma(x) $$

Block Expansion

full_model

identity block이 추가된 이후에 모델이 기존 모델과 같은 값을 내어야한다. 즉 identity block은 $ \phi(x) = x $를 만족한다.(입력과 출력이 동일) 또한 이 블록을 기존 블록들에 교차해서 넣는다.

Shen 등에 따르면 identity block에서 Norm 모듈의 scale parameter을 0으로 초기화 하는 것을 제안했는데 라마 프로에서는 이 방법이 잘 work하지 않았다. 이유로는 역전파 동안 손실 함수 L의 기울기가 RMSNorm 가중치 w에 대해 0이 되기 때문으로 이것이 RMSNorm의 훈련을 막기 때문이다. $$ \frac{\partial L}{\partial w} = \frac{\partial L}{\partial y} \frac{\partial FFN(RMSNorm(x \prime))}{\partial RMSNorm(x \prime)} \frac{\partial RMSNorm(x \prime)}{\partial w} = 0 $$ 그래서 라마프로에서는 RMSNorm을 변형하는 대신에 $W^{O}$항(o_proj)과 $W_{3}$항(down_proj)의 weight를 0으로 초기화 하였다. 이렇게 함으로써 초기에 잔여 연결만을 통과시킴으로써 기존 모델과 동일한 출력을 가질 수 있었다.

모델이 추가적인 도메인 지식을 수용하면서 일반 지식을 유지할 수 있는 능력을 향상시키기 위해서 LLM의 블록 수를 증가시키기 위해 블록 확장을 사용하였다. 또한 원래의 블록을 freeze하고 새로 추가된 모델만을 파인튜닝함으로써 모델의 일반적인 능력을 보존하였다.

Training

Pretrain Detail

SFT Detail

Ablation Study

#Train