DPO 训练

1. 介绍

论文地址:Direct Preference Optimization: Your Language Model is Secretly a Reward Model.

代码仓库: direct-preference-optimization

DPO 训练

主流的 LLM,如 LLama、ChatGLM、Qwen 的训练流程是预训练、SFT、RLHF 3 个步骤:

  1. Pretrain(预训练): 在该阶段,模型会在大量无标签的文本数据上进行训练,以学习语言的基本模式和结构。
  2. SFT(Supervised Fine-Tuning,监督微调):通过人类提供的监督信号,如标签或其他形式的指导,对模型进行微调以提高性能。
  3. RLHF(Reinforcement Learning from Human Feedback,基于人类反馈的强化学习): 在该阶段,模型会根据通过人类反馈收集的奖励信号进行进一步的微调。这一步骤的目标是使模型更好地满足用户的需求和期望。

相比 RLHF 方法需要先拟合一个奖励模型, 再使用强化学习来微调语言模型, 这个过程计算量大、训练时间长, 而且需要维护多个模型(奖励模型、策略模型等), 对显存要求较高。而 DPO 方法更直接:

  1. 不需要先拟合奖励模型, 而是直接优化语言模型本身, 使其生成的输出最符合人类的偏好数据。
  2. DPO 采用一个简单的分类目标, 隐式地拟合了一个奖励模型, 该模型的最优策略可以直接从优化后的语言模型中解析得到。

DPO 只需要两个 model, 而且一个是不更新权重,有点像知识蒸馏中的 tether 和 student。

论文中的 loss 如原文公式 7 所示:
$$
\mathcal{L}{\mathrm{DPO}}\left(\pi{\theta} ; \pi_{\mathrm{ref}}\right)=-\mathbb{E}{\left(x, y{w}, y_{l}\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_{\theta}\left(y_{w} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{w} \mid x\right)}-\beta \log \frac{\pi_{\theta}\left(y_{l} \mid x\right)}{\pi_{\mathrm{ref}}\left(y_{l} \mid x\right)}\right)\right] \tag{7}
$$

这里,$\pi_{\theta}$ 就是优化的语言模型,$\pi_{\text{ref}}$ 是同样的模型但是不更新权重,只做推理。$w$ 是 win 的意思,就是偏好数据中较好的一条。假定 $\beta=1$,去掉外面的 sigmoid_log 就是:
$$
\left[\operatorname{log}\pi_{\theta}\left(y_{w}\right)-\log \pi_{\text{ref}}\left(y_{w}\right)\right]-\left[\operatorname{log}\pi_{\theta}\left(y_{l}\right)-\log \pi_{\text{ref}}\left(y_{l}\right)\right] \tag{l1}
$$

  1. 如果左边变大(奖励模型输出的好响应概率增加), 右边变小(对坏响应的概率降低), 这是理想情况。模型会更倾向于生成好的响应, 同时坏响应的概率降低。
  2. 如果左边变小(好响应概率下降), 右边变得更小(坏响应概率下降的幅度更大), 在生成时模型仍然会更倾向于生成好响应, 尽管好响应的绝对概率降低了。
  3. 如果左边变得更大(好响应概率大幅增加), 右边只变大了一点点(坏响应概率略微增加), 情况类似于 2), 生成时模型会更多地偏向于好响应。

总的来说, 不论奖励模型输出如何变化, 关键是好响应的概率相对于坏响应的概率更高, 模型在生成时就会更倾向于输出好的响应。即使好响应概率下降, 只要坏响应概率下降的幅度更大, 最终的生成结果仍会更好。具体参考——DPO: Direct Preference Optimization 论文解读及代码实践

2. loss 实现

