1. 构建数据集
{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.","The man is talking about hawaii.","A woman is standing outside.","The battle was over. ","A group of people plays volleyball."]}
pos
: 一个正样本构成的列表neg
:一个负样本构成的列表
具体来说可以使用召回和寻找困难样本来构建训练集,然后按这个格式就好了。
import json | |
output_file = "./data/train.jsonl" | |
train_path = "./data.jsonl" | |
with open(train_path, "r", encoding="utf-8") as f_in, open(output_file, "w", encoding="utf-8") as f_out: | |
for line in f_in: | |
line = line.strip() | |
if line: | |
data = json.loads(line) | |
query = data['category_name'] + ":" + data['category_description'] | |
pos = ["正样本 1", "正样本 2"] | |
neg = ["负样本 1", "负样本 2"] | |
jsonl_data = { | |
"query": query, | |
"pos": pos, | |
"neg": neg | |
} | |
json.dump(jsonl_data, f_out, ensure_ascii=False) | |
f_out.write("\n") |
2. 训练
pip 安装 FlagEmbedding
pip install -U FlagEmbedding | |
pip install faiss-gpu |
多卡训练:
!torchrun --nproc_per_node 4 \ | |
-m FlagEmbedding.baai_general_embedding.finetune.run \ | |
--output_dir output\ | |
--model_name_or_path ./models/bge-large-zh-v1___5 \ | |
--train_data ./data/train.jsonl \ | |
--learning_rate 1e-5 \ | |
--fp16 \ | |
--report_to tensorboard \ | |
--num_train_epochs 100 \ | |
--per_device_train_batch_size 256 \ | |
--dataloader_drop_last True \ | |
--normlized True \ | |
--temperature 0.02 \ | |
--query_max_len 128 \ | |
--passage_max_len 512\ | |
--train_group_size 2 \ | |
--negatives_cross_device \ | |
--logging_steps 50 \ | |
--query_instruction_for_retrieval "后续文本跟问题相关吗?" |
3. 推理
from FlagEmbedding import FlagModel | |
model = FlagModel('./output', | |
query_instruction_for_retrieval="后续文本跟问题相关吗?", | |
use_fp16=True) | |
queries = ["query1", "query2"] | |
passages = ["text1", "text2"] | |
q_embeddings = model.encode_queries(queries) | |
p_embeddings = model.encode(passages) | |
scores = q_embeddings @ p_embeddings.T |
更多方式参考 baai_general_embedding#usage。如果线上使用转为 onnx 或者 .
正文完