CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/examples_nn/dynamic_net.py
Views: 713
1
# -*- coding: utf-8 -*-
2
"""
3
PyTorch: Control Flow + Weight Sharing
4
--------------------------------------
5
6
To showcase the power of PyTorch dynamic graphs, we will implement a very strange
7
model: a third-fifth order polynomial that on each forward pass
8
chooses a random number between 4 and 5 and uses that many orders, reusing
9
the same weights multiple times to compute the fourth and fifth order.
10
"""
11
import random
12
import torch
13
import math
14
15
16
class DynamicNet(torch.nn.Module):
17
def __init__(self):
18
"""
19
In the constructor we instantiate five parameters and assign them as members.
20
"""
21
super().__init__()
22
self.a = torch.nn.Parameter(torch.randn(()))
23
self.b = torch.nn.Parameter(torch.randn(()))
24
self.c = torch.nn.Parameter(torch.randn(()))
25
self.d = torch.nn.Parameter(torch.randn(()))
26
self.e = torch.nn.Parameter(torch.randn(()))
27
28
def forward(self, x):
29
"""
30
For the forward pass of the model, we randomly choose either 4, 5
31
and reuse the e parameter to compute the contribution of these orders.
32
33
Since each forward pass builds a dynamic computation graph, we can use normal
34
Python control-flow operators like loops or conditional statements when
35
defining the forward pass of the model.
36
37
Here we also see that it is perfectly safe to reuse the same parameter many
38
times when defining a computational graph.
39
"""
40
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
41
for exp in range(4, random.randint(4, 6)):
42
y = y + self.e * x ** exp
43
return y
44
45
def string(self):
46
"""
47
Just like any class in Python, you can also define custom method on PyTorch modules
48
"""
49
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'
50
51
52
# Create Tensors to hold input and outputs.
53
x = torch.linspace(-math.pi, math.pi, 2000)
54
y = torch.sin(x)
55
56
# Construct our model by instantiating the class defined above
57
model = DynamicNet()
58
59
# Construct our loss function and an Optimizer. Training this strange model with
60
# vanilla stochastic gradient descent is tough, so we use momentum
61
criterion = torch.nn.MSELoss(reduction='sum')
62
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
63
for t in range(30000):
64
# Forward pass: Compute predicted y by passing x to the model
65
y_pred = model(x)
66
67
# Compute and print loss
68
loss = criterion(y_pred, y)
69
if t % 2000 == 1999:
70
print(t, loss.item())
71
72
# Zero gradients, perform a backward pass, and update the weights.
73
optimizer.zero_grad()
74
loss.backward()
75
optimizer.step()
76
77
print(f'Result: {model.string()}')
78
79