Chatglm 论文地址:GLM: General Language Model Pretraining with Autoregressive Blank Infilling.
1. 摘要
- 预训练模型: 包括自编码模型(如 BERT)、自回归模型(如 GPT)和编码器 - 解码器模型(如 T5),都无法在 NLP 中 NLU, 无条件生成 和有条件生成 三大类任务上表现非常好。
- 本文提出了一种基于 自回归空白填充 的 General Language Model(GLM)模型来处理上面的挑战。
- GLM 的改进之处: GLM 通过引入 2D 位置编码 和允许以任意顺序预测片段 来改进空白填充预训练。在 NLU 任务上,GLM 相较于 BERT 和 T5 实现了性能提升。
2. GLM 预训练框架
自回归空白填充
GLM 通过优化自回归填充目标进行训练。给定输入文本 x=[x1,⋯,xn],会抽样出多个文本片段 s1,⋯,sm,其中每个片段 si 对应 x 一系列连续的标记 [si,1,⋯,si,li]。每个片段都用单个 [MASK] 标记替换,形成一个损坏的文本 xcorrupt。模型以自回归方式从损坏的文本中预测片段中的缺失标记,这意味着在预测片段中的缺失标记时,模型可以访问损坏的文本和先前预测的片段。为了充分捕捉不同片段之间的相互依赖关系,作者随机排列了片段的顺序,类似于排列语言模型(Yang 等,2019)。形式上,令 Zm 是长度为 m 的索引序列 [1,2,⋯,m] 的所有可能排列的集合,sz<i 就是 [sz1,⋯,szi−1],我们定义预训练目标为:
max
表示最大化给定输入 \boldsymbol{x}{\text{corrupt}} 和先前预测的片段 \mathbf{s}{z<i} 条件下对 s_{zi} 的概率的对数和。Z_m 是长度为 m 的索引序列的所有可能排列的集合。
论文采用以下技术实现了自回归的空白填充目标。输入 \boldsymbol{x} 被分为两部分:Part A 是损坏的文本 x_{\text{corrupt}},Part B 包含被 mask 的片段。Part A 中的 token 可以相互关注,但不能关注 B 中的任何 token。Part B 中的 token 可以关注 Part A 和 B 中的先行 token,但不能关注 B 中的任何后续 token。为了实现自回归生成,每个片段都用特殊标记 [START] 和[END]进行填充,分别用于输入和输出。这样,GLM 模型在一个统一的模型中自动学习了一个双向编码器(用于 Part A)和一个单向解码器(用于 Part B)。GLM 的实现如图 2 所示。

如上图 2 是 GLM 实现,最主要的部分。
a. 原始文本是 [x_1, x_2,x_3,x_4,x_5,x_6], 然后用采样得到两个片段 [x_3] 和 [x_5, x_6]
b. 原始的文本采样的部分用 [MASK] 代替,就是遮盖住,这就是 Part A.
采样的两个片段组成 Part B,再随机打乱 Part B。
c. GLM 自回归生成 Part B。每个片段用用 [START] 打头和 [END] 结尾。这就组成了整个输入和预测部分。如图 c 中所示,当输入蓝色的时候,要预测即输出就是黄色部分。怎么区分,使用 START 标记。
2D 的位置编码:
1. 区分位置和 span 的范围 | |
2. 区分预测和不同 span |
d. 自注意力掩码. 灰色表示看不到的。Part A 能看到自己但是不能看到 B. Part B 能看到自己前行的部分和 Part A.
3. tokenizer
chatglm3 的 tokenizer 位于:tokenization_chatglm.py.tokenizer 预训练后的文件就是tokenizer.model。
from sentencepiece import SentencePieceProcessor | |
spe_path = 'tokenizer.model' | |
spe = SentencePieceProcessor(model_file=spe_path) | |
print(spe.vocab_size())#64789 | |
print(spe.bos_id())#1 | |
print(spe.eos_id())#2 | |
print(spe.unk_id())#0 |
其中特殊的 token 如下:
[MASK]
: 跟普通的 bert 一样[gMASK]
: 后续为 auto-regressive 生成, 就是前面 Part A 后面接 Part B 时要加上。[sMASK]
: 句子级别 masksop
表示每个 auto-regressive 补全片段的开始,eop
表示补全片段的结束。
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] | |
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens |
核心 tokenizer 部分:
def tokenize(self, s: str, encode_special_tokens=False): | |
if encode_special_tokens: | |
# 如果 encode_special_tokens 为 True,则进行处理特殊标记的分词 | |
last_index = 0 | |
t = [] # 用于存储分词后的结果 | |
for match in re.finditer(self.role_special_token_expression, s): | |
# 使用正则表达式匹配特殊标记 | |
if last_index < match.start(): | |
# 将特殊标记之前的部分进行分词,并加入结果列表 t 中 | |
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()])) | |
# 将匹配到的特殊标记加入结果列表 t 中 | |
t.append(s[match.start():match.end()]) | |
last_index = match.end() | |
if last_index < len(s): | |
# 处理最后一个特殊标记之后的部分 | |
t.extend(self.sp_model.EncodeAsPieces(s[last_index:])) | |
return t | |
else: | |
# 如果 encode_special_tokens 为 False,则直接使用 sp_model 进行分词 | |
return self.sp_model.EncodeAsPieces(s) |
然后是 ChatGLMTokenizer,主要是 ChatGLM tokenizer 特殊 token 设置,token,id 转换,chat 结构设置。
4. model
1. rmsnorm
Bert 中使用的Layer Normalization 计算公式:
y = \frac{x – \text{Mean}(x)}{\sqrt{\text{Var}(x) + \epsilon}} \cdot W + B \tag{2}
RMSNorm 计算公式:
y = \frac{x}{\sqrt{\text{Var}(x) + \epsilon}} \cdot W \tag{3}
RMSNorm 是对 Layer Normalization 的简化形式,省略了均值的计算和平移的操作,而是使用每个位置上的方差进行缩放。这种归一化方法更适用于文本序列等应用。
chatglm 中 rmsnorm代码如下:
class RMSNorm(torch.nn.Module): | |
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): | |
super().__init__() | |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) | |
self.eps = eps | |
def forward(self, hidden_states: torch.Tensor): | |
input_dtype = hidden_states.dtype | |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | |
return (self.weight * hidden_states).to(input_dtype) |
Inference
[1] 清华团队发布 ChatGLM2-6B,该款版本有何亮点?
[2] 【报告笔记】大规模语言模型系列技术:以 GLM-130B 为例
[3] 【报告】从 GLM-130B 到 ChatGLM:大模型预训练与微调
[4] 如何正确的构建 input_ids、attention_mask、position_ids 和 labels@Porraio