class DPOLoss(nn.Module):
    """DPO loss function."""

    def __init__(self, reduction="mean", beta=0.001):
        super().__init__()
        self.reduction = reduction
        self.beta = beta

    def forward(self, logits, logits_ref, labels, cu_lengths=None):
        # if cu_lengths is None, assume that all examples belong to the same conversation
        if cu_lengths is None:
            cu_lengths = [0, logits.size(0)]

        device = logits.device
        losses = []
        rewards = []
        for start, end in zip(cu_lengths[:-1], cu_lengths[1:]):
            pairs = torch.combinations(torch.arange(end - start, device=device), 2)
            pos_ids, neg_ids = pairs[:, 0], pairs[:, 1]

            # compute logprob of pos and neg examples
            pos_logits = logits[start + pos_ids]
            neg_logits = logits[start + neg_ids]

            pos_logprob = F.log_softmax(pos_logits, dim=-1)
            neg_logprob = F.log_softmax(neg_logits, dim=-1)

            pos_logprob = torch.gather(pos_logprob, 2, labels[pos_ids].unsqueeze(-1))
            neg_logprob = torch.gather(neg_logprob, 2, labels[neg_ids].unsqueeze(-1))

            # we need to compute the logprob of the reference examples
            pos_logits_ref = logits_ref[start + pos_ids]
            neg_logits_ref = logits_ref[start + neg_ids]

            pos_logprob_ref = F.log_softmax(pos_logits_ref, dim=-1)
            neg_logprob_ref = F.log_softmax(neg_logits_ref, dim=-1)

            pos_logprob_ref = torch.gather(pos_logprob_ref, 2, labels[pos_ids].unsqueeze(-1))
            neg_logprob_ref = torch.gather(neg_logprob_ref, 2, labels[neg_ids].unsqueeze(-1))

            # compute loss and reward
            pi_logratios = pos_logprob.mean() - neg_logprob.mean()
            ref_logratios = pos_logprob_ref.mean() - neg_logprob_ref.mean()

            loss = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))
            reward = self.beta * (pi_logratios - ref_logratios)

            losses.append(loss)
            rewards.append(reward)

        return sum(losses)/len(losses), sum(rewards)/len(rewards)

DPOLoss解释,代码来源

  1. __init__方法初始化了 reduction 参数 (决定如何减少损失值的维度) 和beta参数(一个缩放系数)。
  2. forward方法是核心逻辑所在。它接收以下输入:
  • logits: 当前模型输出的 logits
  • logits_ref: 参考模型输出的 logits
  • labels: 正确的标签
  • cu_lengths: 可选参数, 表示不同对话上下文的长度
  1. 如果 cu_lengthsNone, 则假设所有样本属于同一个对话上下文。
  2. 对于每个对话上下文, 代码计算该上下文中所有可能的正例对 (pos_ids) 和负例对 (neg_ids) 的组合。
  3. 计算正例和负例在当前模型和参考模型下的对数概率:pos_logprobneg_logprobpos_logprob_refneg_logprob_ref
  4. 计算当前模型和参考模型在正例和负例上的对数概率比值:pi_logratiosref_logratios
  5. 计算损失为:-F.logsigmoid(beta * (pi_logratios - ref_logratios))
  • 这里的目标是最大化pi_logratios - ref_logratios, 即让当前模型在正例上的概率相对于负例的概率, 相比参考模型有所提高。
  • logsigmoid函数用于将这个目标值约束在 (0,1) 范围内, 并取负值作为损失。
  1. 计算奖励为:beta * (pi_logratios - ref_logratios)
  2. 对所有上下文的损失和奖励取平均, 作为最终的输出。

3. DPO 训练 Mistralx7B

这里数据采用 comparison_gpt4_data_zh 部分数据训练,使用 trl 的框架进行训练from trl import DPOrainer

原始数据如下:

{'instruction': '什么是三原色?', 'input': '','output': [' 三原色是红、蓝、黄。这些颜色被称为原色,因为它们不能通过混合其他颜色来创建,所有其他颜色都可以通过将它们按不同比例组合而成。在用于光的加色系统中,原色是红色、绿色和蓝色 (RGB)。',' 红色、黄色和绿色。']}

要转换为如下格式输入:

<|im_start|>system
You are a helpful chatbot that will do its best not to say anything so stupid that people tweet about it.<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant
I'm doing great!<|im_end|>

用这个方法转换一下:

