CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/ax_multiobjective_nas_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Multi-Objective NAS with Ax
4
==================================================
5
6
**Authors:** `David Eriksson <https://github.com/dme65>`__,
7
`Max Balandat <https://github.com/Balandat>`__,
8
and the Adaptive Experimentation team at Meta.
9
10
In this tutorial, we show how to use `Ax <https://ax.dev/>`__ to run
11
multi-objective neural architecture search (NAS) for a simple neural
12
network model on the popular MNIST dataset. While the underlying
13
methodology would typically be used for more complicated models and
14
larger datasets, we opt for a tutorial that is easily runnable
15
end-to-end on a laptop in less than 20 minutes.
16
17
In many NAS applications, there is a natural tradeoff between multiple
18
objectives of interest. For instance, when deploying models on-device
19
we may want to maximize model performance (for example, accuracy), while
20
simultaneously minimizing competing metrics like power consumption,
21
inference latency, or model size in order to satisfy deployment
22
constraints. Often, we may be able to reduce computational requirements
23
or latency of predictions substantially by accepting minimally lower
24
model performance. Principled methods for exploring such tradeoffs
25
efficiently are key enablers of scalable and sustainable AI, and have
26
many successful applications at Meta - see for instance our
27
`case study <https://research.facebook.com/blog/2021/07/optimizing-model-accuracy-and-latency-using-bayesian-multi-objective-neural-architecture-search/>`__
28
on a Natural Language Understanding model.
29
30
In our example here, we will tune the widths of two hidden layers,
31
the learning rate, the dropout probability, the batch size, and the
32
number of training epochs. The goal is to trade off performance
33
(accuracy on the validation set) and model size (the number of
34
model parameters).
35
36
This tutorial makes use of the following PyTorch libraries:
37
38
- `PyTorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`__ (specifying the model and training loop)
39
- `TorchX <https://github.com/pytorch/torchx>`__ (for running training jobs remotely / asynchronously)
40
- `BoTorch <https://github.com/pytorch/botorch>`__ (the Bayesian Optimization library powering Ax's algorithms)
41
"""
42
43
44
######################################################################
45
# Defining the TorchX App
46
# -----------------------
47
#
48
# Our goal is to optimize the PyTorch Lightning training job defined in
49
# `mnist_train_nas.py <https://github.com/pytorch/tutorials/tree/main/intermediate_source/mnist_train_nas.py>`__.
50
# To do this using TorchX, we write a helper function that takes in
51
# the values of the architecture and hyperparameters of the training
52
# job and creates a `TorchX AppDef <https://pytorch.org/torchx/latest/basics.html>`__
53
# with the appropriate settings.
54
#
55
56
from pathlib import Path
57
58
import torchx
59
60
from torchx import specs
61
from torchx.components import utils
62
63
64
def trainer(
65
log_path: str,
66
hidden_size_1: int,
67
hidden_size_2: int,
68
learning_rate: float,
69
epochs: int,
70
dropout: float,
71
batch_size: int,
72
trial_idx: int = -1,
73
) -> specs.AppDef:
74
75
# define the log path so we can pass it to the TorchX ``AppDef``
76
if trial_idx >= 0:
77
log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()
78
79
return utils.python(
80
# command line arguments to the training script
81
"--log_path",
82
log_path,
83
"--hidden_size_1",
84
str(hidden_size_1),
85
"--hidden_size_2",
86
str(hidden_size_2),
87
"--learning_rate",
88
str(learning_rate),
89
"--epochs",
90
str(epochs),
91
"--dropout",
92
str(dropout),
93
"--batch_size",
94
str(batch_size),
95
# other config options
96
name="trainer",
97
script="mnist_train_nas.py",
98
image=torchx.version.TORCHX_IMAGE,
99
)
100
101
102
######################################################################
103
# Setting up the Runner
104
# ---------------------
105
#
106
# Ax’s `Runner <https://ax.dev/api/core.html#ax.core.runner.Runner>`__
107
# abstraction allows writing interfaces to various backends.
108
# Ax already comes with Runner for TorchX, and so we just need to
109
# configure it. For the purpose of this tutorial we run jobs locally
110
# in a fully asynchronous fashion.
111
#
112
# In order to launch them on a cluster, you can instead specify a
113
# different TorchX scheduler and adjust the configuration appropriately.
114
# For example, if you have a Kubernetes cluster, you just need to change the
115
# scheduler from ``local_cwd`` to ``kubernetes``).
116
#
117
118
119
import tempfile
120
from ax.runners.torchx import TorchXRunner
121
122
# Make a temporary dir to log our results into
123
log_dir = tempfile.mkdtemp()
124
125
ax_runner = TorchXRunner(
126
tracker_base="/tmp/",
127
component=trainer,
128
# NOTE: To launch this job on a cluster instead of locally you can
129
# specify a different scheduler and adjust arguments appropriately.
130
scheduler="local_cwd",
131
component_const_params={"log_path": log_dir},
132
cfg={},
133
)
134
135
######################################################################
136
# Setting up the ``SearchSpace``
137
# ------------------------------
138
#
139
# First, we define our search space. Ax supports both range parameters
140
# of type integer and float as well as choice parameters which can have
141
# non-numerical types such as strings.
142
# We will tune the hidden sizes, learning rate, dropout, and number of
143
# epochs as range parameters and tune the batch size as an ordered choice
144
# parameter to enforce it to be a power of 2.
145
#
146
147
from ax.core import (
148
ChoiceParameter,
149
ParameterType,
150
RangeParameter,
151
SearchSpace,
152
)
153
154
parameters = [
155
# NOTE: In a real-world setting, hidden_size_1 and hidden_size_2
156
# should probably be powers of 2, but in our simple example this
157
# would mean that ``num_params`` can't take on that many values, which
158
# in turn makes the Pareto frontier look pretty weird.
159
RangeParameter(
160
name="hidden_size_1",
161
lower=16,
162
upper=128,
163
parameter_type=ParameterType.INT,
164
log_scale=True,
165
),
166
RangeParameter(
167
name="hidden_size_2",
168
lower=16,
169
upper=128,
170
parameter_type=ParameterType.INT,
171
log_scale=True,
172
),
173
RangeParameter(
174
name="learning_rate",
175
lower=1e-4,
176
upper=1e-2,
177
parameter_type=ParameterType.FLOAT,
178
log_scale=True,
179
),
180
RangeParameter(
181
name="epochs",
182
lower=1,
183
upper=4,
184
parameter_type=ParameterType.INT,
185
),
186
RangeParameter(
187
name="dropout",
188
lower=0.0,
189
upper=0.5,
190
parameter_type=ParameterType.FLOAT,
191
),
192
ChoiceParameter( # NOTE: ``ChoiceParameters`` don't require log-scale
193
name="batch_size",
194
values=[32, 64, 128, 256],
195
parameter_type=ParameterType.INT,
196
is_ordered=True,
197
sort_values=True,
198
),
199
]
200
201
search_space = SearchSpace(
202
parameters=parameters,
203
# NOTE: In practice, it may make sense to add a constraint
204
# hidden_size_2 <= hidden_size_1
205
parameter_constraints=[],
206
)
207
208
209
######################################################################
210
# Setting up Metrics
211
# ------------------
212
#
213
# Ax has the concept of a `Metric <https://ax.dev/api/core.html#metric>`__
214
# that defines properties of outcomes and how observations are obtained
215
# for these outcomes. This allows e.g. encoding how data is fetched from
216
# some distributed execution backend and post-processed before being
217
# passed as input to Ax.
218
#
219
# In this tutorial we will use
220
# `multi-objective optimization <https://ax.dev/tutorials/multiobjective_optimization.html>`__
221
# with the goal of maximizing the validation accuracy and minimizing
222
# the number of model parameters. The latter represents a simple proxy
223
# of model latency, which is hard to estimate accurately for small ML
224
# models (in an actual application we would benchmark the latency while
225
# running the model on-device).
226
#
227
# In our example TorchX will run the training jobs in a fully asynchronous
228
# fashion locally and write the results to the ``log_dir`` based on the trial
229
# index (see the ``trainer()`` function above). We will define a metric
230
# class that is aware of that logging directory. By subclassing
231
# `TensorboardCurveMetric <https://ax.dev/api/metrics.html?highlight=tensorboardcurvemetric#ax.metrics.tensorboard.TensorboardCurveMetric>`__
232
# we get the logic to read and parse the TensorBoard logs for free.
233
#
234
235
from ax.metrics.tensorboard import TensorboardMetric
236
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
237
238
class MyTensorboardMetric(TensorboardMetric):
239
240
# NOTE: We need to tell the new TensorBoard metric how to get the id /
241
# file handle for the TensorBoard logs from a trial. In this case
242
# our convention is to just save a separate file per trial in
243
# the prespecified log dir.
244
def _get_event_multiplexer_for_trial(self, trial):
245
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
246
mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)
247
mul.Reload()
248
249
return mul
250
251
# This indicates whether the metric is queryable while the trial is
252
# still running. We don't use this in the current tutorial, but Ax
253
# utilizes this to implement trial-level early-stopping functionality.
254
@classmethod
255
def is_available_while_running(cls):
256
return False
257
258
259
######################################################################
260
# Now we can instantiate the metrics for accuracy and the number of
261
# model parameters. Here `curve_name` is the name of the metric in the
262
# TensorBoard logs, while `name` is the metric name used internally
263
# by Ax. We also specify `lower_is_better` to indicate the favorable
264
# direction of the two metrics.
265
#
266
267
val_acc = MyTensorboardMetric(
268
name="val_acc",
269
tag="val_acc",
270
lower_is_better=False,
271
)
272
model_num_params = MyTensorboardMetric(
273
name="num_params",
274
tag="num_params",
275
lower_is_better=True,
276
)
277
278
279
######################################################################
280
# Setting up the ``OptimizationConfig``
281
# -------------------------------------
282
#
283
# The way to tell Ax what it should optimize is by means of an
284
# `OptimizationConfig <https://ax.dev/api/core.html#module-ax.core.optimization_config>`__.
285
# Here we use a ``MultiObjectiveOptimizationConfig`` as we will
286
# be performing multi-objective optimization.
287
#
288
# Additionally, Ax supports placing constraints on the different
289
# metrics by specifying objective thresholds, which bound the region
290
# of interest in the outcome space that we want to explore. For this
291
# example, we will constrain the validation accuracy to be at least
292
# 0.94 (94%) and the number of model parameters to be at most 80,000.
293
#
294
295
from ax.core import MultiObjective, Objective, ObjectiveThreshold
296
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
297
298
299
opt_config = MultiObjectiveOptimizationConfig(
300
objective=MultiObjective(
301
objectives=[
302
Objective(metric=val_acc, minimize=False),
303
Objective(metric=model_num_params, minimize=True),
304
],
305
),
306
objective_thresholds=[
307
ObjectiveThreshold(metric=val_acc, bound=0.94, relative=False),
308
ObjectiveThreshold(metric=model_num_params, bound=80_000, relative=False),
309
],
310
)
311
312
313
######################################################################
314
# Creating the Ax Experiment
315
# --------------------------
316
#
317
# In Ax, the `Experiment <https://ax.dev/api/core.html#ax.core.experiment.Experiment>`__
318
# object is the object that stores all the information about the problem
319
# setup.
320
#
321
# .. tip:
322
# ``Experiment`` objects can be serialized to JSON or stored to a
323
# database backend such as MySQL in order to persist and be available
324
# to load on different machines. See the the `Ax Docs <https://ax.dev/docs/storage.html>`__
325
# on the storage functionality for details.
326
#
327
328
from ax.core import Experiment
329
330
experiment = Experiment(
331
name="torchx_mnist",
332
search_space=search_space,
333
optimization_config=opt_config,
334
runner=ax_runner,
335
)
336
337
######################################################################
338
# Choosing the Generation Strategy
339
# --------------------------------
340
#
341
# A `GenerationStrategy <https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy>`__
342
# is the abstract representation of how we would like to perform the
343
# optimization. While this can be customized (if you’d like to do so, see
344
# `this tutorial <https://ax.dev/tutorials/generation_strategy.html>`__),
345
# in most cases Ax can automatically determine an appropriate strategy
346
# based on the search space, optimization config, and the total number
347
# of trials we want to run.
348
#
349
# Typically, Ax chooses to evaluate a number of random configurations
350
# before starting a model-based Bayesian Optimization strategy.
351
#
352
353
354
total_trials = 48 # total evaluation budget
355
356
from ax.modelbridge.dispatch_utils import choose_generation_strategy
357
358
gs = choose_generation_strategy(
359
search_space=experiment.search_space,
360
optimization_config=experiment.optimization_config,
361
num_trials=total_trials,
362
)
363
364
365
######################################################################
366
# Configuring the Scheduler
367
# -------------------------
368
#
369
# The ``Scheduler`` acts as the loop control for the optimization.
370
# It communicates with the backend to launch trials, check their status,
371
# and retrieve results. In the case of this tutorial, it is simply reading
372
# and parsing the locally saved logs. In a remote execution setting,
373
# it would call APIs. The following illustration from the Ax
374
# `Scheduler tutorial <https://ax.dev/tutorials/scheduler.html>`__
375
# summarizes how the Scheduler interacts with external systems used to run
376
# trial evaluations:
377
#
378
# .. image:: ../../_static/img/ax_scheduler_illustration.png
379
#
380
#
381
# The ``Scheduler`` requires the ``Experiment`` and the ``GenerationStrategy``.
382
# A set of options can be passed in via ``SchedulerOptions``. Here, we
383
# configure the number of total evaluations as well as ``max_pending_trials``,
384
# the maximum number of trials that should run concurrently. In our
385
# local setting, this is the number of training jobs running as individual
386
# processes, while in a remote execution setting, this would be the number
387
# of machines you want to use in parallel.
388
#
389
390
391
from ax.service.scheduler import Scheduler, SchedulerOptions
392
393
scheduler = Scheduler(
394
experiment=experiment,
395
generation_strategy=gs,
396
options=SchedulerOptions(
397
total_trials=total_trials, max_pending_trials=4
398
),
399
)
400
401
402
######################################################################
403
# Running the optimization
404
# ------------------------
405
#
406
# Now that everything is configured, we can let Ax run the optimization
407
# in a fully automated fashion. The Scheduler will periodically check
408
# the logs for the status of all currently running trials, and if a
409
# trial completes the scheduler will update its status on the
410
# experiment and fetch the observations needed for the Bayesian
411
# optimization algorithm.
412
#
413
414
scheduler.run_all_trials()
415
416
417
######################################################################
418
# Evaluating the results
419
# ----------------------
420
#
421
# We can now inspect the result of the optimization using helper
422
# functions and visualizations included with Ax.
423
424
######################################################################
425
# First, we generate a dataframe with a summary of the results
426
# of the experiment. Each row in this dataframe corresponds to a
427
# trial (that is, a training job that was run), and contains information
428
# on the status of the trial, the parameter configuration that was
429
# evaluated, and the metric values that were observed. This provides
430
# an easy way to sanity check the optimization.
431
#
432
433
from ax.service.utils.report_utils import exp_to_df
434
435
df = exp_to_df(experiment)
436
df.head(10)
437
438
439
######################################################################
440
# We can also visualize the Pareto frontier of tradeoffs between the
441
# validation accuracy and the number of model parameters.
442
#
443
# .. tip::
444
# Ax uses Plotly to produce interactive plots, which allow you to
445
# do things like zoom, crop, or hover in order to view details
446
# of components of the plot. Try it out, and take a look at the
447
# `visualization tutorial <https://ax.dev/tutorials/visualizations.html>`__
448
# if you'd like to learn more).
449
#
450
# The final optimization results are shown in the figure below where
451
# the color corresponds to the iteration number for each trial.
452
# We see that our method was able to successfully explore the
453
# trade-offs and found both large models with high validation
454
# accuracy as well as small models with comparatively lower
455
# validation accuracy.
456
#
457
458
from ax.service.utils.report_utils import _pareto_frontier_scatter_2d_plotly
459
460
_pareto_frontier_scatter_2d_plotly(experiment)
461
462
463
######################################################################
464
# To better understand what our surrogate models have learned about
465
# the black box objectives, we can take a look at the leave-one-out
466
# cross validation results. Since our models are Gaussian Processes,
467
# they not only provide point predictions but also uncertainty estimates
468
# about these predictions. A good model means that the predicted means
469
# (the points in the figure) are close to the 45 degree line and that the
470
# confidence intervals cover the 45 degree line with the expected frequency
471
# (here we use 95% confidence intervals, so we would expect them to contain
472
# the true observation 95% of the time).
473
#
474
# As the figures below show, the model size (``num_params``) metric is
475
# much easier to model than the validation accuracy (``val_acc``) metric.
476
#
477
478
from ax.modelbridge.cross_validation import compute_diagnostics, cross_validate
479
from ax.plot.diagnostic import interact_cross_validation_plotly
480
from ax.utils.notebook.plotting import init_notebook_plotting, render
481
482
cv = cross_validate(model=gs.model) # The surrogate model is stored on the ``GenerationStrategy``
483
compute_diagnostics(cv)
484
485
interact_cross_validation_plotly(cv)
486
487
488
######################################################################
489
# We can also make contour plots to better understand how the different
490
# objectives depend on two of the input parameters. In the figure below,
491
# we show the validation accuracy predicted by the model as a function
492
# of the two hidden sizes. The validation accuracy clearly increases
493
# as the hidden sizes increase.
494
#
495
496
from ax.plot.contour import interact_contour_plotly
497
498
interact_contour_plotly(model=gs.model, metric_name="val_acc")
499
500
501
######################################################################
502
# Similarly, we show the number of model parameters as a function of
503
# the hidden sizes in the figure below and see that it also increases
504
# as a function of the hidden sizes (the dependency on ``hidden_size_1``
505
# is much larger).
506
507
interact_contour_plotly(model=gs.model, metric_name="num_params")
508
509
510
######################################################################
511
# Acknowledgments
512
# ----------------
513
#
514
# We thank the TorchX team (in particular Kiuk Chung and Tristan Rice)
515
# for their help with integrating TorchX with Ax.
516
#
517
518