CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/flask_rest_api_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Deploying PyTorch in Python via a REST API with Flask
4
========================================================
5
**Author**: `Avinash Sajjanshetty <https://avi.im>`_
6
7
In this tutorial, we will deploy a PyTorch model using Flask and expose a
8
REST API for model inference. In particular, we will deploy a pretrained
9
DenseNet 121 model which detects the image.
10
11
.. tip:: All the code used here is released under MIT license and is available on `Github <https://github.com/avinassh/pytorch-flask-api>`_.
12
13
This represents the first in a series of tutorials on deploying PyTorch models
14
in production. Using Flask in this way is by far the easiest way to start
15
serving your PyTorch models, but it will not work for a use case
16
with high performance requirements. For that:
17
18
- If you're already familiar with TorchScript, you can jump straight into our
19
`Loading a TorchScript Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_ tutorial.
20
21
- If you first need a refresher on TorchScript, check out our
22
`Intro a TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_ tutorial.
23
"""
24
25
26
######################################################################
27
# API Definition
28
# --------------
29
#
30
# We will first define our API endpoints, the request and response types. Our
31
# API endpoint will be at ``/predict`` which takes HTTP POST requests with a
32
# ``file`` parameter which contains the image. The response will be of JSON
33
# response containing the prediction:
34
#
35
# .. code-block:: sh
36
#
37
# {"class_id": "n02124075", "class_name": "Egyptian_cat"}
38
#
39
#
40
41
######################################################################
42
# Dependencies
43
# ------------
44
#
45
# Install the required dependencies by running the following command:
46
#
47
# .. code-block:: sh
48
#
49
# pip install Flask==2.0.1 torchvision==0.10.0
50
51
52
######################################################################
53
# Simple Web Server
54
# -----------------
55
#
56
# Following is a simple web server, taken from Flask's documentation
57
58
59
from flask import Flask
60
app = Flask(__name__)
61
62
63
@app.route('/')
64
def hello():
65
return 'Hello World!'
66
67
###############################################################################
68
# We will also change the response type, so that it returns a JSON response
69
# containing ImageNet class id and name. The updated ``app.py`` file will
70
# be now:
71
72
from flask import Flask, jsonify
73
app = Flask(__name__)
74
75
@app.route('/predict', methods=['POST'])
76
def predict():
77
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
78
79
80
######################################################################
81
# Inference
82
# -----------------
83
#
84
# In the next sections we will focus on writing the inference code. This will
85
# involve two parts, one where we prepare the image so that it can be fed
86
# to DenseNet and next, we will write the code to get the actual prediction
87
# from the model.
88
#
89
# Preparing the image
90
# ~~~~~~~~~~~~~~~~~~~
91
#
92
# DenseNet model requires the image to be of 3 channel RGB image of size
93
# 224 x 224. We will also normalize the image tensor with the required mean
94
# and standard deviation values. You can read more about it
95
# `here <https://pytorch.org/vision/stable/models.html>`_.
96
#
97
# We will use ``transforms`` from ``torchvision`` library and build a
98
# transform pipeline, which transforms our images as required. You
99
# can read more about transforms `here <https://pytorch.org/vision/stable/transforms.html>`_.
100
101
import io
102
103
import torchvision.transforms as transforms
104
from PIL import Image
105
106
def transform_image(image_bytes):
107
my_transforms = transforms.Compose([transforms.Resize(255),
108
transforms.CenterCrop(224),
109
transforms.ToTensor(),
110
transforms.Normalize(
111
[0.485, 0.456, 0.406],
112
[0.229, 0.224, 0.225])])
113
image = Image.open(io.BytesIO(image_bytes))
114
return my_transforms(image).unsqueeze(0)
115
116
######################################################################
117
# The above method takes image data in bytes, applies the series of transforms
118
# and returns a tensor. To test the above method, read an image file in
119
# bytes mode (first replacing `../_static/img/sample_file.jpeg` with the actual
120
# path to the file on your computer) and see if you get a tensor back:
121
122
with open("../_static/img/sample_file.jpeg", 'rb') as f:
123
image_bytes = f.read()
124
tensor = transform_image(image_bytes=image_bytes)
125
print(tensor)
126
127
######################################################################
128
# Prediction
129
# ~~~~~~~~~~~~~~~~~~~
130
#
131
# Now will use a pretrained DenseNet 121 model to predict the image class. We
132
# will use one from ``torchvision`` library, load the model and get an
133
# inference. While we'll be using a pretrained model in this example, you can
134
# use this same approach for your own models. See more about loading your
135
# models in this :doc:`tutorial </beginner/saving_loading_models>`.
136
137
from torchvision import models
138
139
# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
140
model = models.densenet121(weights='IMAGENET1K_V1')
141
# Since we are using our model only for inference, switch to `eval` mode:
142
model.eval()
143
144
145
def get_prediction(image_bytes):
146
tensor = transform_image(image_bytes=image_bytes)
147
outputs = model.forward(tensor)
148
_, y_hat = outputs.max(1)
149
return y_hat
150
151
######################################################################
152
# The tensor ``y_hat`` will contain the index of the predicted class id.
153
# However, we need a human readable class name. For that we need a class id
154
# to name mapping. Download
155
# `this file <https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json>`_
156
# as ``imagenet_class_index.json`` and remember where you saved it (or, if you
157
# are following the exact steps in this tutorial, save it in
158
# `tutorials/_static`). This file contains the mapping of ImageNet class id to
159
# ImageNet class name. We will load this JSON file and get the class name of
160
# the predicted index.
161
162
import json
163
164
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
165
166
def get_prediction(image_bytes):
167
tensor = transform_image(image_bytes=image_bytes)
168
outputs = model.forward(tensor)
169
_, y_hat = outputs.max(1)
170
predicted_idx = str(y_hat.item())
171
return imagenet_class_index[predicted_idx]
172
173
174
######################################################################
175
# Before using ``imagenet_class_index`` dictionary, first we will convert
176
# tensor value to a string value, since the keys in the
177
# ``imagenet_class_index`` dictionary are strings.
178
# We will test our above method:
179
180
181
with open("../_static/img/sample_file.jpeg", 'rb') as f:
182
image_bytes = f.read()
183
print(get_prediction(image_bytes=image_bytes))
184
185
######################################################################
186
# You should get a response like this:
187
188
['n02124075', 'Egyptian_cat']
189
190
######################################################################
191
# The first item in array is ImageNet class id and second item is the human
192
# readable name.
193
#
194
195
######################################################################
196
# Integrating the model in our API Server
197
# ---------------------------------------
198
#
199
# In this final part we will add our model to our Flask API server. Since
200
# our API server is supposed to take an image file, we will update our ``predict``
201
# method to read files from the requests:
202
#
203
# .. code-block:: python
204
#
205
# from flask import request
206
#
207
# @app.route('/predict', methods=['POST'])
208
# def predict():
209
# if request.method == 'POST':
210
# # we will get the file from the request
211
# file = request.files['file']
212
# # convert that to bytes
213
# img_bytes = file.read()
214
# class_id, class_name = get_prediction(image_bytes=img_bytes)
215
# return jsonify({'class_id': class_id, 'class_name': class_name})
216
#
217
#
218
######################################################################
219
# The ``app.py`` file is now complete. Following is the full version; replace
220
# the paths with the paths where you saved your files and it should run:
221
#
222
# .. code-block:: python
223
#
224
# import io
225
# import json
226
#
227
# from torchvision import models
228
# import torchvision.transforms as transforms
229
# from PIL import Image
230
# from flask import Flask, jsonify, request
231
#
232
#
233
# app = Flask(__name__)
234
# imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
235
# model = models.densenet121(weights='IMAGENET1K_V1')
236
# model.eval()
237
#
238
#
239
# def transform_image(image_bytes):
240
# my_transforms = transforms.Compose([transforms.Resize(255),
241
# transforms.CenterCrop(224),
242
# transforms.ToTensor(),
243
# transforms.Normalize(
244
# [0.485, 0.456, 0.406],
245
# [0.229, 0.224, 0.225])])
246
# image = Image.open(io.BytesIO(image_bytes))
247
# return my_transforms(image).unsqueeze(0)
248
#
249
#
250
# def get_prediction(image_bytes):
251
# tensor = transform_image(image_bytes=image_bytes)
252
# outputs = model.forward(tensor)
253
# _, y_hat = outputs.max(1)
254
# predicted_idx = str(y_hat.item())
255
# return imagenet_class_index[predicted_idx]
256
#
257
#
258
# @app.route('/predict', methods=['POST'])
259
# def predict():
260
# if request.method == 'POST':
261
# file = request.files['file']
262
# img_bytes = file.read()
263
# class_id, class_name = get_prediction(image_bytes=img_bytes)
264
# return jsonify({'class_id': class_id, 'class_name': class_name})
265
#
266
#
267
# if __name__ == '__main__':
268
# app.run()
269
#
270
#
271
######################################################################
272
# Let's test our web server! Run:
273
#
274
# .. code-block:: sh
275
#
276
# FLASK_ENV=development FLASK_APP=app.py flask run
277
#
278
#######################################################################
279
# We can use the
280
# `requests <https://pypi.org/project/requests/>`_
281
# library to send a POST request to our app:
282
#
283
# .. code-block:: python
284
#
285
# import requests
286
#
287
# resp = requests.post("http://localhost:5000/predict",
288
# files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
289
#
290
291
#######################################################################
292
# Printing `resp.json()` will now show the following:
293
#
294
# .. code-block:: sh
295
#
296
# {"class_id": "n02124075", "class_name": "Egyptian_cat"}
297
#
298
######################################################################
299
# Next steps
300
# --------------
301
#
302
# The server we wrote is quite trivial and may not do everything
303
# you need for your production application. So, here are some things you
304
# can do to make it better:
305
#
306
# - The endpoint ``/predict`` assumes that always there will be a image file
307
# in the request. This may not hold true for all requests. Our user may
308
# send image with a different parameter or send no images at all.
309
#
310
# - The user may send non-image type files too. Since we are not handling
311
# errors, this will break our server. Adding an explicit error handing
312
# path that will throw an exception would allow us to better handle
313
# the bad inputs
314
#
315
# - Even though the model can recognize a large number of classes of images,
316
# it may not be able to recognize all images. Enhance the implementation
317
# to handle cases when the model does not recognize anything in the image.
318
#
319
# - We run the Flask server in the development mode, which is not suitable for
320
# deploying in production. You can check out `this tutorial <https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/>`_
321
# for deploying a Flask server in production.
322
#
323
# - You can also add a UI by creating a page with a form which takes the image and
324
# displays the prediction. Check out the `demo <https://pytorch-imagenet.herokuapp.com/>`_
325
# of a similar project and its `source code <https://github.com/avinassh/pytorch-flask-api-heroku>`_.
326
#
327
# - In this tutorial, we only showed how to build a service that could return predictions for
328
# a single image at a time. We could modify our service to be able to return predictions for
329
# multiple images at once. In addition, the `service-streamer <https://github.com/ShannonAI/service-streamer>`_
330
# library automatically queues requests to your service and samples them into mini-batches
331
# that can be fed into your model. You can check out `this tutorial <https://github.com/ShannonAI/service-streamer/wiki/Vision-Recognition-Service-with-Flask-and-service-streamer>`_.
332
#
333
# - Finally, we encourage you to check out our other tutorials on deploying PyTorch models
334
# linked-to at the top of the page.
335
#
336
337