CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/mnist_train_nas.py
Views: 712
1
"""
2
Example training code for ``ax_multiobjective_nas_tutorial.py``
3
"""
4
5
import argparse
6
import logging
7
import os
8
import sys
9
import time
10
import warnings
11
12
import torch
13
from IPython.utils import io
14
from pytorch_lightning import LightningModule, Trainer
15
from pytorch_lightning import loggers as pl_loggers
16
from torch import nn
17
from torch.nn import functional as F
18
from torch.utils.data import DataLoader
19
from torchmetrics.functional.classification.accuracy import multiclass_accuracy
20
from torchvision import transforms
21
from torchvision.datasets import MNIST
22
23
warnings.filterwarnings("ignore") # Disable data logger warnings
24
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) # Disable GPU/TPU prints
25
26
27
def parse_args():
28
parser = argparse.ArgumentParser(description="train mnist")
29
parser.add_argument(
30
"--log_path", type=str, required=True, help="dir to place tensorboard logs from all trials"
31
)
32
parser.add_argument(
33
"--hidden_size_1", type=int, required=True, help="hidden size layer 1"
34
)
35
parser.add_argument(
36
"--hidden_size_2", type=int, required=True, help="hidden size layer 2"
37
)
38
parser.add_argument("--learning_rate", type=float, required=True, help="learning rate")
39
parser.add_argument("--epochs", type=int, required=True, help="number of epochs")
40
parser.add_argument("--dropout", type=float, required=True, help="dropout probability")
41
parser.add_argument("--batch_size", type=int, required=True, help="batch size")
42
return parser.parse_args()
43
44
args = parse_args()
45
46
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
47
48
49
class MnistModel(LightningModule):
50
def __init__(self):
51
super().__init__()
52
53
# Tunable parameters
54
self.hidden_size_1 = args.hidden_size_1
55
self.hidden_size_2 = args.hidden_size_2
56
self.learning_rate = args.learning_rate
57
self.dropout = args.dropout
58
self.batch_size = args.batch_size
59
60
# Set class attributes
61
self.data_dir = PATH_DATASETS
62
63
# Hardcode some dataset specific attributes
64
self.num_classes = 10
65
self.dims = (1, 28, 28)
66
channels, width, height = self.dims
67
self.transform = transforms.Compose(
68
[
69
transforms.ToTensor(),
70
transforms.Normalize((0.1307,), (0.3081,)),
71
]
72
)
73
74
# Create a PyTorch model
75
layers = [nn.Flatten()]
76
width = channels * width * height
77
hidden_layers = [self.hidden_size_1, self.hidden_size_2]
78
num_params = 0
79
for hidden_size in hidden_layers:
80
if hidden_size > 0:
81
layers.append(nn.Linear(width, hidden_size))
82
layers.append(nn.ReLU())
83
layers.append(nn.Dropout(self.dropout))
84
num_params += width * hidden_size
85
width = hidden_size
86
layers.append(nn.Linear(width, self.num_classes))
87
num_params += width * self.num_classes
88
89
# Save the model and parameter counts
90
self.num_params = num_params
91
self.model = nn.Sequential(*layers) # No need to use Relu for the last layer
92
93
def forward(self, x):
94
x = self.model(x)
95
return F.log_softmax(x, dim=1)
96
97
def training_step(self, batch, batch_idx):
98
x, y = batch
99
logits = self(x)
100
loss = F.nll_loss(logits, y)
101
return loss
102
103
def validation_step(self, batch, batch_idx):
104
x, y = batch
105
logits = self(x)
106
loss = F.nll_loss(logits, y)
107
preds = torch.argmax(logits, dim=1)
108
acc = multiclass_accuracy(preds, y, num_classes=self.num_classes)
109
self.log("val_acc", acc, prog_bar=False)
110
return loss
111
112
def configure_optimizers(self):
113
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
114
return optimizer
115
116
def prepare_data(self):
117
MNIST(self.data_dir, train=True, download=True)
118
MNIST(self.data_dir, train=False, download=True)
119
120
def setup(self, stage=None):
121
self.mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
122
self.mnist_val = MNIST(self.data_dir, train=False, transform=self.transform)
123
124
def train_dataloader(self):
125
return DataLoader(self.mnist_train, batch_size=self.batch_size)
126
127
def val_dataloader(self):
128
return DataLoader(self.mnist_val, batch_size=self.batch_size)
129
130
131
def run_training_job():
132
133
mnist_model = MnistModel()
134
135
# Initialize a trainer (don't log anything since things get so slow...)
136
trainer = Trainer(
137
logger=False,
138
max_epochs=args.epochs,
139
enable_progress_bar=False,
140
deterministic=True, # Do we want a bit of noise?
141
default_root_dir=args.log_path,
142
)
143
144
logger = pl_loggers.TensorBoardLogger(args.log_path)
145
146
print(f"Logging to path: {args.log_path}.")
147
148
# Train the model and log time ⚡
149
start = time.time()
150
trainer.fit(model=mnist_model)
151
end = time.time()
152
train_time = end - start
153
logger.log_metrics({"train_time": end - start})
154
155
# Compute the validation accuracy once and log the score
156
with io.capture_output() as captured:
157
val_accuracy = trainer.validate()[0]["val_acc"]
158
logger.log_metrics({"val_acc": val_accuracy})
159
160
# Log the number of model parameters
161
num_params = trainer.model.num_params
162
logger.log_metrics({"num_params": num_params})
163
164
logger.save()
165
166
# Print outputs
167
print(f"train time: {train_time}, val acc: {val_accuracy}, num_params: {num_params}")
168
169
170
if __name__ == "__main__":
171
run_training_job()
172
173