FlagEmbedding 微调

1. 构建数据集

{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.","The man is talking about hawaii.","A woman is standing outside.","The battle was over. ","A group of people plays volleyball."]}

toy_finetune_data.jsonl

  1. pos: 一个正样本构成的列表
  2. neg:一个负样本构成的列表

具体来说可以使用召回和寻找困难样本来构建训练集,然后按这个格式就好了。

import json


output_file = "./data/train.jsonl"
train_path = "./data.jsonl"
with open(train_path, "r", encoding="utf-8") as f_in, open(output_file, "w", encoding="utf-8") as f_out:
    for line in f_in:
        line = line.strip()
        if line:
            data = json.loads(line)
            query = data['category_name'] + ":" + data['category_description']
            pos = ["正样本 1", "正样本 2"]
            neg = ["负样本 1", "负样本 2"]

            jsonl_data = {
                "query": query,
                "pos": pos,
                "neg": neg
            }


            json.dump(jsonl_data, f_out, ensure_ascii=False)
            f_out.write("\n")

2. 训练

pip 安装 FlagEmbedding

pip install -U FlagEmbedding
pip install faiss-gpu

多卡训练:

!torchrun --nproc_per_node 4 \
-m FlagEmbedding.baai_general_embedding.finetune.run \
--output_dir output\
--model_name_or_path ./models/bge-large-zh-v1___5 \
--train_data ./data/train.jsonl \
--learning_rate 1e-5 \
--fp16 \
--report_to tensorboard \
--num_train_epochs 100 \
--per_device_train_batch_size 256 \
--dataloader_drop_last True \
--normlized True \
--temperature 0.02 \
--query_max_len 128 \
--passage_max_len 512\
--train_group_size 2 \
--negatives_cross_device \
--logging_steps 50 \
--query_instruction_for_retrieval "后续文本跟问题相关吗?" 

3. 推理

from FlagEmbedding import FlagModel

model = FlagModel('./output', 
                  query_instruction_for_retrieval="后续文本跟问题相关吗?",
                  use_fp16=True)


queries = ["query1", "query2"]
passages = ["text1", "text2"]
q_embeddings = model.encode_queries(queries)
p_embeddings = model.encode(passages)
scores = q_embeddings @ p_embeddings.T

更多方式参考 baai_general_embedding#usage。如果线上使用转为 onnx 或者 .

正文完
 
admin
版权声明:本站原创文章,由 admin 2024-03-29发表,共计1752字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。
评论(没有评论)
验证码