Gradio 是一款相当直观和强大的库,它使得在 Python 中构建交互式的应用变得非常简单。这个库的主要优点是,你不需要有任何前端开发的知识就可以创建出让用户易于理解和使用的界面。
本应用使用 Gradio,让用户可以方便快捷地进行多类型文档的问答。用户只需上传他们的文档,然后输入他们想要查询的问题,应用就会解析文档内容并返回所查问题的答案。
另外一方面,使用 langchain_community.document_loaders
可以解析多类型的文档。无论文档是 PDF,csv,md 等,都可以很好地处理不同格式的文档并提取出文档的文本内容。下面是各个部分的代码。
不足的地方:
- 对于 csv, 问答效果不太好,在问答时要在 prompt 中添加表头信息然后用 few-shot 方法来提问
- 对于 xlsx 等文档不支持,可以修改下 loader 部分
1. 对不同类型的文档进行解析并存储到 faiss
import os
os.environ["OPENAI_API_KEY"] = "sk-xx"
import gradio as gr
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import (PyPDFLoader,
TextLoader,
UnstructuredCSVLoader,
UnstructuredMarkdownLoader
)
model_name = "BAAI/bge-large-zh-v1.5"
model_kwargs = {"device": "cuda"} #cpu
encode_kwargs = {
"batch_size": 32,
"normalize_embeddings": True
}
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
def get_faiss_vectordb(file: str, embedding: HuggingFaceBgeEmbeddings):
filename, file_extension = os.path.splitext(file)
# Create a unique FAISS index path based on the input file's name.
faiss_index_path = f"faiss_index_{filename}"
# Determine the loader based on the file extension.
if file_extension == ".pdf":
loader = PyPDFLoader(file_path=file)
elif file_extension == ".csv":
loader = UnstructuredCSVLoader(file_path=file)
elif file_extension == ".txt":
loader = TextLoader(file_path=file)
elif file_extension == ".md":
loader = UnstructuredMarkdownLoader(file_path=file)
else:
# If the document type is not supported, print a message and return None.
raise "This document type is not supported."
# Load the document using the selected loader.
documents = loader.load()
# Split the loaded text into smaller chunks for processing.
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
separators=["\n", "\n\n", "(?<=\.)", ""," "],
)
doc_chunked = text_splitter.split_documents(documents=documents)
vectordb = FAISS.from_documents(doc_chunked, embedding)
vectordb.save_local(faiss_index_path)
return vectordb
2. 检索问题相关上下文并进行回答
def run_llm(file, query: str) -> str:
vectordb = get_faiss_vectordb(file.name, embeddings)
#todo
context = vectordb.similarity_search(query, k=3)
openai_llm = OpenAI(temperature=0, verbose=True)
# Create a RetrievalQA instance from a chain type with a specified retriever.
retrieval_qa = RetrievalQA.from_chain_type(llm=openai_llm, chain_type="stuff", retriever=vectordb.as_retriever(search_kwargs={'k': 3})
)
answer = retrieval_qa.run(query)
return context, answer
3. gradio 构建 web 页面
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Docs QA</h1>""")
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
file = gr.File(label="上传文件")
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("上传")
with gr.Column(scale=1):
query = gr.Text(label='要查询的问题')
context = gr.Text(label="相关上下文", autoscroll=True)
response = gr.Text(label="答案")
btn = gr.Button("提交")
btn.click(run_llm, inputs=[file, query], outputs=[context, response])
正文完