对 ChatGLM-6B 做 LoRA Fine-tuning

    • 搭建依赖环境
    • 加载模型和 Tokenizer
    • 分析模型结构
    • 配置 LoRA
    • 构建数据集
      • 定义常量
      • 测试 Tokenizer 的编解码
      • 定义 Prompt
      • 构建 Attention Mask 和 Position IDs
      • 创建数据集
    • 开始训练
    • 预测
    • 保存训练模型
    • 重载训练后的模型

ChatGLM-6B 是一个支持中英双语的对话语言模型,基于 GLM (General Language Model)。它只有 62 亿个参数,量化后最低 (INT4 量化) 只需要 6GB 的显存,完全可以部署到消费级显卡上。在实际使用这个模型一段时间以后,我们发现模型的对话表现能力确实非常不错。那么,基于这个模型做 Fine-tuning 就非常有价值了。

声明:

本文提供的所有技术信息,都基于 THUDM/chatglm-6b 的历史版本:
096f3de6b4959ce38bef7bb05f3129c931a3084e

源码地址:

  • GitHub
  • gitee

搭建依赖环境

安装 PyTorch 环境:

pip install torch torchvision torchaudio

按照 ChatGLM-6B 的官方指导,安装软件依赖环境:

pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels

为了做 LoRA,还要安装 peft

pip install peft

加载模型和 Tokenizer

from transformers import AutoTokenizer, AutoModelcheckpoint = "THUDM/chatglm-6b"revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)

分析模型结构

模型加载完后,我们可以打印这个 modeltokenizer,建立对模型的基本认知。

首先打印model

print(model)

得到如下结果:

ChatGLMForConditionalGeneration((transformer): ChatGLMModel((word_embeddings): Embedding(150528, 4096)(layers): ModuleList((0-27): 28 x GLMBlock((input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)(attention): SelfAttention((rotary_emb): RotaryEmbedding()(query_key_value): Linear(in_features=4096, out_features=12288, bias=True)(dense): Linear(in_features=4096, out_features=4096, bias=True))(post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)(mlp): GLU((dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)(dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True))))(final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True))(lm_head): Linear(in_features=4096, out_features=150528, bias=False))

简单分析这个模型结构,至少可以得到如下一些信息:

  • 模型使用了 Transformer 结构,因此可以使用 LoRA 进行 Fine-tuning
  • 从 Word Embedding 层可以看出,词汇表大小是 150528
  • LoRA 可以操作的目标是:query_key_value

再打印tokenizer:

print(tokenizer)

得到如下结果(为了便于阅读,已对结果做了分行处理):

ChatGLMTokenizer(name_or_path='THUDM/chatglm-6b', vocab_size=150344, model_max_length=2048, is_fast=False, padding_side='left', truncation_side='right', special_tokens={'bos_token': '', 'eos_token': '', 'unk_token': '', 'pad_token': '', 'mask_token': '[MASK]'})

这里有几个可以关注的点:

  • 词汇表大小vocab_size150344
  • 不是一个 fast Tokenizer(is_fast 的值是 False
  • 特殊 token 包括:bos eos padmask

为什么 model 中的词汇表大小是 150528,而 tokenizer 中定义的词汇表大小却是 150344 呢?读者可以带着这个疑问去读一读模型项目的源码,看看能不能找到答案。

配置 LoRA

借助 peft 库,我们可以很方便地对模型注入 LoRA。

from peft import LoraConfig, get_peft_model, TaskTypedef load_lora_config(model):config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False,r=8, lora_alpha=32, lora_dropout=0.1,target_modules=["query_key_value"])return get_peft_model(model, config)model = load_lora_config(model)

打印可训练的参数量:

model.print_trainable_parameters()

得到如下结果:

trainable params: 3670016 || all params: 6258876416 || trainable%: 0.05863697820615348

可以看到,总的参数量是 6,258,876,416,可训练的参数量是 3,670,016,占比 0.0586% 左右。训练参数量只是百万级别的,可谓相当友好了!另外需要注意的一点是,ChatGLM-6B 是一个因果语言模型 (Causal Language Model),因此我们这里选择的任务类型是 CAUSAL_LM

构建数据集

定义常量

构建之前,我们先定义几个特殊 Token 常量:

bos = tokenizer.bos_token_ideop = tokenizer.eop_token_idpad = tokenizer.pad_token_idmask = tokenizer.mask_token_idgmask = tokenizer.sp_tokenizer[tokenizer.gMASK_token]

将这几个值打印出来:

print("bos = ", bos)print("eop = ", eop)print("pad = ", pad)print("mask = ", mask)print("gmask = ", gmask)

得到如下结果:

bos =150004eop =150005pad =20003mask =150000gmask =150001

我们也可以直接用这个常量结果替换动态计算的部分。常量修改后的结果变成:

bos = 150004eop = 150005pad = 20003mask = 150000gmask = 150001

