1.flask 部署
这里用训练完的 BERT-biLSTM+CRF 模型对 text 解码部分如下,功能函数是跟 flask 和 fastapi 没关系的。比较容易出错的地方:
- 如果不用 DataLoader, 就要对 model 输入进行
.unsqueeze(0)
来构建 batch_size 这个维度 - 如果用 DataLoader,就跟训练时的验证部分差不多,就不多少了。
下面就是公共的功能部分代码:
import json | |
import torch | |
import torch.nn as nn | |
from torchcrf import CRF | |
import uvicorn | |
from fastapi import FastAPI, Request | |
from flask import Flask, jsonify, request | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
class BertNer(nn.Module): | |
def __init__(self, model_name, max_seq_len, num_labels=131): | |
super(BertNer, self).__init__() | |
self.bert = AutoModel.from_pretrained(model_name) | |
self.bert_config = AutoConfig.from_pretrained(model_name) | |
self.lstm_hidden = 128 | |
self.max_seq_len = max_seq_len | |
self.bilstm = nn.LSTM(self.bert_config.hidden_size, self.lstm_hidden, | |
num_layers=1, bidirectional=True, batch_first=True, | |
) | |
self.dropout = nn.Dropout(0.1) | |
self.linear = nn.Linear(self.lstm_hidden*2, num_labels) | |
self.crf = CRF(num_labels, batch_first=True) | |
def decode(self, input_ids, attention_mask): | |
bert_output = self.bert(input_ids, attention_mask) | |
seq_out = bert_output[0] #bs, seq_len, hidden_dim | |
batch_size = seq_out.size(0) | |
seq_out, _ = self.bilstm(seq_out) #bs, seq_len, lstm_hidden*2 | |
seq_out = seq_out.view(batch_size, self.max_seq_len, -1) | |
seq_out = self.linear(self.dropout(seq_out)) | |
logits = self.crf.decode(seq_out, mask=attention_mask.bool()) | |
return logits | |
def test_fn(model, tokenizer, text, id2label, max_len, device): | |
model.eval() | |
if len(text) > max_len - 2: | |
text = text[:max_len - 2] | |
tmp_input_ids = tokenizer.convert_tokens_to_ids(["[CLS]"] + text + ["SEP"]) | |
attention_mask = [1] * len(tmp_input_ids) # 1 不 mask | |
input_ids = tmp_input_ids + [0] * (max_len - len(tmp_input_ids)) | |
attention_mask = attention_mask + [0] * (max_len - len(tmp_input_ids)) | |
input_ids = torch.tensor(input_ids, dtype=torch.long) | |
attention_mask = torch.tensor(attention_mask, dtype=torch.long) | |
input_ids = input_ids.unsqueeze(0) | |
attention_mask = attention_mask.unsqueeze(0) | |
input_ids = input_ids.to(device) | |
attention_mask = attention_mask.to(device) | |
logits = model.decode(input_ids, | |
attention_mask) | |
attention_mask = attention_mask.squeeze().detach().cpu().numpy() | |
length = sum(attention_mask) | |
logit = logits[0][1:length] | |
logit = [id2label[i] for i in logit] | |
return logit | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# device = torch.device('cpu') | |
model_name = r"hfl/chinese-bert-wwm-ext" | |
max_seq_len = 256 | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = BertNer(model_name, max_seq_len) | |
model.load_state_dict(torch.load('./output/best_ner.bin')) | |
model.to(device) | |
with open('./output/label2id.json', 'r') as f: | |
label2id = json.load(f) | |
id2label = {v: k for k, v in label2id.items()} |
那么实现一个预测的 api, 用 flask 怎么部署呢?
- 实例化 app
- 指定路由和方法
@app.route("/predict", methods=['POST'])
- 对
request
得到数据进行处理 - 功能函数
test_fn
来进行预测得到结果jsonify
后返回。
app = Flask(__name__) | |
@app.route("/predict", methods=['POST']) | |
def predict(): | |
text = request.json["text"] | |
text = [i for i in text] | |
try: | |
out = test_fn(model, tokenizer, text, id2label, max_seq_len, device) | |
return jsonify({"text": ''.join(text),"result":''.join(out) | |
}) | |
except Exception as e: | |
print(e) | |
return jsonify({"result": "Model Failed"}) | |
if __name__ == '__main__': | |
app.run("0.0.0.0", port=8000) |
对于 flask 部署,要采用 post 方法。另外就是你如果没有写 @app.post("/")
, 你直接访问对应地址会 404, 使用地址 +/predict
就好了。
使用 curl 和 request 来发送请求:
curl -X POST http://0.0.0.0:8000/predict -H 'Content-Type: application/json' -d '{"text":" 同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"}'
- 这里使用 postman 测试, request 请求
http://0.0.0.0:8000/predict
,然后在raw
中填入{"text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"}'
接口,如下图所示。

返回的 json(这里为了展示将结果 join 连接了一下):
{ | |
"result": "OOOOOOOOOOOOOOOOOOOOB- 竞赛行为_夺冠 I - 竞赛行为_夺冠 OOOOOOOOO", | |
"text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!" | |
} |
2.fastapi 部署
fastapi 部署也类似,只是要使用 async...await
来实现处理异步操作。
在函数定义时,我们使用 async def
来定义一个异步函数。在这种情况下,predict
函数是一个异步函数,它可以处理异步请求。
在 predict
函数中,我们使用 await
来等待异步操作完成。具体来说,await request.json()
表示等待获取请求的 JSON 数据,这是一个异步操作,它会返回一个包含请求数据的字典。
通过使用 async
和 await
,我们可以在异步函数中进行非阻塞的操作,从而充分利用异步处理的优势,提高应用程序的性能和响应能力。
fastapi 可以自动实现文档功能,游览器中输入 http://0.0.0.0:8000/docs#/
点击 try it out -> execute
看到:

app = FastAPI() | |
@app.post("/predict") | |
async def predict(request: Request): | |
data = await request.json() | |
text = data["text"] | |
text = [i for i in text] | |
try: | |
out = test_fn(model, tokenizer, text, id2label, max_seq_len, device) | |
return {"text": ''.join(text),"result":''.join(out)} | |
except Exception as e: | |
print(e) | |
return {"result": "Model Failed"} | |
if __name__ == '__main__': | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
#uvicorn app:app --host 0.0.0.0 --port 8000 |
在 postman 中请求也一样,返回结果如下:
{ | |
"text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!", | |
"result": "OOOOOOOOOOOOOOOOOOOOB- 竞赛行为_夺冠 I - 竞赛行为_夺冠 OOOOOOOOO" | |
} |
参考
正文完