From In-context learning to RLHF (Feat. ChatGPT)

이미지
TL;DR 거시적인 발전 과정 : In-context Learning -> Instruction Tuning -> RLHF -> RLAIF In-Context Learning은 Large Scale 언어모델을 tuning하지 않고 새로운 task에 적용할 수 있는 직관적인 방법을 제시함 Instruction Tuning은 다양한 task를 Instruction + example의 템플릿으로 캐스팅하여 Implicit하게 multi-task로 tuning하며 결과적으로 Unseen Task를 더 잘 수행함  RLHF는 인간의 선호도라는 애매한 척도를 모델링하는 Reward Model과 강화 학습을 활용하여 언어 모델을 개선하는 방법을 제시함 RLAIF는 Human Labeling Cost를 없애고 RLHF에서 추가적으로 helpfulness와 harmlessness를 모두 개선할 수 있는 자동화된 파이프라인을 제시함  In-Context Learning & Instruction Tuning 오늘은 요즘 핫한 ChatGPT와 관련된 이야기를 하려고 한다. 바로 zero-shot의 가능성을 보여준 In-context Learning의 시작과 그것을 더 개선시킨 Instruction Tuning, 마지막으로 화룡정점을 찍은 Reinforcement Learning From Human Feedback에 대한 내용이다. 이후 Anthropic에서 RLHF를 시스템적으로 개선한 RLAIF라는 방식을 추가적으로 제안하기도 했다.  In-context Learning In-context learning은 GPT3 에서 소개되면서 pretraining-finetuning paradigm의 대안을 제시했다. 언어모델이 충분히 크고(도표에 의하면 6B 이상) 다량의 corpus로 학습했다면 사람의 자연어 instruction을 이해하고 바람직한(의도에 맞는) 텍스트를 생성할 수 있는 능력을 가지고 있다는 것이다. 예를 들어 언어모델을 QA 태스크에 명시적으

lassl을 이용한 언어모델 사전학습 (Feat. T5, UL2)

업데이트

[22/10/4] 모두의말뭉치 + alpha로 학습한 KoUL2 모델을 huggingface hub에 릴리즈했습니다!

----------------------------------------

이 포스트에서는 lassl 오픈소스를 사용해서 실제로 한국어 코퍼스를 활용한 사전학습 모델을 만드는 법을 다룹니다. 제가 참여해서 구현한 T5와 UL2의 구현 방식, 고민한 내용들을 공유하고자 합니다. Lassl 소스코드는 여기에서 확인하실 수 있습니다.


TFRC 프로그램 소개 및 신청하기

1. 다음 사이트에 가서 TRC 프로그램의 form을 작성하고 신청합니다. 

https://sites.research.google/trc/about/ 

2.  며칠 지나면 다음과 같은 메일이 옵니다.


3. console.cloud.google.com에서 프로젝트를 생성한 뒤 그 프로젝트 ID를 메일 내 링크를 통해 입력하면 됩니다. 그러면 다음과 같은 승인 메일이 며칠 내에 옵니다.


4. 이제 TPU instance를 무료로 사용할 수 있습니다!


GCP TPU 인스턴스와 디스크 만들기

1. 첫 번째 단계로 gcloud cli를 설치해야 합니다. 해당 내용은 플랫폼 별로 자세하게 설명되어 있으니 여기를 참조해주세요. 다만 아래에 설명하는 명령어들은 Unix 계열이기 때문에 Window를 사용하시는 경우 WSL이나 도커를 활용하시는 것이 따라하시기 편할 것 같습니다.

2. 다음 명령어로 디스크와 tpu 인스턴스를 생성합니다. 디스크는 선택사항이지만 편하게 쓰려면 붙이시면 좋습니다. 사용 시 경험상 매일 1000원 안팎으로 청구됩니다. 다만 기본 인스턴스에도 100GB가 있기 때문에 잘 아껴 쓰면 디스크를 붙이지 않아도 됩니다. 
export GCP_DISK_NAME=lassl-disk
export GCP_INSTANCE_NAME=lassl-tpu
export GCP_PROJECT=<your_project_id> # i.e. fast-cascade-123456
export GCP_ZONE=europe-west4-a
gcloud compute disks create $GCP_DISK_NAME \
--project=$GCP_PROJECT \
--type=pd-standard \
--size=500GB \
--zone=$GCP_ZONE
gcloud alpha compute tpus tpu-vm create $GCP_INSTANCE_NAME \
--accelerator-type v3-8 \
--version tpu-vm-pt-1.10 \
--zone $GCP_ZONE \
--project $GCP_PROJECT \
--data-disk source="projects/${GCP_PROJECT}/zones/${GCP_ZONE}/disks/${GCP_DISK_NAME}"

