1.flask 部署
这里用训练完的 BERT-biLSTM+CRF 模型对 text 解码部分如下,功能函数是跟 flask 和 fastapi 没关系的。比较容易出错的地方:
- 如果不用 DataLoader, 就要对 model 输入进行
.unsqueeze(0)
来构建 batch_size 这个维度 - 如果用 DataLoader,就跟训练时的验证部分差不多,就不多少了。
下面就是公共的功能部分代码:
import json
import torch
import torch.nn as nn
from torchcrf import CRF
import uvicorn
from fastapi import FastAPI, Request
from flask import Flask, jsonify, request
from transformers import AutoTokenizer, AutoModel, AutoConfig
class BertNer(nn.Module):
def __init__(self, model_name, max_seq_len, num_labels=131):
super(BertNer, self).__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.bert_config = AutoConfig.from_pretrained(model_name)
self.lstm_hidden = 128
self.max_seq_len = max_seq_len
self.bilstm = nn.LSTM(self.bert_config.hidden_size, self.lstm_hidden,
num_layers=1, bidirectional=True, batch_first=True,
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(self.lstm_hidden*2, num_labels)
self.crf = CRF(num_labels, batch_first=True)
def decode(self, input_ids, attention_mask):
bert_output = self.bert(input_ids, attention_mask)
seq_out = bert_output[0] #bs, seq_len, hidden_dim
batch_size = seq_out.size(0)
seq_out, _ = self.bilstm(seq_out) #bs, seq_len, lstm_hidden*2
seq_out = seq_out.view(batch_size, self.max_seq_len, -1)
seq_out = self.linear(self.dropout(seq_out))
logits = self.crf.decode(seq_out, mask=attention_mask.bool())
return logits
def test_fn(model, tokenizer, text, id2label, max_len, device):
model.eval()
if len(text) > max_len - 2:
text = text[:max_len - 2]
tmp_input_ids = tokenizer.convert_tokens_to_ids(["[CLS]"] + text + ["SEP"])
attention_mask = [1] * len(tmp_input_ids) # 1 不 mask
input_ids = tmp_input_ids + [0] * (max_len - len(tmp_input_ids))
attention_mask = attention_mask + [0] * (max_len - len(tmp_input_ids))
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
logits = model.decode(input_ids,
attention_mask)
attention_mask = attention_mask.squeeze().detach().cpu().numpy()
length = sum(attention_mask)
logit = logits[0][1:length]
logit = [id2label[i] for i in logit]
return logit
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
model_name = r"hfl/chinese-bert-wwm-ext"
max_seq_len = 256
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertNer(model_name, max_seq_len)
model.load_state_dict(torch.load('./output/best_ner.bin'))
model.to(device)
with open('./output/label2id.json', 'r') as f:
label2id = json.load(f)
id2label = {v: k for k, v in label2id.items()}
那么实现一个预测的 api, 用 flask 怎么部署呢?
- 实例化 app
- 指定路由和方法
@app.route("/predict", methods=['POST'])
- 对
request
得到数据进行处理 - 功能函数
test_fn
来进行预测得到结果jsonify
后返回。
app = Flask(__name__)
@app.route("/predict", methods=['POST'])
def predict():
text = request.json["text"]
text = [i for i in text]
try:
out = test_fn(model, tokenizer, text, id2label, max_seq_len, device)
return jsonify({"text": ''.join(text),"result":''.join(out)
})
except Exception as e:
print(e)
return jsonify({"result": "Model Failed"})
if __name__ == '__main__':
app.run("0.0.0.0", port=8000)
对于 flask 部署,要采用 post 方法。另外就是你如果没有写 @app.post("/")
, 你直接访问对应地址会 404, 使用地址 +/predict
就好了。
使用 curl 和 request 来发送请求:
curl -X POST http://0.0.0.0:8000/predict -H 'Content-Type: application/json' -d '{"text":" 同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"}'
- 这里使用 postman 测试, request 请求
http://0.0.0.0:8000/predict
,然后在raw
中填入{"text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"}'
接口,如下图所示。
返回的 json(这里为了展示将结果 join 连接了一下):
{
"result": "OOOOOOOOOOOOOOOOOOOOB- 竞赛行为_夺冠 I - 竞赛行为_夺冠 OOOOOOOOO",
"text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"
}
2.fastapi 部署
fastapi 部署也类似,只是要使用 async...await
来实现处理异步操作。
在函数定义时,我们使用 async def
来定义一个异步函数。在这种情况下,predict
函数是一个异步函数,它可以处理异步请求。
在 predict
函数中,我们使用 await
来等待异步操作完成。具体来说,await request.json()
表示等待获取请求的 JSON 数据,这是一个异步操作,它会返回一个包含请求数据的字典。
通过使用 async
和 await
,我们可以在异步函数中进行非阻塞的操作,从而充分利用异步处理的优势,提高应用程序的性能和响应能力。
fastapi 可以自动实现文档功能,游览器中输入 http://0.0.0.0:8000/docs#/
点击 try it out -> execute
看到:
app = FastAPI()
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
text = data["text"]
text = [i for i in text]
try:
out = test_fn(model, tokenizer, text, id2label, max_seq_len, device)
return {"text": ''.join(text),"result":''.join(out)}
except Exception as e:
print(e)
return {"result": "Model Failed"}
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)
#uvicorn app:app --host 0.0.0.0 --port 8000
在 postman 中请求也一样,返回结果如下:
{
"text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!",
"result": "OOOOOOOOOOOOOOOOOOOOB- 竞赛行为_夺冠 I - 竞赛行为_夺冠 OOOOOOOOO"
}
参考
正文完