不使用openai对文本进行embedding以及redis vector search

问题

参考 openai-cookbook/getting-started-with-redis-and-openai.ipynb, 使用本地的模型获取 embedding,本文简单使用:sentence-transformers/all-MiniLM-L6-v2。对文本进行 embedding 的方法很多,如:

  1. 简单可以用 Sentence-BERT
  2. 对 bert 输出进行进行 simcse 训练
  3. CoSENT(Cosine Sentence):CoSENT 模型提出了一种排序的损失函数,使训练过程更贴近预测,模型收敛速度和效果比 Sentence-BERT 更好
  4. BGE(BAAI general embedding):BGE 模型按照 retromae 方法进行预训练,参考论文,再使用对比学习 finetune 微调训练模型
  5. M3E: 使用 in-batch 负采样的对比学习的方式在句对数据集进行训练,为了保证 in-batch 负采样的效果,使用 A100 80G 来最大化 batch-size,并在共计 2200W+ 的句对数据集上训练了 1 epoch

流程

必备组件和数据:

使用 redis 进行 embedding search 步骤:

不使用 openai 对文本进行 embedding 以及 redis vector search

  1. 创建索引
  2. 写入文档 embedding
  3. 获取 query 的 embedding
  4. 用 query 的 embedding 进行 knn 的 search

比如使用 user_query = r"How do I win a Kaggle competition?" 搜索的结果如下:

  1. How do I win a Kaggle competition? (Score: 1.0)
  2. How do I start a Kaggle competition? (Score: 0.916)
  3. How do you compete in Kaggle competitions? (Score: 0.903)
  4. What is the best way to approach a Kaggle competition? (Score: 0.9)
  5. How do I win the Kaggle modeling competition? (Score: 0.866)
  6. How can I get the full solution to a Kaggle competition? (Score: 0.852)
  7. How do Kaggle competitions work? (Score: 0.851)
  8. How do I prepare for Kaggle competitions? (Score: 0.848)
  9. How can I learn to solve Kaggle competitions? (Score: 0.846)
  10. What is the Kaggle Competition? (Score: 0.845)

代码

import numpy as np
import pandas as pd
import redis
import torch
from sentence_transformers import SentenceTransformer
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
from redis.commands.search.field import TextField, VectorField, NumericField
from tqdm import tqdm
from typing import List
from ast import literal_eval

#redis setup
HOST="172.xx.xx.xx"
PORT=12345
DB=0
PASSWORD=""
redis_client = redis.Redis(host=HOST,
                           port=PORT,
                           db=DB,
                           password=PASSWORD
                           )

if redis_client.ping() == True:
    print("Connection successful!")

def index_documents(model: SentenceTransformer, client: redis.Redis, prefix: str, documents: pd.DataFrame):
    records = documents.to_dict("records")
    for i, doc in tqdm(enumerate(records), total=len(records)):
        key = f"{prefix}:{str(i)}"
        # create byte vectors for title and content
        question_embedding = get_embeddings(doc['Questions'], model).astype(dtype=np.float32).tobytes()

        mapping_dict = {"title": doc['Questions'],
            "answer": doc['Link'],
            "followers": doc['Followers'],
            "answered": doc['Answered'],
            "question_vector": question_embedding
        }

        client.hset(key, mapping = mapping_dict)


def get_embeddings(text, model):
    emb = model.encode(
        text,
        batch_size=1,
        convert_to_numpy=True,
        device="cpu",
        normalize_embeddings=True
    )

    return emb


def search_redis(
    model: SentenceTransformer,
    redis_client: redis.Redis,
    user_query: str,
    index_name: str = "embeddings-index",
    vector_field: str = "question_vector",
    return_fields: list = ["title", "answer",  "vector_score"],
    hybrid_fields = "*",
    k: int = 20,
    print_results: bool = True,
) -> List[dict]:

    # Creates embedding vector from user query
    embedded_query = get_embeddings(user_query, model)

    # print(type(embedded_query))
    # print(embedded_query.shape)

    # Prepare the Query
    base_query = f'{hybrid_fields}=>[KNN {k} @{vector_field} $query_vector AS vector_score]'
    query = (Query(base_query)
         .return_fields(*return_fields)
         .sort_by("vector_score")
         .paging(0, k)
         .dialect(2)
    )
    # print(embedded_query.astype(dtype=np.float32))
    params_dict = {"query_vector": embedded_query.astype(dtype=np.float32).tobytes()}

    # perform vector search
    results = redis_client.ft(index_name).search(query, params_dict)
    if print_results:
        for i, article in enumerate(results.docs):
            score = 1 - float(article.vector_score)
            print(f"{i}. {article.title} (Score: {round(score ,3) })")
    return results.docs



def create_hybrid_field(field_name: str, value: str) -> str:
    return f'@{field_name}:"{value}"'if __name__ =='__main__':
    # EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
    EMBEDDING_MODEL_NAME = r"D:\models\all-MiniLM-L6-v2"

    model = SentenceTransformer(EMBEDDING_MODEL_NAME)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    #读取数据
    data = pd.read_csv(r"./data/Kaggle related questions on Qoura - Questions.csv")
    print(data.columns.tolist())
    print(f"data frame shape: {data.shape[0]}, {data.shape[1]}")
    #1. 建立索引
    VECTOR_DIM = 384  # title text embedding dim
    VECTOR_NUMBER = len(data)
    INDEX_NAME = "embeddings-index"
    PREFIX = "doc"
    DISTANCE_METRIC = "COSINE"

    # 定义 rediSearch 字段 选择
    question = TextField(name='question')
    answer = TextField(name='answer')
    followers = NumericField("followers") #跟帖数
    answered = NumericField("answered")# 是否回答

    # 需要 3 个字段 name  algorithm attributes
    title_emb = VectorField("question_vector", "HNSW",
                            {
                                "TYPE": "FLOAT32",
                                "DIM": VECTOR_DIM,
                                "DISTANCE_METRIC": DISTANCE_METRIC,
                                "INITIAL_CAP": VECTOR_NUMBER
                            }
                            )
    fields = [question, answer, followers, answered,  title_emb]

    HNSW_INDEX_NAME = INDEX_NAME + "_HNSW"

    try:
        redis_client.ft(HNSW_INDEX_NAME).info()
        print("Index already exists")
    except:
        # Create RediSearch Index
        redis_client.ft(HNSW_INDEX_NAME).create_index(
            fields=fields,
            definition=IndexDefinition(prefix=[PREFIX], index_type=IndexType.HASH)
        )

    #2. 写入 doc
    index_documents(model, redis_client, PREFIX, data)
    print(f"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {HNSW_INDEX_NAME}")

    #3.4 对 query 进行 embedding 和查询
    user_query = r"How do I win a Kaggle competition?"
    results = search_redis(model, redis_client, user_query, index_name=HNSW_INDEX_NAME, k=10)

Inference

[1] Vector Similarity

[2] openai-cookbook/getting-started-with-redis-and-openai.ipynb

[3] 2023-Kaggle-LECR-Top3-TrainCode

[4] text2vec

[5] FlagEmbedding|BGE|向量嵌入

正文完
 
admin
版权声明:本站原创文章,由 admin 2023-12-21发表,共计5095字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。
评论(没有评论)
验证码