Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/examples_autograd/polynomial_custom_function.py
1384 views
1
"""
2
PyTorch: Defining New autograd Functions
3
----------------------------------------
4
5
A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`
6
to :math:`\pi` by minimizing squared Euclidean distance. Instead of writing the
7
polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as
8
:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\\frac{1}{2}\\left(5x^3-3x\\right)` is
9
the `Legendre polynomial`_ of degree three.
10
11
.. _Legendre polynomial:
12
https://en.wikipedia.org/wiki/Legendre_polynomials
13
14
This implementation computes the forward pass using operations on PyTorch
15
Tensors, and uses PyTorch autograd to compute gradients.
16
17
In this implementation we implement our own custom autograd function to perform
18
:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\\frac{3}{2}\\left(5x^2-1\\right)`
19
"""
20
import torch
21
import math
22
23
24
class LegendrePolynomial3(torch.autograd.Function):
25
"""
26
We can implement our own custom autograd Functions by subclassing
27
torch.autograd.Function and implementing the forward and backward passes
28
which operate on Tensors.
29
"""
30
31
@staticmethod
32
def forward(ctx, input):
33
"""
34
In the forward pass we receive a Tensor containing the input and return
35
a Tensor containing the output. ctx is a context object that can be used
36
to stash information for backward computation. You can cache tensors for
37
use in the backward pass using the ``ctx.save_for_backward`` method. Other
38
objects can be stored directly as attributes on the ctx object, such as
39
``ctx.my_object = my_object``. Check out `Extending torch.autograd <https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd>`_
40
for further details.
41
"""
42
ctx.save_for_backward(input)
43
return 0.5 * (5 * input ** 3 - 3 * input)
44
45
@staticmethod
46
def backward(ctx, grad_output):
47
"""
48
In the backward pass we receive a Tensor containing the gradient of the loss
49
with respect to the output, and we need to compute the gradient of the loss
50
with respect to the input.
51
"""
52
input, = ctx.saved_tensors
53
return grad_output * 1.5 * (5 * input ** 2 - 1)
54
55
56
dtype = torch.float
57
device = torch.device("cpu")
58
# device = torch.device("cuda:0") # Uncomment this to run on GPU
59
60
# Create Tensors to hold input and outputs.
61
# By default, requires_grad=False, which indicates that we do not need to
62
# compute gradients with respect to these Tensors during the backward pass.
63
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
64
y = torch.sin(x)
65
66
# Create random Tensors for weights. For this example, we need
67
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
68
# not too far from the correct result to ensure convergence.
69
# Setting requires_grad=True indicates that we want to compute gradients with
70
# respect to these Tensors during the backward pass.
71
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
72
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
73
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
74
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)
75
76
learning_rate = 5e-6
77
for t in range(2000):
78
# To apply our Function, we use Function.apply method. We alias this as 'P3'.
79
P3 = LegendrePolynomial3.apply
80
81
# Forward pass: compute predicted y using operations; we compute
82
# P3 using our custom autograd operation.
83
y_pred = a + b * P3(c + d * x)
84
85
# Compute and print loss
86
loss = (y_pred - y).pow(2).sum()
87
if t % 100 == 99:
88
print(t, loss.item())
89
90
# Use autograd to compute the backward pass.
91
loss.backward()
92
93
# Update weights using gradient descent
94
with torch.no_grad():
95
a -= learning_rate * a.grad
96
b -= learning_rate * b.grad
97
c -= learning_rate * c.grad
98
d -= learning_rate * d.grad
99
100
# Manually zero the gradients after updating weights
101
a.grad = None
102
b.grad = None
103
c.grad = None
104
d.grad = None
105
106
print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')
107
108