问题
参考 openai-cookbook/getting-started-with-redis-and-openai.ipynb, 使用本地的模型获取 embedding,本文简单使用:sentence-transformers/all-MiniLM-L6-v2。对文本进行 embedding 的方法很多,如:
- 简单可以用 Sentence-BERT
- 对 bert 输出进行进行 simcse 训练
- CoSENT(Cosine Sentence):CoSENT 模型提出了一种排序的损失函数,使训练过程更贴近预测,模型收敛速度和效果比 Sentence-BERT 更好
- BGE(BAAI general embedding):BGE 模型按照 retromae 方法进行预训练,参考论文,再使用对比学习 finetune 微调训练模型
- M3E: 使用 in-batch 负采样的对比学习的方式在句对数据集进行训练,为了保证 in-batch 负采样的效果,使用 A100 80G 来最大化 batch-size,并在共计 2200W+ 的句对数据集上训练了 1 epoch
流程
必备组件和数据:
- redis 使用 docker 拉取构建
- data 使用 Qoura 中与 kaggle 问题相关的数据集 Kaggle related questions on Qoura – Questions
使用 redis 进行 embedding search 步骤:
- 创建索引
- 写入文档 embedding
- 获取 query 的 embedding
- 用 query 的 embedding 进行 knn 的 search
比如使用 user_query = r"How do I win a Kaggle competition?"
搜索的结果如下:
- How do I win a Kaggle competition? (Score: 1.0)
- How do I start a Kaggle competition? (Score: 0.916)
- How do you compete in Kaggle competitions? (Score: 0.903)
- What is the best way to approach a Kaggle competition? (Score: 0.9)
- How do I win the Kaggle modeling competition? (Score: 0.866)
- How can I get the full solution to a Kaggle competition? (Score: 0.852)
- How do Kaggle competitions work? (Score: 0.851)
- How do I prepare for Kaggle competitions? (Score: 0.848)
- How can I learn to solve Kaggle competitions? (Score: 0.846)
- What is the Kaggle Competition? (Score: 0.845)
代码
import numpy as np | |
import pandas as pd | |
import redis | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from redis.commands.search.indexDefinition import IndexDefinition, IndexType | |
from redis.commands.search.query import Query | |
from redis.commands.search.field import TextField, VectorField, NumericField | |
from tqdm import tqdm | |
from typing import List | |
from ast import literal_eval | |
#redis setup | |
HOST="172.xx.xx.xx" | |
PORT=12345 | |
DB=0 | |
PASSWORD="" | |
redis_client = redis.Redis(host=HOST, | |
port=PORT, | |
db=DB, | |
password=PASSWORD | |
) | |
if redis_client.ping() == True: | |
print("Connection successful!") | |
def index_documents(model: SentenceTransformer, client: redis.Redis, prefix: str, documents: pd.DataFrame): | |
records = documents.to_dict("records") | |
for i, doc in tqdm(enumerate(records), total=len(records)): | |
key = f"{prefix}:{str(i)}" | |
# create byte vectors for title and content | |
question_embedding = get_embeddings(doc['Questions'], model).astype(dtype=np.float32).tobytes() | |
mapping_dict = {"title": doc['Questions'], | |
"answer": doc['Link'], | |
"followers": doc['Followers'], | |
"answered": doc['Answered'], | |
"question_vector": question_embedding | |
} | |
client.hset(key, mapping = mapping_dict) | |
def get_embeddings(text, model): | |
emb = model.encode( | |
text, | |
batch_size=1, | |
convert_to_numpy=True, | |
device="cpu", | |
normalize_embeddings=True | |
) | |
return emb | |
def search_redis( | |
model: SentenceTransformer, | |
redis_client: redis.Redis, | |
user_query: str, | |
index_name: str = "embeddings-index", | |
vector_field: str = "question_vector", | |
return_fields: list = ["title", "answer", "vector_score"], | |
hybrid_fields = "*", | |
k: int = 20, | |
print_results: bool = True, | |
) -> List[dict]: | |
# Creates embedding vector from user query | |
embedded_query = get_embeddings(user_query, model) | |
# print(type(embedded_query)) | |
# print(embedded_query.shape) | |
# Prepare the Query | |
base_query = f'{hybrid_fields}=>[KNN {k} @{vector_field} $query_vector AS vector_score]' | |
query = (Query(base_query) | |
.return_fields(*return_fields) | |
.sort_by("vector_score") | |
.paging(0, k) | |
.dialect(2) | |
) | |
# print(embedded_query.astype(dtype=np.float32)) | |
params_dict = {"query_vector": embedded_query.astype(dtype=np.float32).tobytes()} | |
# perform vector search | |
results = redis_client.ft(index_name).search(query, params_dict) | |
if print_results: | |
for i, article in enumerate(results.docs): | |
score = 1 - float(article.vector_score) | |
print(f"{i}. {article.title} (Score: {round(score ,3) })") | |
return results.docs | |
def create_hybrid_field(field_name: str, value: str) -> str: | |
return f'@{field_name}:"{value}"'if __name__ =='__main__': | |
# EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
EMBEDDING_MODEL_NAME = r"D:\models\all-MiniLM-L6-v2" | |
model = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
#读取数据 | |
data = pd.read_csv(r"./data/Kaggle related questions on Qoura - Questions.csv") | |
print(data.columns.tolist()) | |
print(f"data frame shape: {data.shape[0]}, {data.shape[1]}") | |
#1. 建立索引 | |
VECTOR_DIM = 384 # title text embedding dim | |
VECTOR_NUMBER = len(data) | |
INDEX_NAME = "embeddings-index" | |
PREFIX = "doc" | |
DISTANCE_METRIC = "COSINE" | |
# 定义 rediSearch 字段 选择 | |
question = TextField(name='question') | |
answer = TextField(name='answer') | |
followers = NumericField("followers") #跟帖数 | |
answered = NumericField("answered")# 是否回答 | |
# 需要 3 个字段 name algorithm attributes | |
title_emb = VectorField("question_vector", "HNSW", | |
{ | |
"TYPE": "FLOAT32", | |
"DIM": VECTOR_DIM, | |
"DISTANCE_METRIC": DISTANCE_METRIC, | |
"INITIAL_CAP": VECTOR_NUMBER | |
} | |
) | |
fields = [question, answer, followers, answered, title_emb] | |
HNSW_INDEX_NAME = INDEX_NAME + "_HNSW" | |
try: | |
redis_client.ft(HNSW_INDEX_NAME).info() | |
print("Index already exists") | |
except: | |
# Create RediSearch Index | |
redis_client.ft(HNSW_INDEX_NAME).create_index( | |
fields=fields, | |
definition=IndexDefinition(prefix=[PREFIX], index_type=IndexType.HASH) | |
) | |
#2. 写入 doc | |
index_documents(model, redis_client, PREFIX, data) | |
print(f"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {HNSW_INDEX_NAME}") | |
#3.4 对 query 进行 embedding 和查询 | |
user_query = r"How do I win a Kaggle competition?" | |
results = search_redis(model, redis_client, user_query, index_name=HNSW_INDEX_NAME, k=10) |
Inference
[2] openai-cookbook/getting-started-with-redis-and-openai.ipynb
[3] 2023-Kaggle-LECR-Top3-TrainCode
[4] text2vec
正文完
发表至: NLP
2023-12-21