Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/ax_multiobjective_nas_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Multi-Objective NAS with Ax3==================================================45**Authors:** `David Eriksson <https://github.com/dme65>`__,6`Max Balandat <https://github.com/Balandat>`__,7and the Adaptive Experimentation team at Meta.89In this tutorial, we show how to use `Ax <https://ax.dev/>`__ to run10multi-objective neural architecture search (NAS) for a simple neural11network model on the popular MNIST dataset. While the underlying12methodology would typically be used for more complicated models and13larger datasets, we opt for a tutorial that is easily runnable14end-to-end on a laptop in less than 20 minutes.1516In many NAS applications, there is a natural tradeoff between multiple17objectives of interest. For instance, when deploying models on-device18we may want to maximize model performance (for example, accuracy), while19simultaneously minimizing competing metrics like power consumption,20inference latency, or model size in order to satisfy deployment21constraints. Often, we may be able to reduce computational requirements22or latency of predictions substantially by accepting minimally lower23model performance. Principled methods for exploring such tradeoffs24efficiently are key enablers of scalable and sustainable AI, and have25many successful applications at Meta - see for instance our26`case study <https://research.facebook.com/blog/2021/07/optimizing-model-accuracy-and-latency-using-bayesian-multi-objective-neural-architecture-search/>`__27on a Natural Language Understanding model.2829In our example here, we will tune the widths of two hidden layers,30the learning rate, the dropout probability, the batch size, and the31number of training epochs. The goal is to trade off performance32(accuracy on the validation set) and model size (the number of33model parameters).3435This tutorial makes use of the following PyTorch libraries:3637- `PyTorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`__ (specifying the model and training loop)38- `TorchX <https://github.com/pytorch/torchx>`__ (for running training jobs remotely / asynchronously)39- `BoTorch <https://github.com/pytorch/botorch>`__ (the Bayesian Optimization library powering Ax's algorithms)40"""414243######################################################################44# Defining the TorchX App45# -----------------------46#47# Our goal is to optimize the PyTorch Lightning training job defined in48# `mnist_train_nas.py <https://github.com/pytorch/tutorials/tree/main/intermediate_source/mnist_train_nas.py>`__.49# To do this using TorchX, we write a helper function that takes in50# the values of the architecture and hyperparameters of the training51# job and creates a `TorchX AppDef <https://pytorch.org/torchx/latest/basics.html>`__52# with the appropriate settings.53#5455from pathlib import Path5657import torchx5859from torchx import specs60from torchx.components import utils616263def trainer(64log_path: str,65hidden_size_1: int,66hidden_size_2: int,67learning_rate: float,68epochs: int,69dropout: float,70batch_size: int,71trial_idx: int = -1,72) -> specs.AppDef:7374# define the log path so we can pass it to the TorchX ``AppDef``75if trial_idx >= 0:76log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()7778return utils.python(79# command line arguments to the training script80"--log_path",81log_path,82"--hidden_size_1",83str(hidden_size_1),84"--hidden_size_2",85str(hidden_size_2),86"--learning_rate",87str(learning_rate),88"--epochs",89str(epochs),90"--dropout",91str(dropout),92"--batch_size",93str(batch_size),94# other config options95name="trainer",96script="mnist_train_nas.py",97image=torchx.version.TORCHX_IMAGE,98)99100101######################################################################102# Setting up the Runner103# ---------------------104#105# Ax’s `Runner <https://ax.dev/api/core.html#ax.core.runner.Runner>`__106# abstraction allows writing interfaces to various backends.107# Ax already comes with Runner for TorchX, and so we just need to108# configure it. For the purpose of this tutorial we run jobs locally109# in a fully asynchronous fashion.110#111# In order to launch them on a cluster, you can instead specify a112# different TorchX scheduler and adjust the configuration appropriately.113# For example, if you have a Kubernetes cluster, you just need to change the114# scheduler from ``local_cwd`` to ``kubernetes``).115#116117118import tempfile119from ax.runners.torchx import TorchXRunner120121# Make a temporary dir to log our results into122log_dir = tempfile.mkdtemp()123124ax_runner = TorchXRunner(125tracker_base="/tmp/",126component=trainer,127# NOTE: To launch this job on a cluster instead of locally you can128# specify a different scheduler and adjust arguments appropriately.129scheduler="local_cwd",130component_const_params={"log_path": log_dir},131cfg={},132)133134######################################################################135# Setting up the ``SearchSpace``136# ------------------------------137#138# First, we define our search space. Ax supports both range parameters139# of type integer and float as well as choice parameters which can have140# non-numerical types such as strings.141# We will tune the hidden sizes, learning rate, dropout, and number of142# epochs as range parameters and tune the batch size as an ordered choice143# parameter to enforce it to be a power of 2.144#145146from ax.core import (147ChoiceParameter,148ParameterType,149RangeParameter,150SearchSpace,151)152153parameters = [154# NOTE: In a real-world setting, hidden_size_1 and hidden_size_2155# should probably be powers of 2, but in our simple example this156# would mean that ``num_params`` can't take on that many values, which157# in turn makes the Pareto frontier look pretty weird.158RangeParameter(159name="hidden_size_1",160lower=16,161upper=128,162parameter_type=ParameterType.INT,163log_scale=True,164),165RangeParameter(166name="hidden_size_2",167lower=16,168upper=128,169parameter_type=ParameterType.INT,170log_scale=True,171),172RangeParameter(173name="learning_rate",174lower=1e-4,175upper=1e-2,176parameter_type=ParameterType.FLOAT,177log_scale=True,178),179RangeParameter(180name="epochs",181lower=1,182upper=4,183parameter_type=ParameterType.INT,184),185RangeParameter(186name="dropout",187lower=0.0,188upper=0.5,189parameter_type=ParameterType.FLOAT,190),191ChoiceParameter( # NOTE: ``ChoiceParameters`` don't require log-scale192name="batch_size",193values=[32, 64, 128, 256],194parameter_type=ParameterType.INT,195is_ordered=True,196sort_values=True,197),198]199200search_space = SearchSpace(201parameters=parameters,202# NOTE: In practice, it may make sense to add a constraint203# hidden_size_2 <= hidden_size_1204parameter_constraints=[],205)206207208######################################################################209# Setting up Metrics210# ------------------211#212# Ax has the concept of a `Metric <https://ax.dev/api/core.html#metric>`__213# that defines properties of outcomes and how observations are obtained214# for these outcomes. This allows e.g. encoding how data is fetched from215# some distributed execution backend and post-processed before being216# passed as input to Ax.217#218# In this tutorial we will use219# `multi-objective optimization <https://ax.dev/tutorials/multiobjective_optimization.html>`__220# with the goal of maximizing the validation accuracy and minimizing221# the number of model parameters. The latter represents a simple proxy222# of model latency, which is hard to estimate accurately for small ML223# models (in an actual application we would benchmark the latency while224# running the model on-device).225#226# In our example TorchX will run the training jobs in a fully asynchronous227# fashion locally and write the results to the ``log_dir`` based on the trial228# index (see the ``trainer()`` function above). We will define a metric229# class that is aware of that logging directory. By subclassing230# `TensorboardCurveMetric <https://ax.dev/api/metrics.html?highlight=tensorboardcurvemetric#ax.metrics.tensorboard.TensorboardCurveMetric>`__231# we get the logic to read and parse the TensorBoard logs for free.232#233234from ax.metrics.tensorboard import TensorboardMetric235from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer236237class MyTensorboardMetric(TensorboardMetric):238239# NOTE: We need to tell the new TensorBoard metric how to get the id /240# file handle for the TensorBoard logs from a trial. In this case241# our convention is to just save a separate file per trial in242# the prespecified log dir.243def _get_event_multiplexer_for_trial(self, trial):244mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)245mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)246mul.Reload()247248return mul249250# This indicates whether the metric is queryable while the trial is251# still running. We don't use this in the current tutorial, but Ax252# utilizes this to implement trial-level early-stopping functionality.253@classmethod254def is_available_while_running(cls):255return False256257258######################################################################259# Now we can instantiate the metrics for accuracy and the number of260# model parameters. Here `curve_name` is the name of the metric in the261# TensorBoard logs, while `name` is the metric name used internally262# by Ax. We also specify `lower_is_better` to indicate the favorable263# direction of the two metrics.264#265266val_acc = MyTensorboardMetric(267name="val_acc",268tag="val_acc",269lower_is_better=False,270)271model_num_params = MyTensorboardMetric(272name="num_params",273tag="num_params",274lower_is_better=True,275)276277278######################################################################279# Setting up the ``OptimizationConfig``280# -------------------------------------281#282# The way to tell Ax what it should optimize is by means of an283# `OptimizationConfig <https://ax.dev/api/core.html#module-ax.core.optimization_config>`__.284# Here we use a ``MultiObjectiveOptimizationConfig`` as we will285# be performing multi-objective optimization.286#287# Additionally, Ax supports placing constraints on the different288# metrics by specifying objective thresholds, which bound the region289# of interest in the outcome space that we want to explore. For this290# example, we will constrain the validation accuracy to be at least291# 0.94 (94%) and the number of model parameters to be at most 80,000.292#293294from ax.core import MultiObjective, Objective, ObjectiveThreshold295from ax.core.optimization_config import MultiObjectiveOptimizationConfig296297298opt_config = MultiObjectiveOptimizationConfig(299objective=MultiObjective(300objectives=[301Objective(metric=val_acc, minimize=False),302Objective(metric=model_num_params, minimize=True),303],304),305objective_thresholds=[306ObjectiveThreshold(metric=val_acc, bound=0.94, relative=False),307ObjectiveThreshold(metric=model_num_params, bound=80_000, relative=False),308],309)310311312######################################################################313# Creating the Ax Experiment314# --------------------------315#316# In Ax, the `Experiment <https://ax.dev/api/core.html#ax.core.experiment.Experiment>`__317# object is the object that stores all the information about the problem318# setup.319#320# .. tip:321# ``Experiment`` objects can be serialized to JSON or stored to a322# database backend such as MySQL in order to persist and be available323# to load on different machines. See the the `Ax Docs <https://ax.dev/docs/storage.html>`__324# on the storage functionality for details.325#326327from ax.core import Experiment328329experiment = Experiment(330name="torchx_mnist",331search_space=search_space,332optimization_config=opt_config,333runner=ax_runner,334)335336######################################################################337# Choosing the Generation Strategy338# --------------------------------339#340# A `GenerationStrategy <https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy>`__341# is the abstract representation of how we would like to perform the342# optimization. While this can be customized (if you’d like to do so, see343# `this tutorial <https://ax.dev/tutorials/generation_strategy.html>`__),344# in most cases Ax can automatically determine an appropriate strategy345# based on the search space, optimization config, and the total number346# of trials we want to run.347#348# Typically, Ax chooses to evaluate a number of random configurations349# before starting a model-based Bayesian Optimization strategy.350#351352353total_trials = 48 # total evaluation budget354355from ax.modelbridge.dispatch_utils import choose_generation_strategy356357gs = choose_generation_strategy(358search_space=experiment.search_space,359optimization_config=experiment.optimization_config,360num_trials=total_trials,361)362363364######################################################################365# Configuring the Scheduler366# -------------------------367#368# The ``Scheduler`` acts as the loop control for the optimization.369# It communicates with the backend to launch trials, check their status,370# and retrieve results. In the case of this tutorial, it is simply reading371# and parsing the locally saved logs. In a remote execution setting,372# it would call APIs. The following illustration from the Ax373# `Scheduler tutorial <https://ax.dev/tutorials/scheduler.html>`__374# summarizes how the Scheduler interacts with external systems used to run375# trial evaluations:376#377# .. image:: ../../_static/img/ax_scheduler_illustration.png378#379#380# The ``Scheduler`` requires the ``Experiment`` and the ``GenerationStrategy``.381# A set of options can be passed in via ``SchedulerOptions``. Here, we382# configure the number of total evaluations as well as ``max_pending_trials``,383# the maximum number of trials that should run concurrently. In our384# local setting, this is the number of training jobs running as individual385# processes, while in a remote execution setting, this would be the number386# of machines you want to use in parallel.387#388389390from ax.service.scheduler import Scheduler, SchedulerOptions391392scheduler = Scheduler(393experiment=experiment,394generation_strategy=gs,395options=SchedulerOptions(396total_trials=total_trials, max_pending_trials=4397),398)399400401######################################################################402# Running the optimization403# ------------------------404#405# Now that everything is configured, we can let Ax run the optimization406# in a fully automated fashion. The Scheduler will periodically check407# the logs for the status of all currently running trials, and if a408# trial completes the scheduler will update its status on the409# experiment and fetch the observations needed for the Bayesian410# optimization algorithm.411#412413scheduler.run_all_trials()414415416######################################################################417# Evaluating the results418# ----------------------419#420# We can now inspect the result of the optimization using helper421# functions and visualizations included with Ax.422423######################################################################424# First, we generate a dataframe with a summary of the results425# of the experiment. Each row in this dataframe corresponds to a426# trial (that is, a training job that was run), and contains information427# on the status of the trial, the parameter configuration that was428# evaluated, and the metric values that were observed. This provides429# an easy way to sanity check the optimization.430#431432from ax.service.utils.report_utils import exp_to_df433434df = exp_to_df(experiment)435df.head(10)436437438######################################################################439# We can also visualize the Pareto frontier of tradeoffs between the440# validation accuracy and the number of model parameters.441#442# .. tip::443# Ax uses Plotly to produce interactive plots, which allow you to444# do things like zoom, crop, or hover in order to view details445# of components of the plot. Try it out, and take a look at the446# `visualization tutorial <https://ax.dev/tutorials/visualizations.html>`__447# if you'd like to learn more).448#449# The final optimization results are shown in the figure below where450# the color corresponds to the iteration number for each trial.451# We see that our method was able to successfully explore the452# trade-offs and found both large models with high validation453# accuracy as well as small models with comparatively lower454# validation accuracy.455#456457from ax.service.utils.report_utils import _pareto_frontier_scatter_2d_plotly458459_pareto_frontier_scatter_2d_plotly(experiment)460461462######################################################################463# To better understand what our surrogate models have learned about464# the black box objectives, we can take a look at the leave-one-out465# cross validation results. Since our models are Gaussian Processes,466# they not only provide point predictions but also uncertainty estimates467# about these predictions. A good model means that the predicted means468# (the points in the figure) are close to the 45 degree line and that the469# confidence intervals cover the 45 degree line with the expected frequency470# (here we use 95% confidence intervals, so we would expect them to contain471# the true observation 95% of the time).472#473# As the figures below show, the model size (``num_params``) metric is474# much easier to model than the validation accuracy (``val_acc``) metric.475#476477from ax.modelbridge.cross_validation import compute_diagnostics, cross_validate478from ax.plot.diagnostic import interact_cross_validation_plotly479from ax.utils.notebook.plotting import init_notebook_plotting, render480481cv = cross_validate(model=gs.model) # The surrogate model is stored on the ``GenerationStrategy``482compute_diagnostics(cv)483484interact_cross_validation_plotly(cv)485486487######################################################################488# We can also make contour plots to better understand how the different489# objectives depend on two of the input parameters. In the figure below,490# we show the validation accuracy predicted by the model as a function491# of the two hidden sizes. The validation accuracy clearly increases492# as the hidden sizes increase.493#494495from ax.plot.contour import interact_contour_plotly496497interact_contour_plotly(model=gs.model, metric_name="val_acc")498499500######################################################################501# Similarly, we show the number of model parameters as a function of502# the hidden sizes in the figure below and see that it also increases503# as a function of the hidden sizes (the dependency on ``hidden_size_1``504# is much larger).505506interact_contour_plotly(model=gs.model, metric_name="num_params")507508509######################################################################510# Acknowledgments511# ----------------512#513# We thank the TorchX team (in particular Kiuk Chung and Tristan Rice)514# for their help with integrating TorchX with Ax.515#516517518