"""1Serve PyTorch models at scale with Ray Serve2============================================34**Author:** `Ricardo Decal <https://github.com/crypdick>`__56This tutorial shows how to deploy a PyTorch model using Ray Serve with7production-ready features.89.. grid:: 21011.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn12:class-card: card-prerequisites1314* How to create a production-ready PyTorch model deployment that can scale to thousands of nodes and GPUs15* How to configure an HTTP endpoint for the deployment using FastAPI16* Enable dynamic request batching for higher throughput17* Configure autoscaling and per-replica CPU/GPU resource allocation18* Load test the service with concurrent requests and monitor it with the Ray dashboard19* How Ray Serve deployments can self-heal from failures20* Ray Serve's advanced features like model multiplexing and model composition.2122.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites23:class-card: card-prerequisites2425* PyTorch v2.9+ and ``torchvision``26* Ray Serve (``ray[serve]``) v2.52.1+27* A GPU is recommended for higher throughput but is not required2829`Ray Serve <https://docs.ray.io/en/latest/serve/index.html>`__ is a30scalable framework for serving machine learning models in production.31It’s built on top of `Ray <https://docs.ray.io/en/latest/index.html>`__,32which is a unified framework for scaling AI and Python applications that33simplifies the complexities of distributed computing. Ray is also open34source and part of the PyTorch Foundation.353637Setup38-----3940To install the dependencies, run:4142.. code-block:: bash4344pip install "ray[serve]" torch torchvision4546"""4748######################################################################49# Start by importing the required libraries:5051import asyncio52import time53from typing import Any5455from fastapi import FastAPI56from pydantic import BaseModel57import aiohttp58import numpy as np59import torch60import torch.nn as nn61from ray import serve62from torchvision.transforms import v26364######################################################################65# Define a PyTorch model66# ----------------------67#68# Define a simple convolutional neural network for the MNIST digit69# classification dataset:7071class MNISTNet(nn.Module):72def __init__(self):73super().__init__()74self.conv1 = nn.Conv2d(1, 32, 3, 1)75self.dropout1 = nn.Dropout(0.25)76self.conv2 = nn.Conv2d(32, 64, 3, 1)77self.fc1 = nn.Linear(9216, 128)78self.dropout2 = nn.Dropout(0.5)79self.fc2 = nn.Linear(128, 10)8081def forward(self, x):82x = self.conv1(x)83x = nn.functional.relu(x)84x = self.conv2(x)85x = nn.functional.relu(x)86x = nn.functional.max_pool2d(x, 2)87x = self.dropout1(x)88x = torch.flatten(x, 1)89x = self.fc1(x)90x = nn.functional.relu(x)91x = self.dropout2(x)92x = self.fc2(x)93return nn.functional.log_softmax(x, dim=1)9495######################################################################96# Define the Ray Serve deployment97# -------------------------------98#99# To deploy this model with Ray Serve, wrap the model in a Python class100# and decorate it with ``@serve.deployment``.101#102# Processing requests in batches is more efficient than processing103# requests one by one, especially when using GPUs. Ray Serve provides104# built-in support for **dynamic request batching**, where individual105# incoming requests are opportunistically batched. Enable dynamic batching106# using the ``@serve.batch`` decorator as shown in the following code:107108app = FastAPI()109110class ImageRequest(BaseModel): # Used for request validation and generating API documentation111image: list[list[float]] | list[list[list[float]]] # 2D or 3D array112113@serve.deployment114@serve.ingress(app)115class MNISTClassifier:116def __init__(self):117self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")118self.model = MNISTNet().to(self.device)119# Define the transformation pipeline for the input images.120self.transform = v2.Compose([121v2.ToImage(),122v2.ToDtype(torch.float32, scale=True),123# Mean and standard deviation of the MNIST training subset.124v2.Normalize(mean=[0.1307], std=[0.3013]),125])126127self.model.eval()128129# batch_wait_timeout_s is the maximum time to wait for a full batch,130# trading off latency for throughput.131@serve.batch(max_batch_size=128, batch_wait_timeout_s=0.1)132async def predict_batch(self, images: list[np.ndarray]) -> list[dict[str, Any]]:133# Stack all images into a single tensor.134batch_tensor = torch.cat([135self.transform(img).unsqueeze(0)136for img in images137]).to(self.device).float()138139# Single forward pass on the entire batch at once.140with torch.no_grad():141logits = self.model(batch_tensor)142predictions = torch.argmax(logits, dim=1).cpu().numpy()143144# Unbatch the results and preserve their original order.145return [146{147"predicted_label": int(pred),148"logits": logit.cpu().numpy().tolist()149}150for pred, logit in zip(predictions, logits)151]152153@app.post("/")154async def handle_request(self, request: ImageRequest):155"""Handle an incoming HTTP request using FastAPI.156157Inputs are automatically validated using the Pydantic model.158"""159# Process the single request.160image_array = np.array(request.image)161162# Ray Serve's @serve.batch automatically batches requests.163result = await self.predict_batch(image_array)164165return result166167168######################################################################169# This is a FastAPI app, which extends Ray Serve with features like170# automatic request validation with Pydantic, auto-generated OpenAPI-style171# API documentation, and more.172#173# Configure autoscaling and resource allocation174# ---------------------------------------------175#176# In production, traffic can vary significantly. Ray Serve’s177# **autoscaling** feature automatically adjusts the number of replicas178# based on traffic load, ensuring you have enough capacity during peaks179# while saving resources during quiet periods. Ray Serve scales to very180# large deployments with thousands of nodes and replicas.181#182# You can also specify **resource allocation** per replica, such as the183# number of CPUs or GPUs. Ray Serve handles the orchestration of these184# resources across your cluster. Ray also supports `fractional185# GPUs <https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html#fractional-accelerators>`__,186# allowing multiple replicas to share a single GPU when models are small187# enough to fit in memory together.188#189# The following is a sample configuration with autoscaling and resource190# allocation:191192num_cpus_per_replica = 1193num_gpus_per_replica = 1 # Set to 0 to run the model on CPUs instead of GPUs.194mnist_app = MNISTClassifier.options(195autoscaling_config={196"target_ongoing_requests": 50, # Target 50 ongoing requests per replica.197"min_replicas": 1, # Keep at least 1 replica alive.198"max_replicas": 80, # Scale up to 80 replicas to maintain target_ongoing_requests.199"upscale_delay_s": 5, # Wait 5s before scaling up.200"downscale_delay_s": 30, # Wait 30s before scaling down.201},202# Max invocations to handle_request per replica to process simultaneously.203# Requests exceeding this limit are queued by the router until capacity is available.204max_ongoing_requests=200,205# Max queue size for requests that exceed max_ongoing_requests.206# If the queue is full, future requests are backpressured with errors until space is available.207# -1 means the queue can grow until cluster memory is exhausted.208max_queued_requests=-1,209# Set the resources per replica.210ray_actor_options={"num_cpus": num_cpus_per_replica, "num_gpus": num_gpus_per_replica}211).bind()212213######################################################################214# The app is ready to deploy. Suppose you ran this on a cluster of 10215# machines, each with 4 GPUs. With ``num_gpus=0.5``, Ray schedules 2216# replicas per GPU, giving you 80 replicas across the cluster. This217# configuration permits the deployment to elastically scale up to 80218# replicas as needed to handle traffic spikes and scale back down to 1219# replica when traffic subsides.220#221# Test the endpoint with concurrent requests222# ------------------------------------------223#224# To deploy the app, use the ``serve.run`` function:225226# Start the Ray Serve application.227handle = serve.run(mnist_app, name="mnist_classifier")228229######################################################################230# You will see output similar to:231#232# .. code-block:: sh233#234# Started Serve in namespace "serve".235# Registering autoscaling state for deployment Deployment(name='MNISTClassifier', app='mnist_classifier')236# Deploying new version of Deployment(name='MNISTClassifier', app='mnist_classifier') (initial target replicas: 1).237# Proxy starting on node ... (HTTP port: 8000).238# Got updated endpoints: {}.239# Got updated endpoints: {Deployment(name='MNISTClassifier', app='mnist_classifier'): EndpointInfo(route='/', app_is_cross_language=False, route_patterns=None)}.240# Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x73a53c52c250>.241# Adding 1 replica to Deployment(name='MNISTClassifier', app='mnist_classifier').242# Got updated endpoints: {Deployment(name='MNISTClassifier', app='mnist_classifier'): EndpointInfo(route='/', app_is_cross_language=False, route_patterns=['/', '/docs', '/docs/oauth2-redirect', '/openapi.json', '/redoc'])}.243# Application 'mnist_classifier' is ready at http://127.0.0.1:8000/.244#245# The app is now listening for requests on port 8000.246#247# To test the deployment, you can send many requests concurrently using248# ``aiohttp``. The following code demonstrates how to send 1000 concurrent249# requests to the app:250251async def send_single_request(session, url, data):252async with session.post(url, json=data) as response:253return await response.json()254255async def send_concurrent_requests(num_requests):256image = np.random.rand(28, 28).tolist()257258print(f"Sending {num_requests} concurrent requests...")259async with aiohttp.ClientSession() as session:260tasks = [261send_single_request(session, url="http://localhost:8000/", data={"image": image})262for _ in range(num_requests)263]264responses = await asyncio.gather(*tasks)265266return responses267268# Run the concurrent requests.269start_time = time.time()270responses = asyncio.run(send_concurrent_requests(1000))271elapsed = time.time() - start_time272273print(f"Processed {len(responses)} requests in {elapsed:.2f} seconds")274print(f"Throughput: {len(responses)/elapsed:.2f} requests/second")275276######################################################################277# Ray Serve automatically buffers and load balances requests across the278# replicas.279#280# Fault tolerance281# ---------------282#283# In production, process and machine failures are inevitable. Ray Serve is designed284# so that each major component in the Serve stack (the controller, replicas, and proxies) can fail285# and recover while your application continues to handle traffic.286#287# Serve can also recover from larger infrastructure failures, such as entire nodes or pods288# failing. Serve can even recover from head node failures, or the entire head pod if289# deploying on KubeRay.290#291# For more information about Ray Serve's fault tolerance, see the292# `Ray Serve fault-tolerance guide <https://docs.ray.io/en/master/serve/production-guide/fault-tolerance.html>`__.293#294# Monitor the deployment295# ----------------------296#297# Monitoring is critical when running large-scale deployments. The `Ray298# dashboard <https://docs.ray.io/en/latest/ray-observability/getting-started.html>`__299# displays Serve metrics like request throughput, latency, and error300# rates. It also shows cluster resource usage, replica status, and overall301# deployment health in real time. The dashboard also lets you inspect logs302# from individual replicas across the cluster.303#304# For debugging, Ray offers `distributed debugging305# tools <https://docs.ray.io/en/latest/ray-observability/index.html>`__306# that let you attach a debugger to running replicas across the cluster.307# For more information, see the `Ray Serve monitoring308# documentation <https://docs.ray.io/en/latest/serve/monitoring.html>`__.309#310# Conclusion311# ------------312#313# In this tutorial, you:314#315# - Deployed a PyTorch model using Ray Serve with production best316# practices.317# - Enabled **dynamic request batching** to optimize performance.318# - Configured **autoscaling** and **fractional GPU allocation** to319# efficiently scale across a cluster.320# - Tested the service with concurrent asynchronous requests.321#322# Further reading323# ---------------324#325# Ray Serve has more production features that are out of scope for this326# tutorial but are worth checking out:327#328# - Specialized `large language model (LLM) serving329# APIs <https://docs.ray.io/en/latest/serve/llm/index.html>`__ that330# handle complexities like managing key-value (KV) caches and continuous331# batching.332# - `Model333# multiplexing <https://docs.ray.io/en/latest/serve/model-multiplexing.html>`__334# to dynamically load and serve many different models on the same335# deployment. This is useful for serving per-user fine-tuned models, for336# example.337# - `Composed338# deployments <https://docs.ray.io/en/latest/serve/model_composition.html>`__339# to orchestrate multiple deployments into a single app.340#341# For more information, see the `Ray Serve342# documentation <https://docs.ray.io/en/latest/serve/index.html>`__ and343# `Ray Serve344# examples <https://docs.ray.io/en/latest/serve/examples.html>`__.345346347