1. 构建数据集
{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.","The man is talking about hawaii.","A woman is standing outside.","The battle was over. ","A group of people plays volleyball."]}
pos
: 一个正样本构成的列表neg
:一个负样本构成的列表
具体来说可以使用召回和寻找困难样本来构建训练集,然后按这个格式就好了。
import json
output_file = "./data/train.jsonl"
train_path = "./data.jsonl"
with open(train_path, "r", encoding="utf-8") as f_in, open(output_file, "w", encoding="utf-8") as f_out:
for line in f_in:
line = line.strip()
if line:
data = json.loads(line)
query = data['category_name'] + ":" + data['category_description']
pos = ["正样本 1", "正样本 2"]
neg = ["负样本 1", "负样本 2"]
jsonl_data = {
"query": query,
"pos": pos,
"neg": neg
}
json.dump(jsonl_data, f_out, ensure_ascii=False)
f_out.write("\n")
2. 训练
pip 安装 FlagEmbedding
pip install -U FlagEmbedding
pip install faiss-gpu
多卡训练:
!torchrun --nproc_per_node 4 \
-m FlagEmbedding.baai_general_embedding.finetune.run \
--output_dir output\
--model_name_or_path ./models/bge-large-zh-v1___5 \
--train_data ./data/train.jsonl \
--learning_rate 1e-5 \
--fp16 \
--report_to tensorboard \
--num_train_epochs 100 \
--per_device_train_batch_size 256 \
--dataloader_drop_last True \
--normlized True \
--temperature 0.02 \
--query_max_len 128 \
--passage_max_len 512\
--train_group_size 2 \
--negatives_cross_device \
--logging_steps 50 \
--query_instruction_for_retrieval "后续文本跟问题相关吗?"
3. 推理
from FlagEmbedding import FlagModel
model = FlagModel('./output',
query_instruction_for_retrieval="后续文本跟问题相关吗?",
use_fp16=True)
queries = ["query1", "query2"]
passages = ["text1", "text2"]
q_embeddings = model.encode_queries(queries)
p_embeddings = model.encode(passages)
scores = q_embeddings @ p_embeddings.T
更多方式参考 baai_general_embedding#usage。如果线上使用转为 onnx 或者 .
正文完