项目作者: kemingy

项目描述 :
Serving the deep learning models easily.
高级语言: Python
项目地址: git://github.com/kemingy/ventu.git
创建时间: 2020-03-24T08:03:45Z
项目社区:https://github.com/kemingy/ventu

开源协议:

下载


Ventu

pypi
versions
Python Test
Python document
Language grade: Python

Serving the deep learning models easily.

Attention

This project is just a proof of concept. Check the MOSEC for production usage.

Install

  1. pip install ventu

Features

  • only need to implement Model(preprocess, postprocess, inference or batch_inference)
  • request & response data validation using pydantic
  • API document using SpecTree (when run with run_http)
  • backend service using falcon supports both JSON and msgpack
  • dynamic batching with batching using Unix domain socket or TCP
    • errors in one request won’t affect others in the same batch
    • load balancing
  • support all the runtime
  • health check
  • monitoring metrics (Prometheus)
    • if you have multiple workers, remember to setup prometheus_multiproc_dir environment variable to a directory
  • inference warm-up

How to use

  • define your request data schema and response data schema with pydantic
    • add examples to schema.Config.schema_extra[examples] for warm-up and health check (optional)
  • inherit ventu.Ventu, implement the preprocess and postprocess methods
  • for standalone HTTP service, implement the inference method, run with run_http
  • for the worker behind dynamic batching service, implement the batch_inference method, run with run_socket

check the document for API details

Example

The demo code can be found in examples.

Service

Install requirements pip install numpy torch transformers httpx

  1. import argparse
  2. import logging
  3. import numpy as np
  4. import torch
  5. from pydantic import BaseModel, confloat, constr
  6. from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
  7. from ventu import Ventu
  8. # request schema used for validation
  9. class Req(BaseModel):
  10. # the input sentence should be at least 2 characters
  11. text: constr(min_length=2)
  12. class Config:
  13. # examples used for health check and warm-up
  14. schema_extra = {
  15. 'example': {'text': 'my cat is very cut'},
  16. 'batch_size': 16,
  17. }
  18. # response schema used for validation
  19. class Resp(BaseModel):
  20. positive: confloat(ge=0, le=1)
  21. negative: confloat(ge=0, le=1)
  22. class ModelInference(Ventu):
  23. def __init__(self, *args, **kwargs):
  24. # initialize super class with request & response schema, configs
  25. super().__init__(*args, **kwargs)
  26. # initialize model and other tools
  27. self.tokenizer = DistilBertTokenizer.from_pretrained(
  28. 'distilbert-base-uncased')
  29. self.model = DistilBertForSequenceClassification.from_pretrained(
  30. 'distilbert-base-uncased-finetuned-sst-2-english')
  31. def preprocess(self, data: Req):
  32. # preprocess a request data (as defined in the request schema)
  33. tokens = self.tokenizer.encode(data.text, add_special_tokens=True)
  34. return tokens
  35. def batch_inference(self, data):
  36. # batch inference is used in `socket` mode
  37. data = [torch.tensor(token) for token in data]
  38. with torch.no_grad():
  39. result = self.model(torch.nn.utils.rnn.pad_sequence(data, batch_first=True))[0]
  40. return result.numpy()
  41. def inference(self, data):
  42. # inference is used in `http` mode
  43. with torch.no_grad():
  44. result = self.model(torch.tensor(data).unsqueeze(0))[0]
  45. return result.numpy()[0]
  46. def postprocess(self, data):
  47. # postprocess a response data (returned data as defined in the response schema)
  48. scores = (np.exp(data) / np.exp(data).sum(-1, keepdims=True)).tolist()
  49. return {'negative': scores[0], 'positive': scores[1]}
  50. def create_model():
  51. logger = logging.getLogger()
  52. formatter = logging.Formatter(
  53. fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
  54. handler = logging.StreamHandler()
  55. handler.setFormatter(formatter)
  56. logger.setLevel(logging.DEBUG)
  57. logger.addHandler(handler)
  58. model = ModelInference(Req, Resp, use_msgpack=True)
  59. return model
  60. def create_app():
  61. """for gunicorn"""
  62. return create_model().app
  63. if __name__ == "__main__":
  64. parser = argparse.ArgumentParser(description='Ventu service')
  65. parser.add_argument('--mode', '-m', default='http', choices=('http', 'unix', 'tcp'))
  66. parser.add_argument('--host', default='localhost')
  67. parser.add_argument('--port', '-p', default=8080, type=int)
  68. parser.add_argument('--socket', '-s', default='batching.socket')
  69. args = parser.parse_args()
  70. model = create_model()
  71. if args.mode == 'unix':
  72. model.run_unix(args.socket)
  73. elif args.mode == 'tcp':
  74. model.run_tcp(args.host, args.port)
  75. else:
  76. model.run_http(args.host, args.port)

You can run this script as:

  • a single thread HTTP service: python examples/app.py
  • a HTTP service with multiple workers: gunicorn -w 2 -b localhost:8080 'examples.app:create_app()'
    • when run as a HTTP service, can check the follow links:
      • /metrics Prometheus metrics
      • /health health check
      • /inference inference
      • /apidoc/redoc or /apidoc/swagger OpenAPI document
  • an inference worker behind the batching service: python examples/app.py -m socket (Unix domain socket) or python examples/app.py -m tcp --host localhost --port 8888 (TCP) (need to run the batching service first)

Client

  1. from concurrent import futures
  2. import httpx
  3. import msgpack
  4. URL = 'http://localhost:8080/inference'
  5. HEADER = {'Content-Type': 'application/msgpack'}
  6. packer = msgpack.Packer(
  7. autoreset=True,
  8. use_bin_type=True,
  9. )
  10. def request(text):
  11. return httpx.post(URL, data=packer.pack({'text': text}), headers=HEADER)
  12. if __name__ == "__main__":
  13. with futures.ThreadPoolExecutor() as executor:
  14. text = [
  15. 'They are smart',
  16. 'what is your problem?',
  17. 'I hate that!',
  18. 'x',
  19. ]
  20. results = executor.map(request, text)
  21. for i, resp in enumerate(results):
  22. print(
  23. f'>> {text[i]} -> [{resp.status_code}]\n'
  24. f'{msgpack.unpackb(resp.content)}'
  25. )