Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/mnist_train_nas.py
Views: 712
"""1Example training code for ``ax_multiobjective_nas_tutorial.py``2"""34import argparse5import logging6import os7import sys8import time9import warnings1011import torch12from IPython.utils import io13from pytorch_lightning import LightningModule, Trainer14from pytorch_lightning import loggers as pl_loggers15from torch import nn16from torch.nn import functional as F17from torch.utils.data import DataLoader18from torchmetrics.functional.classification.accuracy import multiclass_accuracy19from torchvision import transforms20from torchvision.datasets import MNIST2122warnings.filterwarnings("ignore") # Disable data logger warnings23logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) # Disable GPU/TPU prints242526def parse_args():27parser = argparse.ArgumentParser(description="train mnist")28parser.add_argument(29"--log_path", type=str, required=True, help="dir to place tensorboard logs from all trials"30)31parser.add_argument(32"--hidden_size_1", type=int, required=True, help="hidden size layer 1"33)34parser.add_argument(35"--hidden_size_2", type=int, required=True, help="hidden size layer 2"36)37parser.add_argument("--learning_rate", type=float, required=True, help="learning rate")38parser.add_argument("--epochs", type=int, required=True, help="number of epochs")39parser.add_argument("--dropout", type=float, required=True, help="dropout probability")40parser.add_argument("--batch_size", type=int, required=True, help="batch size")41return parser.parse_args()4243args = parse_args()4445PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")464748class MnistModel(LightningModule):49def __init__(self):50super().__init__()5152# Tunable parameters53self.hidden_size_1 = args.hidden_size_154self.hidden_size_2 = args.hidden_size_255self.learning_rate = args.learning_rate56self.dropout = args.dropout57self.batch_size = args.batch_size5859# Set class attributes60self.data_dir = PATH_DATASETS6162# Hardcode some dataset specific attributes63self.num_classes = 1064self.dims = (1, 28, 28)65channels, width, height = self.dims66self.transform = transforms.Compose(67[68transforms.ToTensor(),69transforms.Normalize((0.1307,), (0.3081,)),70]71)7273# Create a PyTorch model74layers = [nn.Flatten()]75width = channels * width * height76hidden_layers = [self.hidden_size_1, self.hidden_size_2]77num_params = 078for hidden_size in hidden_layers:79if hidden_size > 0:80layers.append(nn.Linear(width, hidden_size))81layers.append(nn.ReLU())82layers.append(nn.Dropout(self.dropout))83num_params += width * hidden_size84width = hidden_size85layers.append(nn.Linear(width, self.num_classes))86num_params += width * self.num_classes8788# Save the model and parameter counts89self.num_params = num_params90self.model = nn.Sequential(*layers) # No need to use Relu for the last layer9192def forward(self, x):93x = self.model(x)94return F.log_softmax(x, dim=1)9596def training_step(self, batch, batch_idx):97x, y = batch98logits = self(x)99loss = F.nll_loss(logits, y)100return loss101102def validation_step(self, batch, batch_idx):103x, y = batch104logits = self(x)105loss = F.nll_loss(logits, y)106preds = torch.argmax(logits, dim=1)107acc = multiclass_accuracy(preds, y, num_classes=self.num_classes)108self.log("val_acc", acc, prog_bar=False)109return loss110111def configure_optimizers(self):112optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)113return optimizer114115def prepare_data(self):116MNIST(self.data_dir, train=True, download=True)117MNIST(self.data_dir, train=False, download=True)118119def setup(self, stage=None):120self.mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)121self.mnist_val = MNIST(self.data_dir, train=False, transform=self.transform)122123def train_dataloader(self):124return DataLoader(self.mnist_train, batch_size=self.batch_size)125126def val_dataloader(self):127return DataLoader(self.mnist_val, batch_size=self.batch_size)128129130def run_training_job():131132mnist_model = MnistModel()133134# Initialize a trainer (don't log anything since things get so slow...)135trainer = Trainer(136logger=False,137max_epochs=args.epochs,138enable_progress_bar=False,139deterministic=True, # Do we want a bit of noise?140default_root_dir=args.log_path,141)142143logger = pl_loggers.TensorBoardLogger(args.log_path)144145print(f"Logging to path: {args.log_path}.")146147# Train the model and log time ⚡148start = time.time()149trainer.fit(model=mnist_model)150end = time.time()151train_time = end - start152logger.log_metrics({"train_time": end - start})153154# Compute the validation accuracy once and log the score155with io.capture_output() as captured:156val_accuracy = trainer.validate()[0]["val_acc"]157logger.log_metrics({"val_acc": val_accuracy})158159# Log the number of model parameters160num_params = trainer.model.num_params161logger.log_metrics({"num_params": num_params})162163logger.save()164165# Print outputs166print(f"train time: {train_time}, val acc: {val_accuracy}, num_params: {num_params}")167168169if __name__ == "__main__":170run_training_job()171172173