3. 디스크를 생성한 경우 다음 명령어로 포맷 및 마운트할 수 있습니다.

# Connect tpu-vm node
gcloud alpha compute tpus tpu-vm ssh $GCP_INSTANCE_NAME
# Check that your disk is visible and get its name
lsblk
# Mount disk in the data folder
sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,discard /dev/sdb
sudo mkdir -p workspace
sudo mount -o discard,defaults /dev/sdb workspace
sudo chmod a+w workspace

4. 다음으로는 학습 데이터를 vm 안에 옮겨야 합니다. GCP_USER_NAME은 gcloud_cli를 설치한 그 머신(도커든 wsl이든, unix 계열이면 원래 계정)의 유저네임으로 설정하시면 됩니다. GCP_INSTANCE_PORT는 gcloud console에 들어가서 external_ip를 찾아서 입력하시면 됩니다.  

데이터가 로컬에 있는 경우는 아래 명령어와 같이 rsync를 통해 옮길 수 있고, 구글 드라이브에서 직접 다운받을 수 있는 경우에는 tpu vm에 접속한 상태로 gdown을 사용해서 다운 받는 편이 훨씬 빠릅니다. 

export GCP_USER_NAME=<your_user_name>
export GCP_INSTANCE_PORT=<tpu_vm_external_ip>
rsync -avP -e "ssh -i ~/.ssh/google_compute_engine" corpora $GCP_USER_NAME@$GCP_INSTANCE_PORT:/home/$GCP_USER_NAME/workspace/
cd ~/workspace/lassl/
pip3 install . 


5.  train_tokenizer, serialize_corpora를 먼저 완료한 후에 pretrain_language_model을 실행할 수 있습니다. 원하는 모델 타입의 config 파일을 설정 한 뒤에 아래 커맨드로 학습을 실행하면 됩니다. 
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
python3 xla_spawn.py --num_cores 8 pretrain_language_model.py --config_path <config_path> # i.e. configs/train-plm-with-tpu-t5.yaml

Troubleshooting

  • 학습 시 torch_xla 관련 이슈
    • core dumped 뜨면서 프로세스 꺼진다면 torch==1.9.0 및 xla==1.9.0으로 버전 낮춰보기
    • sudo bash /var/scripts/docker-login.sh sudo docker rm libtpu || true 
      sudo docker create --name libtpu gcr.io/cloud-tpu-v2-images/libtpu:pytorch-1.9 "/bin/bash" sudo docker cp libtpu:libtpu.so /lib 
      sudo pip3 uninstall --yes torch torch_xla torchvision 
      sudo pip3 install torch==1.9.0 
      sudo pip3 install torchvision==0.10.0 
      sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.9-cp38-cp38-linux_x86_64.whl

    • pretrain_language_model.py 실행 전에 xla 환경변수가 설정되지 않은 경우 에러 발생함
    • export XRT_TPU_CONFIG="localservice;0;localhost:51011"
    • symbol을 못찾겠다고 뜰 때 : torch version을 맞추기(downgrade or upgrade)
    • libtpu error는 torch_xla 버전과 현재 설치된 libtpu 버전이 호환되지 않아 생기는 이슈이므로 적절한 version을 설치해주면 됨. 참고
  • distutils 이 version이 없다고 나오는 경우
    • pip install setuptools==59.5.0


사전 학습의 논리적 순서

