12월, 2022의 게시물 표시

From In-context learning to RLHF (Feat. ChatGPT)

이미지
TL;DR 거시적인 발전 과정 : In-context Learning -> Instruction Tuning -> RLHF -> RLAIF In-Context Learning은 Large Scale 언어모델을 tuning하지 않고 새로운 task에 적용할 수 있는 직관적인 방법을 제시함 Instruction Tuning은 다양한 task를 Instruction + example의 템플릿으로 캐스팅하여 Implicit하게 multi-task로 tuning하며 결과적으로 Unseen Task를 더 잘 수행함  RLHF는 인간의 선호도라는 애매한 척도를 모델링하는 Reward Model과 강화 학습을 활용하여 언어 모델을 개선하는 방법을 제시함 RLAIF는 Human Labeling Cost를 없애고 RLHF에서 추가적으로 helpfulness와 harmlessness를 모두 개선할 수 있는 자동화된 파이프라인을 제시함  In-Context Learning & Instruction Tuning 오늘은 요즘 핫한 ChatGPT와 관련된 이야기를 하려고 한다. 바로 zero-shot의 가능성을 보여준 In-context Learning의 시작과 그것을 더 개선시킨 Instruction Tuning, 마지막으로 화룡정점을 찍은 Reinforcement Learning From Human Feedback에 대한 내용이다. 이후 Anthropic에서 RLHF를 시스템적으로 개선한 RLAIF라는 방식을 추가적으로 제안하기도 했다.  In-context Learning In-context learning은 GPT3 에서 소개되면서 pretraining-finetuning paradigm의 대안을 제시했다. 언어모델이 충분히 크고(도표에 의하면 6B 이상) 다량의 corpus로 학습했다면 사람의 자연어 instruction을 이해하고 바람직한(의도에 맞는) 텍스트를 생성할 수 있는 능력을 가지고 있다는 것이다. 예를 들어 언어모델을 QA 태스크에 명시적으

Contrastive Learning in Text Generation

이미지
Contrastive Learning Concept contrastive learning은 어떤 샘플에 대해서 같은 클래스에 속하는 샘플을 positive sample로 보고 가까워지게 하고, 그렇지 않은 샘플을 negative sample로 보고 멀어지게 하는 학습 방식을 말한다. 컨셉은 Triplet Loss를 사용한 Metric Learning과 유사하다. Triplet Loss는 아래와 같이 정의된다.  $$ \mathcal{L}(A, P, N) = \max \left( \|f(A) - f(P) \|^2 - \|f(A) - f(N) \|^2 + \alpha, 0 \right) $$ \( A\)는 Anchor, \(P\)는 positive sample, \(N\)은 negative sample을 뜻한다. 위의 loss는 \( \alpha \)라는 Margin 값을 갖는데, negative sample과의 거리가 positive sample과의 거리보다 \(\alpha\)이상 떨어져 있어야 한다는 뜻이다. 비슷한 맥락에서 Contrastive Loss도 positive sample을 가까워지게 하고, negative sample을 멀어지게 하기 위해 InfoNCE loss( Oord et al. )를 사용한다.  $$ \mathcal{L} = - \log \frac{\exp{(\cos(z_x, z_y) / \tau )}}{\sum_{y' \in \mathcal{B}} \exp{(\cos(z_x, z_{y'}) / \tau )}} $$ Triplet Loss와의 차이점은 negative sample을 in-batch로 가지고 올 수 있다는 점과 확률 분포의 관점에서 최적화를 한다는 것이다. InfoNCE loss는 이론적으로 anchor condition \(c\)와 target인 negative sample 또는 positive sample간의 mutual information을 근사한다.  $$ I(x,c) \geq \log(N) - \ma