Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/flask_rest_api_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Deploying PyTorch in Python via a REST API with Flask3========================================================4**Author**: `Avinash Sajjanshetty <https://avi.im>`_56In this tutorial, we will deploy a PyTorch model using Flask and expose a7REST API for model inference. In particular, we will deploy a pretrained8DenseNet 121 model which detects the image.910.. tip:: All the code used here is released under MIT license and is available on `Github <https://github.com/avinassh/pytorch-flask-api>`_.1112This represents the first in a series of tutorials on deploying PyTorch models13in production. Using Flask in this way is by far the easiest way to start14serving your PyTorch models, but it will not work for a use case15with high performance requirements. For that:1617- If you're already familiar with TorchScript, you can jump straight into our18`Loading a TorchScript Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_ tutorial.1920- If you first need a refresher on TorchScript, check out our21`Intro a TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_ tutorial.22"""232425######################################################################26# API Definition27# --------------28#29# We will first define our API endpoints, the request and response types. Our30# API endpoint will be at ``/predict`` which takes HTTP POST requests with a31# ``file`` parameter which contains the image. The response will be of JSON32# response containing the prediction:33#34# .. code-block:: sh35#36# {"class_id": "n02124075", "class_name": "Egyptian_cat"}37#38#3940######################################################################41# Dependencies42# ------------43#44# Install the required dependencies by running the following command:45#46# .. code-block:: sh47#48# pip install Flask==2.0.1 torchvision==0.10.0495051######################################################################52# Simple Web Server53# -----------------54#55# Following is a simple web server, taken from Flask's documentation565758from flask import Flask59app = Flask(__name__)606162@app.route('/')63def hello():64return 'Hello World!'6566###############################################################################67# We will also change the response type, so that it returns a JSON response68# containing ImageNet class id and name. The updated ``app.py`` file will69# be now:7071from flask import Flask, jsonify72app = Flask(__name__)7374@app.route('/predict', methods=['POST'])75def predict():76return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})777879######################################################################80# Inference81# -----------------82#83# In the next sections we will focus on writing the inference code. This will84# involve two parts, one where we prepare the image so that it can be fed85# to DenseNet and next, we will write the code to get the actual prediction86# from the model.87#88# Preparing the image89# ~~~~~~~~~~~~~~~~~~~90#91# DenseNet model requires the image to be of 3 channel RGB image of size92# 224 x 224. We will also normalize the image tensor with the required mean93# and standard deviation values. You can read more about it94# `here <https://pytorch.org/vision/stable/models.html>`_.95#96# We will use ``transforms`` from ``torchvision`` library and build a97# transform pipeline, which transforms our images as required. You98# can read more about transforms `here <https://pytorch.org/vision/stable/transforms.html>`_.99100import io101102import torchvision.transforms as transforms103from PIL import Image104105def transform_image(image_bytes):106my_transforms = transforms.Compose([transforms.Resize(255),107transforms.CenterCrop(224),108transforms.ToTensor(),109transforms.Normalize(110[0.485, 0.456, 0.406],111[0.229, 0.224, 0.225])])112image = Image.open(io.BytesIO(image_bytes))113return my_transforms(image).unsqueeze(0)114115######################################################################116# The above method takes image data in bytes, applies the series of transforms117# and returns a tensor. To test the above method, read an image file in118# bytes mode (first replacing `../_static/img/sample_file.jpeg` with the actual119# path to the file on your computer) and see if you get a tensor back:120121with open("../_static/img/sample_file.jpeg", 'rb') as f:122image_bytes = f.read()123tensor = transform_image(image_bytes=image_bytes)124print(tensor)125126######################################################################127# Prediction128# ~~~~~~~~~~~~~~~~~~~129#130# Now will use a pretrained DenseNet 121 model to predict the image class. We131# will use one from ``torchvision`` library, load the model and get an132# inference. While we'll be using a pretrained model in this example, you can133# use this same approach for your own models. See more about loading your134# models in this :doc:`tutorial </beginner/saving_loading_models>`.135136from torchvision import models137138# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:139model = models.densenet121(weights='IMAGENET1K_V1')140# Since we are using our model only for inference, switch to `eval` mode:141model.eval()142143144def get_prediction(image_bytes):145tensor = transform_image(image_bytes=image_bytes)146outputs = model.forward(tensor)147_, y_hat = outputs.max(1)148return y_hat149150######################################################################151# The tensor ``y_hat`` will contain the index of the predicted class id.152# However, we need a human readable class name. For that we need a class id153# to name mapping. Download154# `this file <https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json>`_155# as ``imagenet_class_index.json`` and remember where you saved it (or, if you156# are following the exact steps in this tutorial, save it in157# `tutorials/_static`). This file contains the mapping of ImageNet class id to158# ImageNet class name. We will load this JSON file and get the class name of159# the predicted index.160161import json162163imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))164165def get_prediction(image_bytes):166tensor = transform_image(image_bytes=image_bytes)167outputs = model.forward(tensor)168_, y_hat = outputs.max(1)169predicted_idx = str(y_hat.item())170return imagenet_class_index[predicted_idx]171172173######################################################################174# Before using ``imagenet_class_index`` dictionary, first we will convert175# tensor value to a string value, since the keys in the176# ``imagenet_class_index`` dictionary are strings.177# We will test our above method:178179180with open("../_static/img/sample_file.jpeg", 'rb') as f:181image_bytes = f.read()182print(get_prediction(image_bytes=image_bytes))183184######################################################################185# You should get a response like this:186187['n02124075', 'Egyptian_cat']188189######################################################################190# The first item in array is ImageNet class id and second item is the human191# readable name.192#193194######################################################################195# Integrating the model in our API Server196# ---------------------------------------197#198# In this final part we will add our model to our Flask API server. Since199# our API server is supposed to take an image file, we will update our ``predict``200# method to read files from the requests:201#202# .. code-block:: python203#204# from flask import request205#206# @app.route('/predict', methods=['POST'])207# def predict():208# if request.method == 'POST':209# # we will get the file from the request210# file = request.files['file']211# # convert that to bytes212# img_bytes = file.read()213# class_id, class_name = get_prediction(image_bytes=img_bytes)214# return jsonify({'class_id': class_id, 'class_name': class_name})215#216#217######################################################################218# The ``app.py`` file is now complete. Following is the full version; replace219# the paths with the paths where you saved your files and it should run:220#221# .. code-block:: python222#223# import io224# import json225#226# from torchvision import models227# import torchvision.transforms as transforms228# from PIL import Image229# from flask import Flask, jsonify, request230#231#232# app = Flask(__name__)233# imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))234# model = models.densenet121(weights='IMAGENET1K_V1')235# model.eval()236#237#238# def transform_image(image_bytes):239# my_transforms = transforms.Compose([transforms.Resize(255),240# transforms.CenterCrop(224),241# transforms.ToTensor(),242# transforms.Normalize(243# [0.485, 0.456, 0.406],244# [0.229, 0.224, 0.225])])245# image = Image.open(io.BytesIO(image_bytes))246# return my_transforms(image).unsqueeze(0)247#248#249# def get_prediction(image_bytes):250# tensor = transform_image(image_bytes=image_bytes)251# outputs = model.forward(tensor)252# _, y_hat = outputs.max(1)253# predicted_idx = str(y_hat.item())254# return imagenet_class_index[predicted_idx]255#256#257# @app.route('/predict', methods=['POST'])258# def predict():259# if request.method == 'POST':260# file = request.files['file']261# img_bytes = file.read()262# class_id, class_name = get_prediction(image_bytes=img_bytes)263# return jsonify({'class_id': class_id, 'class_name': class_name})264#265#266# if __name__ == '__main__':267# app.run()268#269#270######################################################################271# Let's test our web server! Run:272#273# .. code-block:: sh274#275# FLASK_ENV=development FLASK_APP=app.py flask run276#277#######################################################################278# We can use the279# `requests <https://pypi.org/project/requests/>`_280# library to send a POST request to our app:281#282# .. code-block:: python283#284# import requests285#286# resp = requests.post("http://localhost:5000/predict",287# files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})288#289290#######################################################################291# Printing `resp.json()` will now show the following:292#293# .. code-block:: sh294#295# {"class_id": "n02124075", "class_name": "Egyptian_cat"}296#297######################################################################298# Next steps299# --------------300#301# The server we wrote is quite trivial and may not do everything302# you need for your production application. So, here are some things you303# can do to make it better:304#305# - The endpoint ``/predict`` assumes that always there will be a image file306# in the request. This may not hold true for all requests. Our user may307# send image with a different parameter or send no images at all.308#309# - The user may send non-image type files too. Since we are not handling310# errors, this will break our server. Adding an explicit error handing311# path that will throw an exception would allow us to better handle312# the bad inputs313#314# - Even though the model can recognize a large number of classes of images,315# it may not be able to recognize all images. Enhance the implementation316# to handle cases when the model does not recognize anything in the image.317#318# - We run the Flask server in the development mode, which is not suitable for319# deploying in production. You can check out `this tutorial <https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/>`_320# for deploying a Flask server in production.321#322# - You can also add a UI by creating a page with a form which takes the image and323# displays the prediction. Check out the `demo <https://pytorch-imagenet.herokuapp.com/>`_324# of a similar project and its `source code <https://github.com/avinassh/pytorch-flask-api-heroku>`_.325#326# - In this tutorial, we only showed how to build a service that could return predictions for327# a single image at a time. We could modify our service to be able to return predictions for328# multiple images at once. In addition, the `service-streamer <https://github.com/ShannonAI/service-streamer>`_329# library automatically queues requests to your service and samples them into mini-batches330# that can be fed into your model. You can check out `this tutorial <https://github.com/ShannonAI/service-streamer/wiki/Vision-Recognition-Service-with-Flask-and-service-streamer>`_.331#332# - Finally, we encourage you to check out our other tutorials on deploying PyTorch models333# linked-to at the top of the page.334#335336337