lassl을 기준으로 사전학습이 진행되는 순서에 대해서 설명하면, 먼저 적절한 형태(sent_text)로 corpus를 수집하고 해당 corpus를 활용해서 tokenizer를 학습합니다. 만일 이미 학습된 tokenizer가 있다면 이 과정은 생략할 수 있습니다. 이어서 Processor를 활용해서 사전학습 방식(i.e. BERT, T5, ...)에 맞게 corpus를 tokenize한 후 적절한 사이즈로 잘라서 huggingface dataset 형태로 적재합니다. 대부분의 경우 사전학습에 사용할 데이터 샘플의 시퀀스 길이에서 특수토큰(</s>, <s>) 등의 개수를 뺀 길이로 전 처리를 해둡니다. 마지막으로 실제로 학습할 때는 이 데이터를 읽어들여서 Collator에서 해당 모델에 맞게 batch를 dynamic하게 재구성하게 됩니다. 예를 들어 BERT라면 MLM을 하기 위해 random token을 masking하고, next sentence prediction을 위해 [sep]을 사이에 두고 두 문장을 이어서 준비할 것입니다. 만약 T5라면 sentinel token을 활용하여 span corruption을 수행하여 input과 label을 만들 것입니다. lassl 내에서 실제로 파일이 실행되는 순서는 train_tokenizer.py -> serialize_corpora.py -> pretrain_language_model.py 입니다. 더 자세한 설명은 lassl Readme에 잘 나와 있으니 참고해 주시면 좋겠습니다.


T5 및 UL2 구현 방식 

T5, UL2의 사전학습 방식

T5는 Text-To-Text Transfer Transformer의 약자로 seq2seq 계열 태스크에서는 스탠다드로 여겨지고 있는 사전학습 모델입니다. T5는 사전학습을 위해 span corruption을 사용합니다. BART나 BERT와 유사하게 masking을 하는데 독특한 점은 한 샘플 내에서 여러개의 span이 만들어지는 경우 유니크한 마스크로 치환한다는 점입니다. 예를 들어 A Brown Fox Jumps Over The Wall 이라는 문장에서 A Brown과 The Wall이 치환된다고 했을 때 input은 <extra_id_0> Jumps Over <extra_id_1></s>가 되고 label은 <extra_id_0> A Brown <extra_id_1> The Wall <extra_id_2></s>가 됩니다. label에 extra_id_2가 추가된 것은 1번 mask에 대한 예측이 끝났음을 나타내며, </s>는 샘플이 끝났음을 나타냅니다. 

UL2는 T5의 디노이징 함수를 다른 세팅으로 여러 개 만드는 Mixture-of-denoisers 방식을 제안하였습니다. 논문에서 제안한 방식은 r-denoiser 2개 세팅, x-denoiser 4개 세팅, s-denoiser 1개 세팅으로 총 7가지 세팅입니다. 자세한 내용은 이 포스트에 잘 설명되어 있습니다. 구현 방식은 T5 denoiser를 여러 개 만들어 두고 batch 내에서 sample마다 alternating 하는 방식이라고 생각하시면 됩니다. 따라서 denoiser 여러 개를 관리하기 위한 사소한 추가사항 외에는 T5의 구현과 동일하다고 보실 수 있습니다.

새로운 모델을 추가하기 위해 구현해야 할 부분

lassl은 huggingface Trainer 및 Datasets를 기반으로 구현되어 있어서 주요한 학습 로직이나 데이터셋 관리는 건드리지 않아도 됩니다. 따라서 저희가 구현해야 하는 부분은 모델 별로 어떻게 corpus를 처리할지(serialize_corpora)와 학습 시점에 어떻게 샘플들을 가공해서 실제 training에 필요한 input, output, mask 등을 생성할지(collator)에 해당하는 부분입니다. 

1. Processor 구현

Processor는 raw text를 가공하여 collator에서 처리하기 적절한 길이로 미리 tokenizing 해둡니다. 여기에서 사용하는 tokenizer는 train_tokenizers.py를 사용해서 직접 학습해도 되고 huggingface hub에 있는 tokenizer를 활용해도 됩니다. lassl의 핵심 디자인 패턴은 정적인 배치처리를 추구한다는 점입니다. tokenizing 할 때 미리 input과 label을 만들어 둔다면 하나의 샘플에서 얻을 수 있는 다양한 sementic을 잃어버리기 때문입니다. 

T5의 경우 corruption을 통해 어느정도의 비율(예를 들어 15%)의 token을 노이즈로 취급할 것인지를 사전에 정의할 수 있고, mean_span_length를 정의하여 전체 token의 개수를 계산 할 수 있습니다. 예를 들어 512개의 길이를 갖는 샘플을 이용하여 사전학습을 수행하고 싶고, 15% 마스킹, 평균 스팬 길이가 3이라면 77개의 토큰이 마스킹 되고, 스팬은 26개가 필요합니다. 이제 마스킹을 한 뒤에 input의 길이와 label의 길이 또한 예측할 수 있습니다. input의 길이는 512 - 77 + 26 + 1(</s>) = 462이 되고  label의 길이는 77 + 26 + 1(last sentinel) + 1(</s>) = 105 이 됩니다. 