def chatml_format(example):
    # 格式化输入
    if 'input' in example and len(example['input']) > 0:
        message = {"role": "system", "content": example['input']}
        system = tokenizer.apply_chat_template([message], tokenize=False)
    else:
        system = ""

    # 格式化指令
    if 'instruction' in example:
        message = {"role": "user", "content": example['instruction']}
        prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
    else:
        prompt = ""

    # 格式化输出
    if 'output' in example:
        chosen = example['output'][0] + "<|im_end|>\n"
        if len(example['output']) > 1:
            rejected = example['output'][1] + "<|im_end|>\n"
        else:
            rejected = "<|im_end|>\n"
    else:
        chosen = "<|im_end|>\n"
        rejected = "<|im_end|>\n"

    return {
        "prompt": system + prompt,
        "chosen": chosen,
        "rejected": rejected,
    }

取两条大概样子如下:

{'prompt': ['<|im_start|>user\n 什么是三原色?<|im_end|>\n<|im_start|>assistant\n',
  '<|im_start|>user\n 解释为什么下面的分数等于 1/4\n4/16<|im_end|>\n<|im_start|>assistant\n'],
 'chosen': ['三原色是红、蓝、黄。这些颜色被称为原色,因为它们不能通过混合其他颜色来创建,所有其他颜色都可以通过将它们按不同比例组合而成。在用于光的加色系统中,原色是红色、绿色和蓝色 (RGB)。<|im_end|>\n',
  '分数 4/16 等于 1/4,因为分子和分母都可以被 4 整除。将顶部和底部数字都除以 4 得到分数 1/4。<|im_end|>\n'],
 'rejected': ['红色、黄色和绿色。<|im_end|>\n', '1/4 与 1/4 相同。<|im_end|>\n']}

DPO 训练需要两个模型:

  1. model: 这是要被微调的主模型, 在训练过程中, 该模型的参数会根据损失函数进行更新。
  2. ref_model: 这是一个参考模型, 通常与 model 具有相同的体系结构和初始权重。它被用于计算与目标输出的对数似然概率, 而不会被训练或更新参数。

在 DPO 训练中, 会先使用 ref_model 生成参考输出的对数似然概率 (ref_logits)。然后,model 会生成它自己的输出对数似然概率(model_logits)。

接下来, 使用一个损失函数 (如交叉熵损失) 来衡量 model_logits 与目标输出的差异, 并将这个损失值与 ref_logits 结合, 计算最终的 DPO 损失。

通过最小化这个 DPO 损失,model可以被引导产生更接近目标输出的预测, 同时也尽量不偏离 ref_model 的初始预测, 从而达到微调的目的。

因此,ref_model扮演了一个提供参考输出概率分布的角色, 用于计算损失和引导主模型 model 的训练, 而它自身的参数在训练过程中保持不变。

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    load_in_4bit=True
)
model.config.use_cache = False

# Reference model
ref_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    load_in_4bit=True
)

设置好 DPOtrainer 就可以训练了:

# LoRA configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
)

# Model to fine-tune
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    load_in_4bit=True
)
model.config.use_cache = False

# Reference model
ref_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    load_in_4bit=True
)

# Training arguments
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=5e-5,
    lr_scheduler_type="cosine",
    max_steps=200,
    save_strategy="no",
    logging_steps=1,
    output_dir=new_model,
    optim="paged_adamw_32bit",
    warmup_steps=100,
    bf16=True,
#     report_to="wandb",
    report_to="tensorboard",

)
dpo_trainer = DPOTrainer(
    model,
    ref_model,
    args=training_args,
    train_dataset=train_dataset, 
    eval_dataset=eval_dataset, 
    tokenizer=tokenizer,
    peft_config=peft_config,
    beta=0.1,
    max_prompt_length=1024,
    max_length=1536,
    force_use_ref_model=True
)
dpo_trainer.train()

Inference

[1] DPO(Direct Preference Optimization):LLM 的直接偏好优化

[2] 使用 DPO 微调 Llama 2

[3] LLM DPO

[4] fine-tuning-mistral-7b-with-dpo

[5] dpo-in-kaggle-working)

正文完
 
admin
版权声明:本站原创文章,由 admin 2024-03-23发表,共计7308字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。