CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/basics/buildmodel_tutorial.py
Views: 713
1
"""
2
`Learn the Basics <intro.html>`_ ||
3
`Quickstart <quickstart_tutorial.html>`_ ||
4
`Tensors <tensorqs_tutorial.html>`_ ||
5
`Datasets & DataLoaders <data_tutorial.html>`_ ||
6
`Transforms <transforms_tutorial.html>`_ ||
7
**Build Model** ||
8
`Autograd <autogradqs_tutorial.html>`_ ||
9
`Optimization <optimization_tutorial.html>`_ ||
10
`Save & Load Model <saveloadrun_tutorial.html>`_
11
12
Build the Neural Network
13
========================
14
15
Neural networks comprise of layers/modules that perform operations on data.
16
The `torch.nn <https://pytorch.org/docs/stable/nn.html>`_ namespace provides all the building blocks you need to
17
build your own neural network. Every module in PyTorch subclasses the `nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_.
18
A neural network is a module itself that consists of other modules (layers). This nested structure allows for
19
building and managing complex architectures easily.
20
21
In the following sections, we'll build a neural network to classify images in the FashionMNIST dataset.
22
23
"""
24
25
import os
26
import torch
27
from torch import nn
28
from torch.utils.data import DataLoader
29
from torchvision import datasets, transforms
30
31
32
#############################################
33
# Get Device for Training
34
# -----------------------
35
# We want to be able to train our model on a hardware accelerator like the GPU or MPS,
36
# if available. Let's check to see if `torch.cuda <https://pytorch.org/docs/stable/notes/cuda.html>`_
37
# or `torch.backends.mps <https://pytorch.org/docs/stable/notes/mps.html>`_ are available, otherwise we use the CPU.
38
39
device = (
40
"cuda"
41
if torch.cuda.is_available()
42
else "mps"
43
if torch.backends.mps.is_available()
44
else "cpu"
45
)
46
print(f"Using {device} device")
47
48
##############################################
49
# Define the Class
50
# -------------------------
51
# We define our neural network by subclassing ``nn.Module``, and
52
# initialize the neural network layers in ``__init__``. Every ``nn.Module`` subclass implements
53
# the operations on input data in the ``forward`` method.
54
55
class NeuralNetwork(nn.Module):
56
def __init__(self):
57
super().__init__()
58
self.flatten = nn.Flatten()
59
self.linear_relu_stack = nn.Sequential(
60
nn.Linear(28*28, 512),
61
nn.ReLU(),
62
nn.Linear(512, 512),
63
nn.ReLU(),
64
nn.Linear(512, 10),
65
)
66
67
def forward(self, x):
68
x = self.flatten(x)
69
logits = self.linear_relu_stack(x)
70
return logits
71
72
##############################################
73
# We create an instance of ``NeuralNetwork``, and move it to the ``device``, and print
74
# its structure.
75
76
model = NeuralNetwork().to(device)
77
print(model)
78
79
80
##############################################
81
# To use the model, we pass it the input data. This executes the model's ``forward``,
82
# along with some `background operations <https://github.com/pytorch/pytorch/blob/270111b7b611d174967ed204776985cefca9c144/torch/nn/modules/module.py#L866>`_.
83
# Do not call ``model.forward()`` directly!
84
#
85
# Calling the model on the input returns a 2-dimensional tensor with dim=0 corresponding to each output of 10 raw predicted values for each class, and dim=1 corresponding to the individual values of each output.
86
# We get the prediction probabilities by passing it through an instance of the ``nn.Softmax`` module.
87
88
X = torch.rand(1, 28, 28, device=device)
89
logits = model(X)
90
pred_probab = nn.Softmax(dim=1)(logits)
91
y_pred = pred_probab.argmax(1)
92
print(f"Predicted class: {y_pred}")
93
94
95
######################################################################
96
# --------------
97
#
98
99
100
##############################################
101
# Model Layers
102
# -------------------------
103
#
104
# Let's break down the layers in the FashionMNIST model. To illustrate it, we
105
# will take a sample minibatch of 3 images of size 28x28 and see what happens to it as
106
# we pass it through the network.
107
108
input_image = torch.rand(3,28,28)
109
print(input_image.size())
110
111
##################################################
112
# nn.Flatten
113
# ^^^^^^^^^^^^^^^^^^^^^^
114
# We initialize the `nn.Flatten <https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html>`_
115
# layer to convert each 2D 28x28 image into a contiguous array of 784 pixel values (
116
# the minibatch dimension (at dim=0) is maintained).
117
118
flatten = nn.Flatten()
119
flat_image = flatten(input_image)
120
print(flat_image.size())
121
122
##############################################
123
# nn.Linear
124
# ^^^^^^^^^^^^^^^^^^^^^^
125
# The `linear layer <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_
126
# is a module that applies a linear transformation on the input using its stored weights and biases.
127
#
128
layer1 = nn.Linear(in_features=28*28, out_features=20)
129
hidden1 = layer1(flat_image)
130
print(hidden1.size())
131
132
133
#################################################
134
# nn.ReLU
135
# ^^^^^^^^^^^^^^^^^^^^^^
136
# Non-linear activations are what create the complex mappings between the model's inputs and outputs.
137
# They are applied after linear transformations to introduce *nonlinearity*, helping neural networks
138
# learn a wide variety of phenomena.
139
#
140
# In this model, we use `nn.ReLU <https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html>`_ between our
141
# linear layers, but there's other activations to introduce non-linearity in your model.
142
143
print(f"Before ReLU: {hidden1}\n\n")
144
hidden1 = nn.ReLU()(hidden1)
145
print(f"After ReLU: {hidden1}")
146
147
148
149
#################################################
150
# nn.Sequential
151
# ^^^^^^^^^^^^^^^^^^^^^^
152
# `nn.Sequential <https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html>`_ is an ordered
153
# container of modules. The data is passed through all the modules in the same order as defined. You can use
154
# sequential containers to put together a quick network like ``seq_modules``.
155
156
seq_modules = nn.Sequential(
157
flatten,
158
layer1,
159
nn.ReLU(),
160
nn.Linear(20, 10)
161
)
162
input_image = torch.rand(3,28,28)
163
logits = seq_modules(input_image)
164
165
################################################################
166
# nn.Softmax
167
# ^^^^^^^^^^^^^^^^^^^^^^
168
# The last linear layer of the neural network returns `logits` - raw values in [-\infty, \infty] - which are passed to the
169
# `nn.Softmax <https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html>`_ module. The logits are scaled to values
170
# [0, 1] representing the model's predicted probabilities for each class. ``dim`` parameter indicates the dimension along
171
# which the values must sum to 1.
172
173
softmax = nn.Softmax(dim=1)
174
pred_probab = softmax(logits)
175
176
177
#################################################
178
# Model Parameters
179
# -------------------------
180
# Many layers inside a neural network are *parameterized*, i.e. have associated weights
181
# and biases that are optimized during training. Subclassing ``nn.Module`` automatically
182
# tracks all fields defined inside your model object, and makes all parameters
183
# accessible using your model's ``parameters()`` or ``named_parameters()`` methods.
184
#
185
# In this example, we iterate over each parameter, and print its size and a preview of its values.
186
#
187
188
189
print(f"Model structure: {model}\n\n")
190
191
for name, param in model.named_parameters():
192
print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")
193
194
######################################################################
195
# --------------
196
#
197
198
#################################################################
199
# Further Reading
200
# -----------------
201
# - `torch.nn API <https://pytorch.org/docs/stable/nn.html>`_
202
203