Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/examples_autograd/polynomial_custom_function.py
Views: 713
# -*- coding: utf-8 -*-1"""2PyTorch: Defining New autograd Functions3----------------------------------------45A third order polynomial, trained to predict :math:`y=\sin(x)` from :math:`-\pi`6to :math:`\pi` by minimizing squared Euclidean distance. Instead of writing the7polynomial as :math:`y=a+bx+cx^2+dx^3`, we write the polynomial as8:math:`y=a+b P_3(c+dx)` where :math:`P_3(x)=\\frac{1}{2}\\left(5x^3-3x\\right)` is9the `Legendre polynomial`_ of degree three.1011.. _Legendre polynomial:12https://en.wikipedia.org/wiki/Legendre_polynomials1314This implementation computes the forward pass using operations on PyTorch15Tensors, and uses PyTorch autograd to compute gradients.1617In this implementation we implement our own custom autograd function to perform18:math:`P_3'(x)`. By mathematics, :math:`P_3'(x)=\\frac{3}{2}\\left(5x^2-1\\right)`19"""20import torch21import math222324class LegendrePolynomial3(torch.autograd.Function):25"""26We can implement our own custom autograd Functions by subclassing27torch.autograd.Function and implementing the forward and backward passes28which operate on Tensors.29"""3031@staticmethod32def forward(ctx, input):33"""34In the forward pass we receive a Tensor containing the input and return35a Tensor containing the output. ctx is a context object that can be used36to stash information for backward computation. You can cache arbitrary37objects for use in the backward pass using the ctx.save_for_backward method.38"""39ctx.save_for_backward(input)40return 0.5 * (5 * input ** 3 - 3 * input)4142@staticmethod43def backward(ctx, grad_output):44"""45In the backward pass we receive a Tensor containing the gradient of the loss46with respect to the output, and we need to compute the gradient of the loss47with respect to the input.48"""49input, = ctx.saved_tensors50return grad_output * 1.5 * (5 * input ** 2 - 1)515253dtype = torch.float54device = torch.device("cpu")55# device = torch.device("cuda:0") # Uncomment this to run on GPU5657# Create Tensors to hold input and outputs.58# By default, requires_grad=False, which indicates that we do not need to59# compute gradients with respect to these Tensors during the backward pass.60x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)61y = torch.sin(x)6263# Create random Tensors for weights. For this example, we need64# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized65# not too far from the correct result to ensure convergence.66# Setting requires_grad=True indicates that we want to compute gradients with67# respect to these Tensors during the backward pass.68a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)69b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)70c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)71d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)7273learning_rate = 5e-674for t in range(2000):75# To apply our Function, we use Function.apply method. We alias this as 'P3'.76P3 = LegendrePolynomial3.apply7778# Forward pass: compute predicted y using operations; we compute79# P3 using our custom autograd operation.80y_pred = a + b * P3(c + d * x)8182# Compute and print loss83loss = (y_pred - y).pow(2).sum()84if t % 100 == 99:85print(t, loss.item())8687# Use autograd to compute the backward pass.88loss.backward()8990# Update weights using gradient descent91with torch.no_grad():92a -= learning_rate * a.grad93b -= learning_rate * b.grad94c -= learning_rate * c.grad95d -= learning_rate * d.grad9697# Manually zero the gradients after updating weights98a.grad = None99b.grad = None100c.grad = None101d.grad = None102103print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')104105106