CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/basics/buildmodel_tutorial.py
Views: 1017
1
"""
2
`Learn the Basics <intro.html>`_ ||
3
`Quickstart <quickstart_tutorial.html>`_ ||
4
`Tensors <tensorqs_tutorial.html>`_ ||
5
`Datasets & DataLoaders <data_tutorial.html>`_ ||
6
`Transforms <transforms_tutorial.html>`_ ||
7
**Build Model** ||
8
`Autograd <autogradqs_tutorial.html>`_ ||
9
`Optimization <optimization_tutorial.html>`_ ||
10
`Save & Load Model <saveloadrun_tutorial.html>`_
11
12
Build the Neural Network
13
========================
14
15
Neural networks comprise of layers/modules that perform operations on data.
16
The `torch.nn <https://pytorch.org/docs/stable/nn.html>`_ namespace provides all the building blocks you need to
17
build your own neural network. Every module in PyTorch subclasses the `nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_.
18
A neural network is a module itself that consists of other modules (layers). This nested structure allows for
19
building and managing complex architectures easily.
20
21
In the following sections, we'll build a neural network to classify images in the FashionMNIST dataset.
22
23
"""
24
25
import os
26
import torch
27
from torch import nn
28
from torch.utils.data import DataLoader
29
from torchvision import datasets, transforms
30
31
32
#############################################
33
# Get Device for Training
34
# -----------------------
35
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
36
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.
37
38
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
39
print(f"Using {device} device")
40
41
##############################################
42
# Define the Class
43
# -------------------------
44
# We define our neural network by subclassing ``nn.Module``, and
45
# initialize the neural network layers in ``__init__``. Every ``nn.Module`` subclass implements
46
# the operations on input data in the ``forward`` method.
47
48
class NeuralNetwork(nn.Module):
49
def __init__(self):
50
super().__init__()
51
self.flatten = nn.Flatten()
52
self.linear_relu_stack = nn.Sequential(
53
nn.Linear(28*28, 512),
54
nn.ReLU(),
55
nn.Linear(512, 512),
56
nn.ReLU(),
57
nn.Linear(512, 10),
58
)
59
60
def forward(self, x):
61
x = self.flatten(x)
62
logits = self.linear_relu_stack(x)
63
return logits
64
65
##############################################
66
# We create an instance of ``NeuralNetwork``, and move it to the ``device``, and print
67
# its structure.
68
69
model = NeuralNetwork().to(device)
70
print(model)
71
72
73
##############################################
74
# To use the model, we pass it the input data. This executes the model's ``forward``,
75
# along with some `background operations <https://github.com/pytorch/pytorch/blob/270111b7b611d174967ed204776985cefca9c144/torch/nn/modules/module.py#L866>`_.
76
# Do not call ``model.forward()`` directly!
77
#
78
# Calling the model on the input returns a 2-dimensional tensor with dim=0 corresponding to each output of 10 raw predicted values for each class, and dim=1 corresponding to the individual values of each output.
79
# We get the prediction probabilities by passing it through an instance of the ``nn.Softmax`` module.
80
81
X = torch.rand(1, 28, 28, device=device)
82
logits = model(X)
83
pred_probab = nn.Softmax(dim=1)(logits)
84
y_pred = pred_probab.argmax(1)
85
print(f"Predicted class: {y_pred}")
86
87
88
######################################################################
89
# --------------
90
#
91
92
93
##############################################
94
# Model Layers
95
# -------------------------
96
#
97
# Let's break down the layers in the FashionMNIST model. To illustrate it, we
98
# will take a sample minibatch of 3 images of size 28x28 and see what happens to it as
99
# we pass it through the network.
100
101
input_image = torch.rand(3,28,28)
102
print(input_image.size())
103
104
##################################################
105
# nn.Flatten
106
# ^^^^^^^^^^^^^^^^^^^^^^
107
# We initialize the `nn.Flatten <https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html>`_
108
# layer to convert each 2D 28x28 image into a contiguous array of 784 pixel values (
109
# the minibatch dimension (at dim=0) is maintained).
110
111
flatten = nn.Flatten()
112
flat_image = flatten(input_image)
113
print(flat_image.size())
114
115
##############################################
116
# nn.Linear
117
# ^^^^^^^^^^^^^^^^^^^^^^
118
# The `linear layer <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_
119
# is a module that applies a linear transformation on the input using its stored weights and biases.
120
#
121
layer1 = nn.Linear(in_features=28*28, out_features=20)
122
hidden1 = layer1(flat_image)
123
print(hidden1.size())
124
125
126
#################################################
127
# nn.ReLU
128
# ^^^^^^^^^^^^^^^^^^^^^^
129
# Non-linear activations are what create the complex mappings between the model's inputs and outputs.
130
# They are applied after linear transformations to introduce *nonlinearity*, helping neural networks
131
# learn a wide variety of phenomena.
132
#
133
# In this model, we use `nn.ReLU <https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html>`_ between our
134
# linear layers, but there's other activations to introduce non-linearity in your model.
135
136
print(f"Before ReLU: {hidden1}\n\n")
137
hidden1 = nn.ReLU()(hidden1)
138
print(f"After ReLU: {hidden1}")
139
140
141
142
#################################################
143
# nn.Sequential
144
# ^^^^^^^^^^^^^^^^^^^^^^
145
# `nn.Sequential <https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html>`_ is an ordered
146
# container of modules. The data is passed through all the modules in the same order as defined. You can use
147
# sequential containers to put together a quick network like ``seq_modules``.
148
149
seq_modules = nn.Sequential(
150
flatten,
151
layer1,
152
nn.ReLU(),
153
nn.Linear(20, 10)
154
)
155
input_image = torch.rand(3,28,28)
156
logits = seq_modules(input_image)
157
158
################################################################
159
# nn.Softmax
160
# ^^^^^^^^^^^^^^^^^^^^^^
161
# The last linear layer of the neural network returns `logits` - raw values in [-\infty, \infty] - which are passed to the
162
# `nn.Softmax <https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html>`_ module. The logits are scaled to values
163
# [0, 1] representing the model's predicted probabilities for each class. ``dim`` parameter indicates the dimension along
164
# which the values must sum to 1.
165
166
softmax = nn.Softmax(dim=1)
167
pred_probab = softmax(logits)
168
169
170
#################################################
171
# Model Parameters
172
# -------------------------
173
# Many layers inside a neural network are *parameterized*, i.e. have associated weights
174
# and biases that are optimized during training. Subclassing ``nn.Module`` automatically
175
# tracks all fields defined inside your model object, and makes all parameters
176
# accessible using your model's ``parameters()`` or ``named_parameters()`` methods.
177
#
178
# In this example, we iterate over each parameter, and print its size and a preview of its values.
179
#
180
181
182
print(f"Model structure: {model}\n\n")
183
184
for name, param in model.named_parameters():
185
print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")
186
187
######################################################################
188
# --------------
189
#
190
191
#################################################################
192
# Further Reading
193
# -----------------
194
# - `torch.nn API <https://pytorch.org/docs/stable/nn.html>`_
195
196