여기에서 우리는 학습 시에 사용되는 샘플의 길이(input과 label 중 큰 것)가 사전에 정의된 512이길 원합니다. 그렇게 하기 위해서는 주어진 세팅 (0.15, 3)으로 corruption 했을 때 둘 중 긴것의 길이가 512가 나오도록 해야 합니다. 위의 corruption rate으로는 반드시 input이 더 길 수밖에 없으므로 input을 기준으로 살펴보겠습니다. 이 길이를 만들기 위해 t5 구현체를 살펴보면 corruption 후의 input 길이를 반환하는 함수를 만들고 pre-noise input 길이를 1씩 늘려 가면서 corruption 후에 512 길이가 나올 때까지 반복하는 방식으로 적절한 길이를 계산합니다. 


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def compute_indv_chunk_size(target_length, noise_density, mean_span_length):
    '''pre-corruption token length approximation for T5 and UL2'''
    def _tokens_length_to_inputs_length_targets_length(tokens_length):
        '''
            https://github.com/google-research/text-to-text-transfer-transformer/blob/c3be7cf1c20e5f6d83e6de99377b653a3a0bc44a/t5/data/preprocessors.py#L2648 
        '''
        # setting mean_span_length to None means prefix-lm that masks last 25% tokens
        num_noise_tokens = int(round(tokens_length * noise_density))
        num_nonnoise_tokens = tokens_length - num_noise_tokens
        num_noise_spans = 1 if mean_span_length is None else int(round(num_noise_tokens / mean_span_length))
        # inputs contain all nonnoise tokens, sentinels for all noise spans
        # and one EOS token.
        return (
            num_nonnoise_tokens +
            num_noise_spans + 1,
            num_noise_tokens +
            num_noise_spans + 1)
    
    tokens_length = target_length - 1
    while (_tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0]
            <= target_length):
        tokens_length += 1
    inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length (tokens_length)
    if mean_span_length is None:
        mean_span_length = targets_length - 2
    return tokens_length, targets_length, mean_span_length


2. Collator 구현하기

collator 구현에서 핵심은 noise mask를 만들고, 그 noise mask 부분을 적절한 sentinel로 치환하는 것입니다. 먼저 노이즈 마스크를 만드는 핵심은 noise span의 개수가 사전에 정해져 있다는 점과, non-noise 부분과 noise 부분이 번갈아가며 등장해야 한다는 것입니다. 따라서 노이즈 토큰과 non-noise 토큰을 정해진 span 개수만큼 random하게 쪼개고, non-noise 부분부터 번갈아가며 나오도록 interleave하는 방식을 사용합니다. 이제 만들어진 노이즈 마스크를 활용해서 sentinel token으로 치환해야 합니다. t5 tokenizer를 뜯어보면 <extra_id_0>이 가장 큰 id를 가지고 있고 id가 1씩 감소할 때마다 다음 extra_id가 등장하는 식입니다. 따라서 noise span을 만날 때마다 한개씩 감소시키는 방식으로 extra_id를 배치할 수 있습니다. ul2의 경우에는 더 극단적인 노이즈 함수를 사용하기 때문에 extra_id가 더 필요합니다. 이 경우 [new_id_1]~[new_id_27]을 만들어서 new_id_27이 가장 큰 id를 갖도록 구현하였습니다. 따라서 사용되는 token은 [new_id_27] ... [new_id_1] <extra_id_0> ... <extra_id_99> 순입니다.