除了上面定义的 Token 常量,我们还需要定义模型训练绑定的设备名,以及最大输入长度和最大输出长度等,如下:

device = "cuda"max_src_length = 200max_dst_length = 500

开发者可以结合自己的显卡性能和要处理的数据集特点来确定这些最大长度。

测试 Tokenizer 的编解码

我们可以先做个简单的测试:

text = "AI探险家"print(tokenizer.encode(text, add_special_tokens = True))print(tokenizer.encode(text, add_special_tokens = False))

输出结果是:

[26738, 98715, 83920, 150001, 150004][26738, 98715, 83920]

从这个结果可以看出,“AI探险家”这几个字的裸编码是 [26738, 98715, 83920]。为什么是这样呢?我们可以对每一个数值再解码,看看输出结果:

print(tokenizer.decode([26738]))print(tokenizer.decode([98715]))print(tokenizer.decode([83920]))

输出结果是:

AI探险家

观察这个结果,读者应该能对词汇表建立基本的认知了。读者如果有兴趣,还可以分别针对 “A” “I” “探” “险” 这几个字分别编码,看看编码结果是什么。

另外,当 add_special_tokens = True 时,编码结果会在末尾添加 150001150004,也就是 gmaskbos。请注意,我们的训练数据,要按照如下编码要求进行构造:

[token, ..., token, gmask, bos, token, ... token, eop]

因此,前半部分文本的编码可以直接让 add_special_tokens = True,后半部分文本的编码则让 add_special_tokens = False,最后再拼接一个 eop

定义 Prompt

我们 Fine-tuning 的任务是问答任务(简称 QA),因此一个简单的 Prompt 是这样的:

PROMPT_PATTERN = "问:{}\n答: "

{}里填入 QA 训练集的问题文本。在显存有限的情况下,如果不对长文本做限制处理,很容易出现类似 CUDA out of memory 这样的报错。处理长文本,在给定编码后的数组上限时,可能存在这么几种方式:

  • 截断末尾超出部分的编码
  • 截断前面超出部分的编码
  • 丢掉训练样本

每一种方式都有各自的优劣,开发者可以根据自身数据的特点自行选择一种处理方式。当然,如果你的显存够大,也可以不处理。本文以上述第一种方式进行处理。
为了不把 PROMPT_PATTERN 中的 \n答: 这几个字截断掉,我们将整个 PROMPT_PATTERN 拆成两部分:

PROMPT_PATTERN = "问:{}"SEP_PATTERN = "\n答: "

基于这份 Prompt 模板,我们定义下面三个辅助方法:

def create_prompt(question):return PROMPT_PATTERN.format(question), SEP_PATTERNdef create_prompt_ids(tokenizer, question, max_src_length):prompt, sep = create_prompt(question)sep_ids = tokenizer.encode(sep, add_special_tokens = True)sep_len = len(sep_ids)special_tokens_num = 2prompt_ids = tokenizer.encode(prompt, max_length = max_src_length - (sep_len - special_tokens_num),truncation = True,add_special_tokens = False)return prompt_ids + sep_idsdef create_inputs_and_labels(tokenizer, question, answer, device):prompt = create_prompt_ids(tokenizer, question, max_src_length)completion = tokenizer.encode(answer, max_length = max_dst_length,truncation = True,add_special_tokens = False)inputs = prompt + completion + [eop]labels = [-100] * len(prompt) + completion + [eop] inputs = torch.tensor(inputs, dtype=torch.long, device=device)labels = torch.tensor(labels, dtype=torch.long, device=device)return inputs, labels

值得注意的两点:

  • create_prompt_ids 这个函数实现可以看出,我们编码分隔符 SEP_PATTERN 时自动添加了前面所述的 2 个特殊 Token。
  • create_inputs_and_labels 的函数实现中,我们将 labels 无需处理的部分用数值 -100 来表示。因为 ChatGLMForConditionalGeneration 内部在计算损失函数的时候,用的是 torch.nn.CrossEntropyLoss。该函数的参数之一 ignore_index 默认值是 -100。这就让我们在计算损失函数时,无需考虑非标识部分的数值。

构建 Attention Mask 和 Position IDs

def get_attention_mask(tokenizer, input_ids, device):seq = input_ids.tolist()context_len = seq.index(bos)seq_len = len(seq)attention_mask = torch.ones((seq_len, seq_len), device=device)attention_mask.tril_()attention_mask[..., :context_len] = 1attention_mask.unsqueeze_(0)attention_mask = (attention_mask < 0.5).bool()return attention_maskdef get_position_ids(tokenizer, input_ids, device, position_encoding_2d=True):seq = input_ids.tolist()context_len = seq.index(bos)seq_len = len(seq)mask_token = mask if mask in seq else gmaskuse_gmask = False if mask in seq else gmaskmask_position = seq.index(mask_token)if position_encoding_2d:position_ids = torch.arange(seq_len, dtype=torch.long, device=device)if not use_gmask:position_ids[context_len:] = mask_positionblock_position_ids = torch.cat((torch.zeros(context_len, dtype=torch.long, device=device),torch.arange(seq_len - context_len, dtype=torch.long, device=device) + 1))position_ids = torch.stack((position_ids, block_position_ids), dim=0)else:position_ids = torch.arange(seq_len, dtype=torch.long, device=device)if not use_gmask:position_ids[context_len:] = mask_positionreturn position_ids

