Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/hyperparameter_tuning_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Hyperparameter tuning with Ray Tune3===================================45Hyperparameter tuning can make the difference between an average model and a highly6accurate one. Often simple things like choosing a different learning rate or changing7a network layer size can have a dramatic impact on your model performance.89Fortunately, there are tools that help with finding the best combination of parameters.10`Ray Tune <https://docs.ray.io/en/latest/tune.html>`_ is an industry standard tool for11distributed hyperparameter tuning. Ray Tune includes the latest hyperparameter search12algorithms, integrates with various analysis libraries, and natively13supports distributed training through `Ray's distributed machine learning engine14<https://ray.io/>`_.1516In this tutorial, we will show you how to integrate Ray Tune into your PyTorch17training workflow. We will extend `this tutorial from the PyTorch documentation18<https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_ for training19a CIFAR10 image classifier.2021As you will see, we only need to add some slight modifications. In particular, we22need to23241. wrap data loading and training in functions,252. make some network parameters configurable,263. add checkpointing (optional),274. and define the search space for the model tuning2829|3031To run this tutorial, please make sure the following packages are32installed:3334- ``ray[tune]``: Distributed hyperparameter tuning library35- ``torchvision``: For the data transformers3637Setup / Imports38---------------39Let's start with the imports:40"""41from functools import partial42import os43import tempfile44from pathlib import Path45import torch46import torch.nn as nn47import torch.nn.functional as F48import torch.optim as optim49from torch.utils.data import random_split50import torchvision51import torchvision.transforms as transforms52# sphinx_gallery_start_ignore53# Fixes ``AttributeError: '_LoggingTee' object has no attribute 'fileno'``.54# This is only needed to run with sphinx-build.55import sys56if not hasattr(sys.stdout, "encoding"):57sys.stdout.encoding = "latin1"58sys.stdout.fileno = lambda: 059# sphinx_gallery_end_ignore60from ray import tune61from ray import train62from ray.train import Checkpoint, get_checkpoint63from ray.tune.schedulers import ASHAScheduler64import ray.cloudpickle as pickle6566######################################################################67# Most of the imports are needed for building the PyTorch model. Only the last68# imports are for Ray Tune.69#70# Data loaders71# ------------72# We wrap the data loaders in their own function and pass a global data directory.73# This way we can share a data directory between different trials.747576def load_data(data_dir="./data"):77transform = transforms.Compose(78[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]79)8081trainset = torchvision.datasets.CIFAR10(82root=data_dir, train=True, download=True, transform=transform83)8485testset = torchvision.datasets.CIFAR10(86root=data_dir, train=False, download=True, transform=transform87)8889return trainset, testset909192######################################################################93# Configurable neural network94# ---------------------------95# We can only tune those parameters that are configurable.96# In this example, we can specify97# the layer sizes of the fully connected layers:9899100class Net(nn.Module):101def __init__(self, l1=120, l2=84):102super(Net, self).__init__()103self.conv1 = nn.Conv2d(3, 6, 5)104self.pool = nn.MaxPool2d(2, 2)105self.conv2 = nn.Conv2d(6, 16, 5)106self.fc1 = nn.Linear(16 * 5 * 5, l1)107self.fc2 = nn.Linear(l1, l2)108self.fc3 = nn.Linear(l2, 10)109110def forward(self, x):111x = self.pool(F.relu(self.conv1(x)))112x = self.pool(F.relu(self.conv2(x)))113x = torch.flatten(x, 1) # flatten all dimensions except batch114x = F.relu(self.fc1(x))115x = F.relu(self.fc2(x))116x = self.fc3(x)117return x118119120######################################################################121# The train function122# ------------------123# Now it gets interesting, because we introduce some changes to the example `from the PyTorch124# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.125#126# We wrap the training script in a function ``train_cifar(config, data_dir=None)``.127# The ``config`` parameter will receive the hyperparameters we would like to128# train with. The ``data_dir`` specifies the directory where we load and store the data,129# so that multiple runs can share the same data source.130# We also load the model and optimizer state at the start of the run, if a checkpoint131# is provided. Further down in this tutorial you will find information on how132# to save the checkpoint and what it is used for.133#134# .. code-block:: python135#136# net = Net(config["l1"], config["l2"])137#138# checkpoint = get_checkpoint()139# if checkpoint:140# with checkpoint.as_directory() as checkpoint_dir:141# data_path = Path(checkpoint_dir) / "data.pkl"142# with open(data_path, "rb") as fp:143# checkpoint_state = pickle.load(fp)144# start_epoch = checkpoint_state["epoch"]145# net.load_state_dict(checkpoint_state["net_state_dict"])146# optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])147# else:148# start_epoch = 0149#150# The learning rate of the optimizer is made configurable, too:151#152# .. code-block:: python153#154# optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)155#156# We also split the training data into a training and validation subset. We thus train on157# 80% of the data and calculate the validation loss on the remaining 20%. The batch sizes158# with which we iterate through the training and test sets are configurable as well.159#160# Adding (multi) GPU support with DataParallel161# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~162# Image classification benefits largely from GPUs. Luckily, we can continue to use163# PyTorch's abstractions in Ray Tune. Thus, we can wrap our model in ``nn.DataParallel``164# to support data parallel training on multiple GPUs:165#166# .. code-block:: python167#168# device = "cpu"169# if torch.cuda.is_available():170# device = "cuda:0"171# if torch.cuda.device_count() > 1:172# net = nn.DataParallel(net)173# net.to(device)174#175# By using a ``device`` variable we make sure that training also works when we have176# no GPUs available. PyTorch requires us to send our data to the GPU memory explicitly,177# like this:178#179# .. code-block:: python180#181# for i, data in enumerate(trainloader, 0):182# inputs, labels = data183# inputs, labels = inputs.to(device), labels.to(device)184#185# The code now supports training on CPUs, on a single GPU, and on multiple GPUs. Notably, Ray186# also supports `fractional GPUs <https://docs.ray.io/en/master/using-ray-with-gpus.html#fractional-gpus>`_187# so we can share GPUs among trials, as long as the model still fits on the GPU memory. We'll come back188# to that later.189#190# Communicating with Ray Tune191# ~~~~~~~~~~~~~~~~~~~~~~~~~~~192#193# The most interesting part is the communication with Ray Tune:194#195# .. code-block:: python196#197# checkpoint_data = {198# "epoch": epoch,199# "net_state_dict": net.state_dict(),200# "optimizer_state_dict": optimizer.state_dict(),201# }202# with tempfile.TemporaryDirectory() as checkpoint_dir:203# data_path = Path(checkpoint_dir) / "data.pkl"204# with open(data_path, "wb") as fp:205# pickle.dump(checkpoint_data, fp)206#207# checkpoint = Checkpoint.from_directory(checkpoint_dir)208# train.report(209# {"loss": val_loss / val_steps, "accuracy": correct / total},210# checkpoint=checkpoint,211# )212#213# Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,214# we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics215# to decide which hyperparameter configuration lead to the best results. These metrics216# can also be used to stop bad performing trials early in order to avoid wasting217# resources on those trials.218#219# The checkpoint saving is optional, however, it is necessary if we wanted to use advanced220# schedulers like221# `Population Based Training <https://docs.ray.io/en/latest/tune/examples/pbt_guide.html>`_.222# Also, by saving the checkpoint we can later load the trained models and validate them223# on a test set. Lastly, saving checkpoints is useful for fault tolerance, and it allows224# us to interrupt training and continue training later.225#226# Full training function227# ~~~~~~~~~~~~~~~~~~~~~~228#229# The full code example looks like this:230231232def train_cifar(config, data_dir=None):233net = Net(config["l1"], config["l2"])234235device = "cpu"236if torch.cuda.is_available():237device = "cuda:0"238if torch.cuda.device_count() > 1:239net = nn.DataParallel(net)240net.to(device)241242criterion = nn.CrossEntropyLoss()243optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)244245checkpoint = get_checkpoint()246if checkpoint:247with checkpoint.as_directory() as checkpoint_dir:248data_path = Path(checkpoint_dir) / "data.pkl"249with open(data_path, "rb") as fp:250checkpoint_state = pickle.load(fp)251start_epoch = checkpoint_state["epoch"]252net.load_state_dict(checkpoint_state["net_state_dict"])253optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])254else:255start_epoch = 0256257trainset, testset = load_data(data_dir)258259test_abs = int(len(trainset) * 0.8)260train_subset, val_subset = random_split(261trainset, [test_abs, len(trainset) - test_abs]262)263264trainloader = torch.utils.data.DataLoader(265train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8266)267valloader = torch.utils.data.DataLoader(268val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8269)270271for epoch in range(start_epoch, 10): # loop over the dataset multiple times272running_loss = 0.0273epoch_steps = 0274for i, data in enumerate(trainloader, 0):275# get the inputs; data is a list of [inputs, labels]276inputs, labels = data277inputs, labels = inputs.to(device), labels.to(device)278279# zero the parameter gradients280optimizer.zero_grad()281282# forward + backward + optimize283outputs = net(inputs)284loss = criterion(outputs, labels)285loss.backward()286optimizer.step()287288# print statistics289running_loss += loss.item()290epoch_steps += 1291if i % 2000 == 1999: # print every 2000 mini-batches292print(293"[%d, %5d] loss: %.3f"294% (epoch + 1, i + 1, running_loss / epoch_steps)295)296running_loss = 0.0297298# Validation loss299val_loss = 0.0300val_steps = 0301total = 0302correct = 0303for i, data in enumerate(valloader, 0):304with torch.no_grad():305inputs, labels = data306inputs, labels = inputs.to(device), labels.to(device)307308outputs = net(inputs)309_, predicted = torch.max(outputs.data, 1)310total += labels.size(0)311correct += (predicted == labels).sum().item()312313loss = criterion(outputs, labels)314val_loss += loss.cpu().numpy()315val_steps += 1316317checkpoint_data = {318"epoch": epoch,319"net_state_dict": net.state_dict(),320"optimizer_state_dict": optimizer.state_dict(),321}322with tempfile.TemporaryDirectory() as checkpoint_dir:323data_path = Path(checkpoint_dir) / "data.pkl"324with open(data_path, "wb") as fp:325pickle.dump(checkpoint_data, fp)326327checkpoint = Checkpoint.from_directory(checkpoint_dir)328train.report(329{"loss": val_loss / val_steps, "accuracy": correct / total},330checkpoint=checkpoint,331)332333print("Finished Training")334335336######################################################################337# As you can see, most of the code is adapted directly from the original example.338#339# Test set accuracy340# -----------------341# Commonly the performance of a machine learning model is tested on a hold-out test342# set with data that has not been used for training the model. We also wrap this in a343# function:344345346def test_accuracy(net, device="cpu"):347trainset, testset = load_data()348349testloader = torch.utils.data.DataLoader(350testset, batch_size=4, shuffle=False, num_workers=2351)352353correct = 0354total = 0355with torch.no_grad():356for data in testloader:357images, labels = data358images, labels = images.to(device), labels.to(device)359outputs = net(images)360_, predicted = torch.max(outputs.data, 1)361total += labels.size(0)362correct += (predicted == labels).sum().item()363364return correct / total365366367######################################################################368# The function also expects a ``device`` parameter, so we can do the369# test set validation on a GPU.370#371# Configuring the search space372# ----------------------------373# Lastly, we need to define Ray Tune's search space. Here is an example:374#375# .. code-block:: python376#377# config = {378# "l1": tune.choice([2 ** i for i in range(9)]),379# "l2": tune.choice([2 ** i for i in range(9)]),380# "lr": tune.loguniform(1e-4, 1e-1),381# "batch_size": tune.choice([2, 4, 8, 16])382# }383#384# The ``tune.choice()`` accepts a list of values that are uniformly sampled from.385# In this example, the ``l1`` and ``l2`` parameters386# should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.387# The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,388# the batch size is a choice between 2, 4, 8, and 16.389#390# At each trial, Ray Tune will now randomly sample a combination of parameters from these391# search spaces. It will then train a number of models in parallel and find the best392# performing one among these. We also use the ``ASHAScheduler`` which will terminate bad393# performing trials early.394#395# We wrap the ``train_cifar`` function with ``functools.partial`` to set the constant396# ``data_dir`` parameter. We can also tell Ray Tune what resources should be397# available for each trial:398#399# .. code-block:: python400#401# gpus_per_trial = 2402# # ...403# result = tune.run(404# partial(train_cifar, data_dir=data_dir),405# resources_per_trial={"cpu": 8, "gpu": gpus_per_trial},406# config=config,407# num_samples=num_samples,408# scheduler=scheduler,409# checkpoint_at_end=True)410#411# You can specify the number of CPUs, which are then available e.g.412# to increase the ``num_workers`` of the PyTorch ``DataLoader`` instances. The selected413# number of GPUs are made visible to PyTorch in each trial. Trials do not have access to414# GPUs that haven't been requested for them - so you don't have to care about two trials415# using the same set of resources.416#417# Here we can also specify fractional GPUs, so something like ``gpus_per_trial=0.5`` is418# completely valid. The trials will then share GPUs among each other.419# You just have to make sure that the models still fit in the GPU memory.420#421# After training the models, we will find the best performing one and load the trained422# network from the checkpoint file. We then obtain the test set accuracy and report423# everything by printing.424#425# The full main function looks like this:426427428def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):429data_dir = os.path.abspath("./data")430load_data(data_dir)431config = {432"l1": tune.choice([2**i for i in range(9)]),433"l2": tune.choice([2**i for i in range(9)]),434"lr": tune.loguniform(1e-4, 1e-1),435"batch_size": tune.choice([2, 4, 8, 16]),436}437scheduler = ASHAScheduler(438metric="loss",439mode="min",440max_t=max_num_epochs,441grace_period=1,442reduction_factor=2,443)444result = tune.run(445partial(train_cifar, data_dir=data_dir),446resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},447config=config,448num_samples=num_samples,449scheduler=scheduler,450)451452best_trial = result.get_best_trial("loss", "min", "last")453print(f"Best trial config: {best_trial.config}")454print(f"Best trial final validation loss: {best_trial.last_result['loss']}")455print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")456457best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])458device = "cpu"459if torch.cuda.is_available():460device = "cuda:0"461if gpus_per_trial > 1:462best_trained_model = nn.DataParallel(best_trained_model)463best_trained_model.to(device)464465best_checkpoint = result.get_best_checkpoint(trial=best_trial, metric="accuracy", mode="max")466with best_checkpoint.as_directory() as checkpoint_dir:467data_path = Path(checkpoint_dir) / "data.pkl"468with open(data_path, "rb") as fp:469best_checkpoint_data = pickle.load(fp)470471best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])472test_acc = test_accuracy(best_trained_model, device)473print("Best trial test set accuracy: {}".format(test_acc))474475476if __name__ == "__main__":477# You can change the number of GPUs per trial here:478main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)479480481######################################################################482# If you run the code, an example output could look like this:483#484# .. code-block:: sh485#486# Number of trials: 10/10 (10 TERMINATED)487# +-----+--------------+------+------+-------------+--------+---------+------------+488# | ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |489# |-----+--------------+------+------+-------------+--------+---------+------------|490# | ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |491# | ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |492# | ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |493# | ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |494# | ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |495# | ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |496# | ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |497# | ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |498# | ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |499# | ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |500# +-----+--------------+------+------+-------------+--------+---------+------------+501#502# Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}503# Best trial final validation loss: 1.5310075663924216504# Best trial final validation accuracy: 0.4761505# Best trial test set accuracy: 0.4737506#507# Most trials have been stopped early in order to avoid wasting resources.508# The best performing trial achieved a validation accuracy of about 47%, which could509# be confirmed on the test set.510#511# So that's it! You can now tune the parameters of your PyTorch models.512513514