Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/examples_nn/dynamic_net.py
Views: 713
# -*- coding: utf-8 -*-1"""2PyTorch: Control Flow + Weight Sharing3--------------------------------------45To showcase the power of PyTorch dynamic graphs, we will implement a very strange6model: a third-fifth order polynomial that on each forward pass7chooses a random number between 4 and 5 and uses that many orders, reusing8the same weights multiple times to compute the fourth and fifth order.9"""10import random11import torch12import math131415class DynamicNet(torch.nn.Module):16def __init__(self):17"""18In the constructor we instantiate five parameters and assign them as members.19"""20super().__init__()21self.a = torch.nn.Parameter(torch.randn(()))22self.b = torch.nn.Parameter(torch.randn(()))23self.c = torch.nn.Parameter(torch.randn(()))24self.d = torch.nn.Parameter(torch.randn(()))25self.e = torch.nn.Parameter(torch.randn(()))2627def forward(self, x):28"""29For the forward pass of the model, we randomly choose either 4, 530and reuse the e parameter to compute the contribution of these orders.3132Since each forward pass builds a dynamic computation graph, we can use normal33Python control-flow operators like loops or conditional statements when34defining the forward pass of the model.3536Here we also see that it is perfectly safe to reuse the same parameter many37times when defining a computational graph.38"""39y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 340for exp in range(4, random.randint(4, 6)):41y = y + self.e * x ** exp42return y4344def string(self):45"""46Just like any class in Python, you can also define custom method on PyTorch modules47"""48return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?'495051# Create Tensors to hold input and outputs.52x = torch.linspace(-math.pi, math.pi, 2000)53y = torch.sin(x)5455# Construct our model by instantiating the class defined above56model = DynamicNet()5758# Construct our loss function and an Optimizer. Training this strange model with59# vanilla stochastic gradient descent is tough, so we use momentum60criterion = torch.nn.MSELoss(reduction='sum')61optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)62for t in range(30000):63# Forward pass: Compute predicted y by passing x to the model64y_pred = model(x)6566# Compute and print loss67loss = criterion(y_pred, y)68if t % 2000 == 1999:69print(t, loss.item())7071# Zero gradients, perform a backward pass, and update the weights.72optimizer.zero_grad()73loss.backward()74optimizer.step()7576print(f'Result: {model.string()}')777879