在这个通用实现中,我们针对 maskgmask 两种情况做了区分,同时也对是否执行 position_encoding_2d 分情况处理。本文的 QA 任务采用的是 gmask,并且使用 position_encoding_2d = True

我们可以构建下面的问答,来验证下这几个函数的输出:

test_data = {"question": "AI探险家帅不帅?","answer": "非常帅!"}inputs, labels = create_inputs_and_labels(tokenizer, **test_data, device=device)attention_mask = get_attention_mask(tokenizer, inputs, device=device)position_ids = get_position_ids(tokenizer, inputs, device=device)print("inputs: \n", inputs.tolist())print("\nlabels: \n", labels.tolist())print("\nposition_ids: \n", position_ids.tolist())print("\nattention_mask: \n", attention_mask.tolist())

输出结果(为了便于阅读,已对输出进行格式化操作):

inputs:[20005, 84286, 20012, 31943, 98715, 83920, 87359, 83848, 87359, 20031, 20005, 20004, 87342, 20012, 150001, 150004, 20005, 84122, 87359, 20035, 150005]labels:[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 20005, 84122, 87359, 20035, 150005]position_ids:[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0,0,0,0,0,1,2,3,4,5] ]attention_mask:[[ [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True],[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]]

结合论文观察数据,基本符合预期。

创建数据集

我们先定义具有如下格式的训练数据:

train_data = [{"question": "问题1", "answer": "答案1"},{"question": "问题2", "answer": "答案2"},]

定义好格式后,我们先创建一个 QADataset 类,如下:

from torch.utils.data import Datasetclass QADataset(Dataset):def __init__(self, data, tokenizer) -> None:super().__init__()self.data = dataself.tokenizer = tokenizer def __getitem__(self, index):item_data = self.data[index]tokenizer = self.tokenizerinput_ids, labels = create_inputs_and_labels(tokenizer, device=device,**item_data)attention_mask = get_attention_mask(tokenizer, input_ids, device)position_ids = get_position_ids(tokenizer, input_ids, device)return {"input_ids": input_ids,"labels": labels,"attention_mask": attention_mask,"position_ids": position_ids}def __len__(self):return len(self.data)

然后创建一个 Data Collator:

def collate_fn(batch):input_ids = []attention_mask = []labels = []position_ids = []for obj in batch:input_ids.append(obj['input_ids'])labels.append(obj['labels'])attention_mask.append(obj['attention_mask'])position_ids.append(obj['position_ids'])return {'input_ids': torch.stack(input_ids),'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels),'position_ids':torch.stack(position_ids)}

开始训练

from transformers import TrainingArguments, Trainermodel.to(device)training_args = TrainingArguments("output",fp16 =True,save_steps = 500,save_total_limit = 3,gradient_accumulation_steps=1,per_device_train_batch_size = 1,learning_rate = 1e-4,max_steps=1500,logging_steps=50,remove_unused_columns=False,seed=0,data_seed=0,group_by_length=False,dataloader_pin_memory=False)class ModifiedTrainer(Trainer):def compute_loss(self, model, inputs, return_outputs=False):return model(input_ids=inputs["input_ids"],attention_mask=inputs["attention_mask"],position_ids=inputs["position_ids"],labels=inputs["labels"],).losstrain_dataset = QADataset(train_data, tokenizer=tokenizer)trainer = ModifiedTrainer(model=model,train_dataset=train_dataset,args=training_args,data_collator=collate_fn,tokenizer=tokenizer)trainer.train()

预测

response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])print(response)

保存训练模型

import osdef save_tuned_parameters(model, path):saved_params = {k: v.to(device)for k, v in model.named_parameters()if v.requires_grad}torch.save(saved_params, path)save_tuned_parameters(model, os.path.join("/path/to/output", "chatglm-6b-lora.pt"))

重载训练后的模型

checkpoint = "THUDM/chatglm-6b"revision = "096f3de6b4959ce38bef7bb05f3129c931a3084e"model = AutoModel.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision=revision, trust_remote_code=True)model = load_lora_config(model)model.load_state_dict(torch.load(f"/path/to/output/chatglm-6b-lora.pt"), strict=False)model.half().cuda().eval()response, history = model.chat(tokenizer, "AI探险家的颜值如何?", history=[])print(response)