BERT模型的两种部署方式flask和fastapi

1.flask 部署

这里用训练完的 BERT-biLSTM+CRF 模型对 text 解码部分如下,功能函数是跟 flask 和 fastapi 没关系的。比较容易出错的地方:

  1. 如果不用 DataLoader, 就要对 model 输入进行 .unsqueeze(0) 来构建 batch_size 这个维度
  2. 如果用 DataLoader,就跟训练时的验证部分差不多,就不多少了。

下面就是公共的功能部分代码:

import json
import torch
import torch.nn as nn
from torchcrf import CRF
import uvicorn
from fastapi import FastAPI, Request
from flask import Flask, jsonify, request
from transformers import AutoTokenizer, AutoModel, AutoConfig


class BertNer(nn.Module):
    def __init__(self, model_name, max_seq_len, num_labels=131):
        super(BertNer, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.bert_config = AutoConfig.from_pretrained(model_name)
        self.lstm_hidden = 128
        self.max_seq_len = max_seq_len
        self.bilstm = nn.LSTM(self.bert_config.hidden_size, self.lstm_hidden,
                              num_layers=1, bidirectional=True, batch_first=True,
                              )
        self.dropout = nn.Dropout(0.1)
        self.linear = nn.Linear(self.lstm_hidden*2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)


    def decode(self, input_ids, attention_mask):
        bert_output = self.bert(input_ids, attention_mask)
        seq_out = bert_output[0] #bs, seq_len, hidden_dim
        batch_size = seq_out.size(0)
        seq_out, _ = self.bilstm(seq_out) #bs, seq_len, lstm_hidden*2
        seq_out = seq_out.view(batch_size, self.max_seq_len, -1)
        seq_out = self.linear(self.dropout(seq_out))

        logits = self.crf.decode(seq_out, mask=attention_mask.bool())

        return logits

def test_fn(model, tokenizer, text, id2label, max_len, device):
    model.eval()

    if len(text) > max_len - 2:
        text = text[:max_len - 2]

    tmp_input_ids = tokenizer.convert_tokens_to_ids(["[CLS]"] + text + ["SEP"])
    attention_mask = [1] * len(tmp_input_ids)  # 1 不 mask
    input_ids = tmp_input_ids + [0] * (max_len - len(tmp_input_ids))
    attention_mask = attention_mask + [0] * (max_len - len(tmp_input_ids))

    input_ids = torch.tensor(input_ids, dtype=torch.long)
    attention_mask = torch.tensor(attention_mask, dtype=torch.long)

    input_ids = input_ids.unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    logits = model.decode(input_ids,
                          attention_mask)

    attention_mask = attention_mask.squeeze().detach().cpu().numpy()


    length = sum(attention_mask)
    logit = logits[0][1:length]
    logit = [id2label[i] for i in logit]

    return logit




device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
model_name = r"hfl/chinese-bert-wwm-ext"
max_seq_len = 256
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertNer(model_name, max_seq_len)
model.load_state_dict(torch.load('./output/best_ner.bin'))
model.to(device)

with open('./output/label2id.json', 'r') as f:
    label2id = json.load(f)
id2label = {v: k for k, v in label2id.items()}

那么实现一个预测的 api, 用 flask 怎么部署呢?

  1. 实例化 app
  2. 指定路由和方法@app.route("/predict", methods=['POST'])
  3. request 得到数据进行处理
  4. 功能函数 test_fn 来进行预测得到结果 jsonify 后返回。
app = Flask(__name__)
@app.route("/predict", methods=['POST'])
def predict():
    text = request.json["text"]
    text = [i for i in text]

    try:
        out = test_fn(model, tokenizer, text, id2label, max_seq_len, device)
        return jsonify({"text": ''.join(text),"result":''.join(out)
                        })
    except Exception as e:
        print(e)
        return jsonify({"result": "Model Failed"})

if __name__ == '__main__':
    app.run("0.0.0.0", port=8000)

对于 flask 部署,要采用 post 方法。另外就是你如果没有写 @app.post("/"), 你直接访问对应地址会 404, 使用地址 +/predict 就好了。

使用 curl 和 request 来发送请求:

  1. curl -X POST http://0.0.0.0:8000/predict -H 'Content-Type: application/json' -d '{"text":" 同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"}'
  2. 这里使用 postman 测试, request 请求 http://0.0.0.0:8000/predict,然后在raw 中填入 {"text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"}' 接口,如下图所示。
BERT 模型的两种部署方式 flask 和 fastapi

返回的 json(这里为了展示将结果 join 连接了一下):

{
    "result": "OOOOOOOOOOOOOOOOOOOOB- 竞赛行为_夺冠 I - 竞赛行为_夺冠 OOOOOOOOO",
    "text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!"
}

2.fastapi 部署

fastapi 部署也类似,只是要使用 async...await 来实现处理异步操作。

在函数定义时,我们使用 async def 来定义一个异步函数。在这种情况下,predict 函数是一个异步函数,它可以处理异步请求。

predict 函数中,我们使用 await 来等待异步操作完成。具体来说,await request.json() 表示等待获取请求的 JSON 数据,这是一个异步操作,它会返回一个包含请求数据的字典。

通过使用 asyncawait,我们可以在异步函数中进行非阻塞的操作,从而充分利用异步处理的优势,提高应用程序的性能和响应能力。

fastapi 可以自动实现文档功能,游览器中输入 http://0.0.0.0:8000/docs#/ 点击 try it out -> execute 看到:

BERT 模型的两种部署方式 flask 和 fastapi
app = FastAPI()
@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    text = data["text"]
    text = [i for i in text]
    try:
        out = test_fn(model, tokenizer, text, id2label, max_seq_len, device)
        return {"text": ''.join(text),"result":''.join(out)}
    except Exception as e:
        print(e)
        return {"result": "Model Failed"}


if __name__ == '__main__':
    uvicorn.run(app, host="0.0.0.0", port=8000)
#uvicorn app:app --host 0.0.0.0 --port 8000

在 postman 中请求也一样,返回结果如下:

{
    "text": "同梅西关系如何?C 罗一句话点出重点:当我夺冠时他一定很痛苦!",
    "result": "OOOOOOOOOOOOOOOOOOOOB- 竞赛行为_夺冠 I - 竞赛行为_夺冠 OOOOOOOOO"
}

 

参考

flask 中 gunicorn 的使用

 
正文完
 
admin
版权声明:本站原创文章,由 admin 2023-11-26发表,共计4276字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。
评论(没有评论)
验证码