UL2에서 원본 구현체와 다소 달라진 부분이 있는데, 원본 구현체는 [NLU], [NLG], [S2S]와 같은 토큰을 따로 특수토큰으로 처리하지 않고 사용하였습니다. 하지만 lassl에서 동일하게 진행할 경우 [NLU], [S2S]는 5개의 subword, [NLG]의 경우 4개의 subword로 쪼개지게 되어 로직에 문제가 생기는 상황이 있었습니다. 따라서 이 이슈를 방지하기 위해 각각의 task prefix를 특수토큰으로 추가하여 사용하였습니다.


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def random_spans_noise_mask(noise_density: float, mean_span_length : float, length : int) -> torch.BoolTensor:
    ''' pytorch-ported version of https://github.com/google-research/text-to-text-transfer-transformer/blob/bb545f19ec221e6203dd05505573fbc0c0a9001f/t5/data/preprocessors.py#L2901'''
    orig_len = length
    length = max(length, 2) # set minumum to 2 to avoid degeneracy
    num_noise_tokens = round(noise_density * length)
    num_noise_tokens = min(max(num_noise_tokens, 1), length-1) # set maximum to length-1 
    num_noise_spans = round(num_noise_tokens / mean_span_length)
    num_noise_spans = max(num_noise_spans, 1) # set minumum to 1
    num_nonnoise_tokens = length - num_noise_tokens

    def _random_segmentation(num_items, num_segments):
        # affected by global seed
        bars = torch.arange(num_items-1) < num_segments-1
        bars = bars[torch.randperm(bars.size(0))]
        bars = torch.cat((torch.tensor([0]), bars), dim=0) # to make segment 0 nonzero
        segment_id = torch.cumsum(bars, dim=0)
        segment_length = torch.zeros(num_segments, dtype=torch.long).scatter_add(0, segment_id, torch.ones_like(segment_id))
        return segment_length 
    
    noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
    nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
    interleaved_span_lengths = torch.stack((nonnoise_span_lengths, noise_span_lengths), dim=1).reshape(-1)
    span_starts = torch.cumsum(interleaved_span_lengths, dim=0)[:-1]
    span_start_indicator = torch.zeros(length).long().scatter(0, span_starts, torch.ones_like(span_starts))
    span_num = torch.cumsum(span_start_indicator, dim=0)
    is_noise = span_num % 2 == 1
    return is_noise[:orig_len]

def noise_span_to_unique_sentinel(tokenizer, tokens, noise_mask, append_last_sentinel=False, denoiser_prefix : Optional[str] = None, first_extra_id : str = "<extra_id_0>") -> torch.LongTensor:
    ''' pytorch-ported version of https://github.com/google-research/text-to-text-transfer-transformer/blob/bb545f19ec221e6203dd05505573fbc0c0a9001f/t5/data/preprocessors.py#L3074'''
    if not isinstance(tokens, torch.Tensor):
        tokens = torch.tensor(tokens)
    
    # sample consecutive substring from tokens if len(tokens) > len(noise_mask)
    # in case of T5, these two should match. In case of UL2, due to use of several denoisers, number of tokens could be larger than length of noise masks.
    if len(tokens) > len(noise_mask):
        offset = len(tokens) - len(noise_mask)
        random.seed(tokens[0].item()) # seed that makes same example to match in both making inputs and targets
        start_idx = random.randint(0,offset)
        tokens = tokens[start_idx : start_idx + len(noise_mask)]
        assert len(tokens) == len(noise_mask)

    prev_token_is_noise = torch.cat((torch.tensor([0]), noise_mask[:-1]), dim=0).bool()
    first_noise_tokens = torch.logical_and(
        noise_mask, torch.logical_not(prev_token_is_noise))
    subsequent_noise_tokens = torch.logical_and(noise_mask, prev_token_is_noise)
    sentinel = tokenizer.get_vocab()[first_extra_id] + 1 - torch.cumsum(first_noise_tokens.long(), dim=0)
    tokens = torch.where(first_noise_tokens, sentinel, tokens)
    ret = torch.masked_select(tokens, torch.logical_not(subsequent_noise_tokens))
    if append_last_sentinel: # target masking needs additional sentinel token at last position
        last_sentinel_id = sentinel.min().reshape(-1) - 1
        ret = torch.cat((ret, last_sentinel_id), dim=0)
    ret = torch.cat((ret, torch.tensor([tokenizer.eos_token_id], dtype=torch.long)), dim=0) # add eos token
    
    if denoiser_prefix:
    # used only for UL2, which prepends one of [S2S], [NLG], [NLU] during training. 
    # These tokens are not treated as special tokens but they are tokenized as normal tokens.
        denoiser_prefix_enc = torch.tensor(tokenizer.encode(denoiser_prefix)[:1], dtype=torch.long)
        ret = torch.cat((denoiser_prefix_enc, ret), dim=0)
    return ret

댓글

이 블로그의 인기 게시물

From In-context learning to RLHF (Feat. ChatGPT)

Wasserstein Auto-encoders (vs VAE)