CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/examples_autograd/polynomial_custom_function.py
Views: 713
1
# -*- coding: utf-8 -*-
2
"""
3
PyTorch: Defining New autograd Functions
4
----------------------------------------
5
6
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
7
to :math:`\pi` by minimizing squared Euclidean distance. Instead of writing the
8
polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as
9
:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\\frac{1}{2}\\left(5x^3-3x\\right)` is
10
the `Legendre polynomial`_ of degree three.
11
12
.. _Legendre polynomial:
13
https://en.wikipedia.org/wiki/Legendre_polynomials
14
15
This implementation computes the forward pass using operations on PyTorch
16
Tensors, and uses PyTorch autograd to compute gradients.
17
18
In this implementation we implement our own custom autograd function to perform
19
:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\\frac{3}{2}\\left(5x^2-1\\right)`
20
"""
21
import torch
22
import math
23
24
25
class LegendrePolynomial3(torch.autograd.Function):
26
"""
27
We can implement our own custom autograd Functions by subclassing
28
torch.autograd.Function and implementing the forward and backward passes
29
which operate on Tensors.
30
"""
31
32
@staticmethod
33
def forward(ctx, input):
34
"""
35
In the forward pass we receive a Tensor containing the input and return
36
a Tensor containing the output. ctx is a context object that can be used
37
to stash information for backward computation. You can cache arbitrary
38
objects for use in the backward pass using the ctx.save_for_backward method.
39
"""
40
ctx.save_for_backward(input)
41
return 0.5 * (5 * input ** 3 - 3 * input)
42
43
@staticmethod
44
def backward(ctx, grad_output):
45
"""
46
In the backward pass we receive a Tensor containing the gradient of the loss
47
with respect to the output, and we need to compute the gradient of the loss
48
with respect to the input.
49
"""
50
input, = ctx.saved_tensors
51
return grad_output * 1.5 * (5 * input ** 2 - 1)
52
53
54
dtype = torch.float
55
device = torch.device("cpu")
56
# device = torch.device("cuda:0") # Uncomment this to run on GPU
57
58
# Create Tensors to hold input and outputs.
59
# By default, requires_grad=False, which indicates that we do not need to
60
# compute gradients with respect to these Tensors during the backward pass.
61
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
62
y = torch.sin(x)
63
64
# Create random Tensors for weights. For this example, we need
65
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
66
# not too far from the correct result to ensure convergence.
67
# Setting requires_grad=True indicates that we want to compute gradients with
68
# respect to these Tensors during the backward pass.
69
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
70
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
71
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
72
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
73
74
learning_rate = 5e-6
75
for t in range(2000):
76
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
77
P3 = LegendrePolynomial3.apply
78
79
# Forward pass: compute predicted y using operations; we compute
80
# P3 using our custom autograd operation.
81
y_pred = a + b * P3(c + d * x)
82
83
# Compute and print loss
84
loss = (y_pred - y).pow(2).sum()
85
if t % 100 == 99:
86
print(t, loss.item())
87
88
# Use autograd to compute the backward pass.
89
loss.backward()
90
91
# Update weights using gradient descent
92
with torch.no_grad():
93
a -= learning_rate * a.grad
94
b -= learning_rate * b.grad
95
c -= learning_rate * c.grad
96
d -= learning_rate * d.grad
97
98
# Manually zero the gradients after updating weights
99
a.grad = None
100
b.grad = None
101
c.grad = None
102
d.grad = None
103
104
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')
105
106