Path: blob/main/beginner_source/examples_autograd/polynomial_custom_function.py
1384 views
"""1PyTorch: Defining New autograd Functions2----------------------------------------34A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`5to :math:`\pi` by minimizing squared Euclidean distance. Instead of writing the6polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as7:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\\frac{1}{2}\\left(5x^3-3x\\right)` is8the `Legendre polynomial`_ of degree three.910.. _Legendre polynomial:11https://en.wikipedia.org/wiki/Legendre_polynomials1213This implementation computes the forward pass using operations on PyTorch14Tensors, and uses PyTorch autograd to compute gradients.1516In this implementation we implement our own custom autograd function to perform17:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\\frac{3}{2}\\left(5x^2-1\\right)`18"""19import torch20import math212223class LegendrePolynomial3(torch.autograd.Function):24"""25We can implement our own custom autograd Functions by subclassing26torch.autograd.Function and implementing the forward and backward passes27which operate on Tensors.28"""2930@staticmethod31def forward(ctx, input):32"""33In the forward pass we receive a Tensor containing the input and return34a Tensor containing the output. ctx is a context object that can be used35to stash information for backward computation. You can cache tensors for36use in the backward pass using the ``ctx.save_for_backward`` method. Other37objects can be stored directly as attributes on the ctx object, such as38``ctx.my_object = my_object``. Check out `Extending torch.autograd <https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd>`_39for further details.40"""41ctx.save_for_backward(input)42return 0.5 * (5 * input ** 3 - 3 * input)4344@staticmethod45def backward(ctx, grad_output):46"""47In the backward pass we receive a Tensor containing the gradient of the loss48with respect to the output, and we need to compute the gradient of the loss49with respect to the input.50"""51input, = ctx.saved_tensors52return grad_output * 1.5 * (5 * input ** 2 - 1)535455dtype = torch.float56device = torch.device("cpu")57# device = torch.device("cuda:0") # Uncomment this to run on GPU5859# Create Tensors to hold input and outputs.60# By default, requires_grad=False, which indicates that we do not need to61# compute gradients with respect to these Tensors during the backward pass.62x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)63y = torch.sin(x)6465# Create random Tensors for weights. For this example, we need66# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized67# not too far from the correct result to ensure convergence.68# Setting requires_grad=True indicates that we want to compute gradients with69# respect to these Tensors during the backward pass.70a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)71b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)72c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)73d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)7475learning_rate = 5e-676for t in range(2000):77# To apply our Function, we use Function.apply method. We alias this as 'P3'.78P3 = LegendrePolynomial3.apply7980# Forward pass: compute predicted y using operations; we compute81# P3 using our custom autograd operation.82y_pred = a + b * P3(c + d * x)8384# Compute and print loss85loss = (y_pred - y).pow(2).sum()86if t % 100 == 99:87print(t, loss.item())8889# Use autograd to compute the backward pass.90loss.backward()9192# Update weights using gradient descent93with torch.no_grad():94a -= learning_rate * a.grad95b -= learning_rate * b.grad96c -= learning_rate * c.grad97d -= learning_rate * d.grad9899# Manually zero the gradients after updating weights100a.grad = None101b.grad = None102c.grad = None103d.grad = None104105print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')106107108