1. 介绍
论文地址:Direct Preference Optimization: Your Language Model is Secretly a Reward Model.
代码仓库: direct-preference-optimization
主流的 LLM,如 LLama、ChatGLM、Qwen 的训练流程是预训练、SFT、RLHF 3 个步骤:
- Pretrain(预训练): 在该阶段,模型会在大量无标签的文本数据上进行训练,以学习语言的基本模式和结构。
- SFT(Supervised Fine-Tuning,监督微调):通过人类提供的监督信号,如标签或其他形式的指导,对模型进行微调以提高性能。
- RLHF(Reinforcement Learning from Human Feedback,基于人类反馈的强化学习): 在该阶段,模型会根据通过人类反馈收集的奖励信号进行进一步的微调。这一步骤的目标是使模型更好地满足用户的需求和期望。
相比 RLHF 方法需要先拟合一个奖励模型, 再使用强化学习来微调语言模型, 这个过程计算量大、训练时间长, 而且需要维护多个模型(奖励模型、策略模型等), 对显存要求较高。而 DPO 方法更直接:
- 不需要先拟合奖励模型, 而是直接优化语言模型本身, 使其生成的输出最符合人类的偏好数据。
- DPO 采用一个简单的分类目标, 隐式地拟合了一个奖励模型, 该模型的最优策略可以直接从优化后的语言模型中解析得到。
DPO 只需要两个 model, 而且一个是不更新权重,有点像知识蒸馏中的 tether 和 student。
论文中的 loss 如原文公式 7 所示:
$$
\mathcal{L}{\mathrm{DPO}}\left(\pi{\theta} ; \pi_{\mathrm{ref}}\right)=-\mathbb{E}{\left(x, y{w}, y_{l}\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_{\theta}\left(y_{w} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{w} \mid x\right)}-\beta \log \frac{\pi_{\theta}\left(y_{l} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{l} \mid x\right)}\right)\right] \tag{7}
$$
这里,$\pi_{\theta}$ 就是优化的语言模型,$\pi_{\text{ref}}$ 是同样的模型但是不更新权重,只做推理。$w$ 是 win 的意思,就是偏好数据中较好的一条。假定 $\beta=1$,去掉外面的 sigmoid_log 就是:
$$
\left[\operatorname{log}\pi_{\theta}\left(y_{w}\right)-\log \pi_{\text{ref}}\left(y_{w}\right)\right]-\left[\operatorname{log}\pi_{\theta}\left(y_{l}\right)-\log \pi_{\text{ref}}\left(y_{l}\right)\right] \tag{l1}
$$
- 如果左边变大(奖励模型输出的好响应概率增加), 右边变小(对坏响应的概率降低), 这是理想情况。模型会更倾向于生成好的响应, 同时坏响应的概率降低。
- 如果左边变小(好响应概率下降), 右边变得更小(坏响应概率下降的幅度更大), 在生成时模型仍然会更倾向于生成好响应, 尽管好响应的绝对概率降低了。
- 如果左边变得更大(好响应概率大幅增加), 右边只变大了一点点(坏响应概率略微增加), 情况类似于 2), 生成时模型会更多地偏向于好响应。
总的来说, 不论奖励模型输出如何变化, 关键是好响应的概率相对于坏响应的概率更高, 模型在生成时就会更倾向于输出好的响应。即使好响应概率下降, 只要坏响应概率下降的幅度更大, 最终的生成结果仍会更好。具体参考——DPO: Direct Preference Optimization 论文解读及代码实践
2. loss 实现
class DPOLoss(nn.Module):
"""DPO loss function."""
def __init__(self, reduction="mean", beta=0.001):
super().__init__()
self.reduction = reduction
self.beta = beta
def forward(self, logits, logits_ref, labels, cu_lengths=None):
# if cu_lengths is None, assume that all examples belong to the same conversation
if cu_lengths is None:
cu_lengths = [0, logits.size(0)]
device = logits.device
losses = []
rewards = []
for start, end in zip(cu_lengths[:-1], cu_lengths[1:]):
pairs = torch.combinations(torch.arange(end - start, device=device), 2)
pos_ids, neg_ids = pairs[:, 0], pairs[:, 1]
# compute logprob of pos and neg examples
pos_logits = logits[start + pos_ids]
neg_logits = logits[start + neg_ids]
pos_logprob = F.log_softmax(pos_logits, dim=-1)
neg_logprob = F.log_softmax(neg_logits, dim=-1)
pos_logprob = torch.gather(pos_logprob, 2, labels[pos_ids].unsqueeze(-1))
neg_logprob = torch.gather(neg_logprob, 2, labels[neg_ids].unsqueeze(-1))
# we need to compute the logprob of the reference examples
pos_logits_ref = logits_ref[start + pos_ids]
neg_logits_ref = logits_ref[start + neg_ids]
pos_logprob_ref = F.log_softmax(pos_logits_ref, dim=-1)
neg_logprob_ref = F.log_softmax(neg_logits_ref, dim=-1)
pos_logprob_ref = torch.gather(pos_logprob_ref, 2, labels[pos_ids].unsqueeze(-1))
neg_logprob_ref = torch.gather(neg_logprob_ref, 2, labels[neg_ids].unsqueeze(-1))
# compute loss and reward
pi_logratios = pos_logprob.mean() - neg_logprob.mean()
ref_logratios = pos_logprob_ref.mean() - neg_logprob_ref.mean()
loss = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))
reward = self.beta * (pi_logratios - ref_logratios)
losses.append(loss)
rewards.append(reward)
return sum(losses)/len(losses), sum(rewards)/len(rewards)
DPOLoss
解释,代码来源。
__init__
方法初始化了reduction
参数 (决定如何减少损失值的维度) 和beta
参数(一个缩放系数)。forward
方法是核心逻辑所在。它接收以下输入:
logits
: 当前模型输出的 logitslogits_ref
: 参考模型输出的 logitslabels
: 正确的标签cu_lengths
: 可选参数, 表示不同对话上下文的长度
- 如果
cu_lengths
为None
, 则假设所有样本属于同一个对话上下文。 - 对于每个对话上下文, 代码计算该上下文中所有可能的正例对 (
pos_ids
) 和负例对 (neg_ids
) 的组合。 - 计算正例和负例在当前模型和参考模型下的对数概率:
pos_logprob
、neg_logprob
、pos_logprob_ref
和neg_logprob_ref
。 - 计算当前模型和参考模型在正例和负例上的对数概率比值:
pi_logratios
和ref_logratios
。 - 计算损失为:
-F.logsigmoid(beta * (pi_logratios - ref_logratios))
- 这里的目标是最大化
pi_logratios - ref_logratios
, 即让当前模型在正例上的概率相对于负例的概率, 相比参考模型有所提高。 logsigmoid
函数用于将这个目标值约束在 (0,1) 范围内, 并取负值作为损失。
- 计算奖励为:
beta * (pi_logratios - ref_logratios)
- 对所有上下文的损失和奖励取平均, 作为最终的输出。
3. DPO 训练 Mistralx7B
这里数据采用 comparison_gpt4_data_zh 部分数据训练,使用 trl 的框架进行训练from trl import DPOrainer
。
原始数据如下:
{'instruction': '什么是三原色?', 'input': '','output': [' 三原色是红、蓝、黄。这些颜色被称为原色,因为它们不能通过混合其他颜色来创建,所有其他颜色都可以通过将它们按不同比例组合而成。在用于光的加色系统中,原色是红色、绿色和蓝色 (RGB)。',' 红色、黄色和绿色。']}
要转换为如下格式输入:
<|im_start|>system
You are a helpful chatbot that will do its best not to say anything so stupid that people tweet about it.<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant
I'm doing great!<|im_end|>
用这个方法转换一下:
def chatml_format(example):
# 格式化输入
if 'input' in example and len(example['input']) > 0:
message = {"role": "system", "content": example['input']}
system = tokenizer.apply_chat_template([message], tokenize=False)
else:
system = ""
# 格式化指令
if 'instruction' in example:
message = {"role": "user", "content": example['instruction']}
prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
else:
prompt = ""
# 格式化输出
if 'output' in example:
chosen = example['output'][0] + "<|im_end|>\n"
if len(example['output']) > 1:
rejected = example['output'][1] + "<|im_end|>\n"
else:
rejected = "<|im_end|>\n"
else:
chosen = "<|im_end|>\n"
rejected = "<|im_end|>\n"
return {
"prompt": system + prompt,
"chosen": chosen,
"rejected": rejected,
}
取两条大概样子如下:
{'prompt': ['<|im_start|>user\n 什么是三原色?<|im_end|>\n<|im_start|>assistant\n',
'<|im_start|>user\n 解释为什么下面的分数等于 1/4\n4/16<|im_end|>\n<|im_start|>assistant\n'],
'chosen': ['三原色是红、蓝、黄。这些颜色被称为原色,因为它们不能通过混合其他颜色来创建,所有其他颜色都可以通过将它们按不同比例组合而成。在用于光的加色系统中,原色是红色、绿色和蓝色 (RGB)。<|im_end|>\n',
'分数 4/16 等于 1/4,因为分子和分母都可以被 4 整除。将顶部和底部数字都除以 4 得到分数 1/4。<|im_end|>\n'],
'rejected': ['红色、黄色和绿色。<|im_end|>\n', '1/4 与 1/4 相同。<|im_end|>\n']}
DPO 训练需要两个模型:
model
: 这是要被微调的主模型, 在训练过程中, 该模型的参数会根据损失函数进行更新。ref_model
: 这是一个参考模型, 通常与model
具有相同的体系结构和初始权重。它被用于计算与目标输出的对数似然概率, 而不会被训练或更新参数。
在 DPO 训练中, 会先使用 ref_model
生成参考输出的对数似然概率 (ref_logits
)。然后,model
会生成它自己的输出对数似然概率(model_logits
)。
接下来, 使用一个损失函数 (如交叉熵损失) 来衡量 model_logits
与目标输出的差异, 并将这个损失值与 ref_logits
结合, 计算最终的 DPO 损失。
通过最小化这个 DPO 损失,model
可以被引导产生更接近目标输出的预测, 同时也尽量不偏离 ref_model
的初始预测, 从而达到微调的目的。
因此,ref_model
扮演了一个提供参考输出概率分布的角色, 用于计算损失和引导主模型 model
的训练, 而它自身的参数在训练过程中保持不变。
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
load_in_4bit=True
)
model.config.use_cache = False
# Reference model
ref_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
load_in_4bit=True
)
设置好 DPOtrainer 就可以训练了:
# LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
)
# Model to fine-tune
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
load_in_4bit=True
)
model.config.use_cache = False
# Reference model
ref_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
load_in_4bit=True
)
# Training arguments
training_args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=5e-5,
lr_scheduler_type="cosine",
max_steps=200,
save_strategy="no",
logging_steps=1,
output_dir=new_model,
optim="paged_adamw_32bit",
warmup_steps=100,
bf16=True,
# report_to="wandb",
report_to="tensorboard",
)
dpo_trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
beta=0.1,
max_prompt_length=1024,
max_length=1536,
force_use_ref_model=True
)
dpo_trainer.train()
Inference
[1] DPO(Direct Preference Optimization):LLM 的直接偏好优化
[3] LLM DPO