CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/examples_nn/polynomial_module.py
Views: 713
1
# -*- coding: utf-8 -*-
2
"""
3
PyTorch: Custom nn Modules
4
--------------------------
5
6
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
7
to :math:`\pi` by minimizing squared Euclidean distance.
8
9
This implementation defines the model as a custom Module subclass. Whenever you
10
want a model more complex than a simple sequence of existing Modules you will
11
need to define your model this way.
12
"""
13
import torch
14
import math
15
16
17
class Polynomial3(torch.nn.Module):
18
def __init__(self):
19
"""
20
In the constructor we instantiate four parameters and assign them as
21
member parameters.
22
"""
23
super().__init__()
24
self.a = torch.nn.Parameter(torch.randn(()))
25
self.b = torch.nn.Parameter(torch.randn(()))
26
self.c = torch.nn.Parameter(torch.randn(()))
27
self.d = torch.nn.Parameter(torch.randn(()))
28
29
def forward(self, x):
30
"""
31
In the forward function we accept a Tensor of input data and we must return
32
a Tensor of output data. We can use Modules defined in the constructor as
33
well as arbitrary operators on Tensors.
34
"""
35
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
36
37
def string(self):
38
"""
39
Just like any class in Python, you can also define custom method on PyTorch modules
40
"""
41
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'
42
43
44
# Create Tensors to hold input and outputs.
45
x = torch.linspace(-math.pi, math.pi, 2000)
46
y = torch.sin(x)
47
48
# Construct our model by instantiating the class defined above
49
model = Polynomial3()
50
51
# Construct our loss function and an Optimizer. The call to model.parameters()
52
# in the SGD constructor will contain the learnable parameters (defined
53
# with torch.nn.Parameter) which are members of the model.
54
criterion = torch.nn.MSELoss(reduction='sum')
55
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
56
for t in range(2000):
57
# Forward pass: Compute predicted y by passing x to the model
58
y_pred = model(x)
59
60
# Compute and print loss
61
loss = criterion(y_pred, y)
62
if t % 100 == 99:
63
print(t, loss.item())
64
65
# Zero gradients, perform a backward pass, and update the weights.
66
optimizer.zero_grad()
67
loss.backward()
68
optimizer.step()
69
70
print(f'Result: {model.string()}')
71
72