Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/serving_tutorial.py
3951 views
1
"""
2
Serve PyTorch models at scale with Ray Serve
3
============================================
4
5
**Author:** `Ricardo Decal <https://github.com/crypdick>`__
6
7
This tutorial shows how to deploy a PyTorch model using Ray Serve with
8
production-ready features.
9
10
.. grid:: 2
11
12
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
13
:class-card: card-prerequisites
14
15
* How to create a production-ready PyTorch model deployment that can scale to thousands of nodes and GPUs
16
* How to configure an HTTP endpoint for the deployment using FastAPI
17
* Enable dynamic request batching for higher throughput
18
* Configure autoscaling and per-replica CPU/GPU resource allocation
19
* Load test the service with concurrent requests and monitor it with the Ray dashboard
20
* How Ray Serve deployments can self-heal from failures
21
* Ray Serve's advanced features like model multiplexing and model composition.
22
23
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
24
:class-card: card-prerequisites
25
26
* PyTorch v2.9+ and ``torchvision``
27
* Ray Serve (``ray[serve]``) v2.52.1+
28
* A GPU is recommended for higher throughput but is not required
29
30
`Ray Serve <https://docs.ray.io/en/latest/serve/index.html>`__ is a
31
scalable framework for serving machine learning models in production.
32
It’s built on top of `Ray <https://docs.ray.io/en/latest/index.html>`__,
33
which is a unified framework for scaling AI and Python applications that
34
simplifies the complexities of distributed computing. Ray is also open
35
source and part of the PyTorch Foundation.
36
37
38
Setup
39
-----
40
41
To install the dependencies, run:
42
43
.. code-block:: bash
44
45
pip install "ray[serve]" torch torchvision
46
47
"""
48
49
######################################################################
50
# Start by importing the required libraries:
51
52
import asyncio
53
import time
54
from typing import Any
55
56
from fastapi import FastAPI
57
from pydantic import BaseModel
58
import aiohttp
59
import numpy as np
60
import torch
61
import torch.nn as nn
62
from ray import serve
63
from torchvision.transforms import v2
64
65
######################################################################
66
# Define a PyTorch model
67
# ----------------------
68
#
69
# Define a simple convolutional neural network for the MNIST digit
70
# classification dataset:
71
72
class MNISTNet(nn.Module):
73
def __init__(self):
74
super().__init__()
75
self.conv1 = nn.Conv2d(1, 32, 3, 1)
76
self.dropout1 = nn.Dropout(0.25)
77
self.conv2 = nn.Conv2d(32, 64, 3, 1)
78
self.fc1 = nn.Linear(9216, 128)
79
self.dropout2 = nn.Dropout(0.5)
80
self.fc2 = nn.Linear(128, 10)
81
82
def forward(self, x):
83
x = self.conv1(x)
84
x = nn.functional.relu(x)
85
x = self.conv2(x)
86
x = nn.functional.relu(x)
87
x = nn.functional.max_pool2d(x, 2)
88
x = self.dropout1(x)
89
x = torch.flatten(x, 1)
90
x = self.fc1(x)
91
x = nn.functional.relu(x)
92
x = self.dropout2(x)
93
x = self.fc2(x)
94
return nn.functional.log_softmax(x, dim=1)
95
96
######################################################################
97
# Define the Ray Serve deployment
98
# -------------------------------
99
#
100
# To deploy this model with Ray Serve, wrap the model in a Python class
101
# and decorate it with ``@serve.deployment``.
102
#
103
# Processing requests in batches is more efficient than processing
104
# requests one by one, especially when using GPUs. Ray Serve provides
105
# built-in support for **dynamic request batching**, where individual
106
# incoming requests are opportunistically batched. Enable dynamic batching
107
# using the ``@serve.batch`` decorator as shown in the following code:
108
109
app = FastAPI()
110
111
class ImageRequest(BaseModel): # Used for request validation and generating API documentation
112
image: list[list[float]] | list[list[list[float]]] # 2D or 3D array
113
114
@serve.deployment
115
@serve.ingress(app)
116
class MNISTClassifier:
117
def __init__(self):
118
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
self.model = MNISTNet().to(self.device)
120
# Define the transformation pipeline for the input images.
121
self.transform = v2.Compose([
122
v2.ToImage(),
123
v2.ToDtype(torch.float32, scale=True),
124
# Mean and standard deviation of the MNIST training subset.
125
v2.Normalize(mean=[0.1307], std=[0.3013]),
126
])
127
128
self.model.eval()
129
130
# batch_wait_timeout_s is the maximum time to wait for a full batch,
131
# trading off latency for throughput.
132
@serve.batch(max_batch_size=128, batch_wait_timeout_s=0.1)
133
async def predict_batch(self, images: list[np.ndarray]) -> list[dict[str, Any]]:
134
# Stack all images into a single tensor.
135
batch_tensor = torch.cat([
136
self.transform(img).unsqueeze(0)
137
for img in images
138
]).to(self.device).float()
139
140
# Single forward pass on the entire batch at once.
141
with torch.no_grad():
142
logits = self.model(batch_tensor)
143
predictions = torch.argmax(logits, dim=1).cpu().numpy()
144
145
# Unbatch the results and preserve their original order.
146
return [
147
{
148
"predicted_label": int(pred),
149
"logits": logit.cpu().numpy().tolist()
150
}
151
for pred, logit in zip(predictions, logits)
152
]
153
154
@app.post("/")
155
async def handle_request(self, request: ImageRequest):
156
"""Handle an incoming HTTP request using FastAPI.
157
158
Inputs are automatically validated using the Pydantic model.
159
"""
160
# Process the single request.
161
image_array = np.array(request.image)
162
163
# Ray Serve's @serve.batch automatically batches requests.
164
result = await self.predict_batch(image_array)
165
166
return result
167
168
169
######################################################################
170
# This is a FastAPI app, which extends Ray Serve with features like
171
# automatic request validation with Pydantic, auto-generated OpenAPI-style
172
# API documentation, and more.
173
#
174
# Configure autoscaling and resource allocation
175
# ---------------------------------------------
176
#
177
# In production, traffic can vary significantly. Ray Serve’s
178
# **autoscaling** feature automatically adjusts the number of replicas
179
# based on traffic load, ensuring you have enough capacity during peaks
180
# while saving resources during quiet periods. Ray Serve scales to very
181
# large deployments with thousands of nodes and replicas.
182
#
183
# You can also specify **resource allocation** per replica, such as the
184
# number of CPUs or GPUs. Ray Serve handles the orchestration of these
185
# resources across your cluster. Ray also supports `fractional
186
# GPUs <https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html#fractional-accelerators>`__,
187
# allowing multiple replicas to share a single GPU when models are small
188
# enough to fit in memory together.
189
#
190
# The following is a sample configuration with autoscaling and resource
191
# allocation:
192
193
num_cpus_per_replica = 1
194
num_gpus_per_replica = 1 # Set to 0 to run the model on CPUs instead of GPUs.
195
mnist_app = MNISTClassifier.options(
196
autoscaling_config={
197
"target_ongoing_requests": 50, # Target 50 ongoing requests per replica.
198
"min_replicas": 1, # Keep at least 1 replica alive.
199
"max_replicas": 80, # Scale up to 80 replicas to maintain target_ongoing_requests.
200
"upscale_delay_s": 5, # Wait 5s before scaling up.
201
"downscale_delay_s": 30, # Wait 30s before scaling down.
202
},
203
# Max invocations to handle_request per replica to process simultaneously.
204
# Requests exceeding this limit are queued by the router until capacity is available.
205
max_ongoing_requests=200,
206
# Max queue size for requests that exceed max_ongoing_requests.
207
# If the queue is full, future requests are backpressured with errors until space is available.
208
# -1 means the queue can grow until cluster memory is exhausted.
209
max_queued_requests=-1,
210
# Set the resources per replica.
211
ray_actor_options={"num_cpus": num_cpus_per_replica, "num_gpus": num_gpus_per_replica}
212
).bind()
213
214
######################################################################
215
# The app is ready to deploy. Suppose you ran this on a cluster of 10
216
# machines, each with 4 GPUs. With ``num_gpus=0.5``, Ray schedules 2
217
# replicas per GPU, giving you 80 replicas across the cluster. This
218
# configuration permits the deployment to elastically scale up to 80
219
# replicas as needed to handle traffic spikes and scale back down to 1
220
# replica when traffic subsides.
221
#
222
# Test the endpoint with concurrent requests
223
# ------------------------------------------
224
#
225
# To deploy the app, use the ``serve.run`` function:
226
227
# Start the Ray Serve application.
228
handle = serve.run(mnist_app, name="mnist_classifier")
229
230
######################################################################
231
# You will see output similar to:
232
#
233
# .. code-block:: sh
234
#
235
# Started Serve in namespace "serve".
236
# Registering autoscaling state for deployment Deployment(name='MNISTClassifier', app='mnist_classifier')
237
# Deploying new version of Deployment(name='MNISTClassifier', app='mnist_classifier') (initial target replicas: 1).
238
# Proxy starting on node ... (HTTP port: 8000).
239
# Got updated endpoints: {}.
240
# Got updated endpoints: {Deployment(name='MNISTClassifier', app='mnist_classifier'): EndpointInfo(route='/', app_is_cross_language=False, route_patterns=None)}.
241
# Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x73a53c52c250>.
242
# Adding 1 replica to Deployment(name='MNISTClassifier', app='mnist_classifier').
243
# Got updated endpoints: {Deployment(name='MNISTClassifier', app='mnist_classifier'): EndpointInfo(route='/', app_is_cross_language=False, route_patterns=['/', '/docs', '/docs/oauth2-redirect', '/openapi.json', '/redoc'])}.
244
# Application 'mnist_classifier' is ready at http://127.0.0.1:8000/.
245
#
246
# The app is now listening for requests on port 8000.
247
#
248
# To test the deployment, you can send many requests concurrently using
249
# ``aiohttp``. The following code demonstrates how to send 1000 concurrent
250
# requests to the app:
251
252
async def send_single_request(session, url, data):
253
async with session.post(url, json=data) as response:
254
return await response.json()
255
256
async def send_concurrent_requests(num_requests):
257
image = np.random.rand(28, 28).tolist()
258
259
print(f"Sending {num_requests} concurrent requests...")
260
async with aiohttp.ClientSession() as session:
261
tasks = [
262
send_single_request(session, url="http://localhost:8000/", data={"image": image})
263
for _ in range(num_requests)
264
]
265
responses = await asyncio.gather(*tasks)
266
267
return responses
268
269
# Run the concurrent requests.
270
start_time = time.time()
271
responses = asyncio.run(send_concurrent_requests(1000))
272
elapsed = time.time() - start_time
273
274
print(f"Processed {len(responses)} requests in {elapsed:.2f} seconds")
275
print(f"Throughput: {len(responses)/elapsed:.2f} requests/second")
276
277
######################################################################
278
# Ray Serve automatically buffers and load balances requests across the
279
# replicas.
280
#
281
# Fault tolerance
282
# ---------------
283
#
284
# In production, process and machine failures are inevitable. Ray Serve is designed
285
# so that each major component in the Serve stack (the controller, replicas, and proxies) can fail
286
# and recover while your application continues to handle traffic.
287
#
288
# Serve can also recover from larger infrastructure failures, such as entire nodes or pods
289
# failing. Serve can even recover from head node failures, or the entire head pod if
290
# deploying on KubeRay.
291
#
292
# For more information about Ray Serve's fault tolerance, see the
293
# `Ray Serve fault-tolerance guide <https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html>`__.
294
#
295
# Monitor the deployment
296
# ----------------------
297
#
298
# Monitoring is critical when running large-scale deployments. The `Ray
299
# dashboard <https://docs.ray.io/en/latest/ray-observability/getting-started.html>`__
300
# displays Serve metrics like request throughput, latency, and error
301
# rates. It also shows cluster resource usage, replica status, and overall
302
# deployment health in real time. The dashboard also lets you inspect logs
303
# from individual replicas across the cluster.
304
#
305
# For debugging, Ray offers `distributed debugging
306
# tools <https://docs.ray.io/en/latest/ray-observability/index.html>`__
307
# that let you attach a debugger to running replicas across the cluster.
308
# For more information, see the `Ray Serve monitoring
309
# documentation <https://docs.ray.io/en/latest/serve/monitoring.html>`__.
310
#
311
# Conclusion
312
# ------------
313
#
314
# In this tutorial, you:
315
#
316
# - Deployed a PyTorch model using Ray Serve with production best
317
# practices.
318
# - Enabled **dynamic request batching** to optimize performance.
319
# - Configured **autoscaling** and **fractional GPU allocation** to
320
# efficiently scale across a cluster.
321
# - Tested the service with concurrent asynchronous requests.
322
#
323
# Further reading
324
# ---------------
325
#
326
# Ray Serve has more production features that are out of scope for this
327
# tutorial but are worth checking out:
328
#
329
# - Specialized `large language model (LLM) serving
330
# APIs <https://docs.ray.io/en/latest/serve/llm/index.html>`__ that
331
# handle complexities like managing key-value (KV) caches and continuous
332
# batching.
333
# - `Model
334
# multiplexing <https://docs.ray.io/en/latest/serve/model-multiplexing.html>`__
335
# to dynamically load and serve many different models on the same
336
# deployment. This is useful for serving per-user fine-tuned models, for
337
# example.
338
# - `Composed
339
# deployments <https://docs.ray.io/en/latest/serve/model_composition.html>`__
340
# to orchestrate multiple deployments into a single app.
341
#
342
# For more information, see the `Ray Serve
343
# documentation <https://docs.ray.io/en/latest/serve/index.html>`__ and
344
# `Ray Serve
345
# examples <https://docs.ray.io/en/latest/serve/examples.html>`__.
346
347