Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/numpy_extensions_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Creating Extensions Using NumPy and SciPy3=========================================4**Author**: `Adam Paszke <https://github.com/apaszke>`_56**Updated by**: `Adam Dziedzic <https://github.com/adam-dziedzic>`_78In this tutorial, we shall go through two tasks:9101. Create a neural network layer with no parameters.1112- This calls into **numpy** as part of its implementation13142. Create a neural network layer that has learnable weights1516- This calls into **SciPy** as part of its implementation17"""1819import torch20from torch.autograd import Function2122###############################################################23# Parameter-less example24# ----------------------25#26# This layer doesn’t particularly do anything useful or mathematically27# correct.28#29# It is aptly named ``BadFFTFunction``30#31# **Layer Implementation**3233from numpy.fft import rfft2, irfft2343536class BadFFTFunction(Function):37@staticmethod38def forward(ctx, input):39numpy_input = input.detach().numpy()40result = abs(rfft2(numpy_input))41return input.new(result)4243@staticmethod44def backward(ctx, grad_output):45numpy_go = grad_output.numpy()46result = irfft2(numpy_go)47return grad_output.new(result)4849# since this layer does not have any parameters, we can50# simply declare this as a function, rather than as an ``nn.Module`` class515253def incorrect_fft(input):54return BadFFTFunction.apply(input)5556###############################################################57# **Example usage of the created layer:**5859input = torch.randn(8, 8, requires_grad=True)60result = incorrect_fft(input)61print(result)62result.backward(torch.randn(result.size()))63print(input)6465###############################################################66# Parametrized example67# --------------------68#69# In deep learning literature, this layer is confusingly referred70# to as convolution while the actual operation is cross-correlation71# (the only difference is that filter is flipped for convolution,72# which is not the case for cross-correlation).73#74# Implementation of a layer with learnable weights, where cross-correlation75# has a filter (kernel) that represents weights.76#77# The backward pass computes the gradient ``wrt`` the input and the gradient ``wrt`` the filter.7879from numpy import flip80import numpy as np81from scipy.signal import convolve2d, correlate2d82from torch.nn.modules.module import Module83from torch.nn.parameter import Parameter848586class ScipyConv2dFunction(Function):87@staticmethod88def forward(ctx, input, filter, bias):89# detach so we can cast to NumPy90input, filter, bias = input.detach(), filter.detach(), bias.detach()91result = correlate2d(input.numpy(), filter.numpy(), mode='valid')92result += bias.numpy()93ctx.save_for_backward(input, filter, bias)94return torch.as_tensor(result, dtype=input.dtype)9596@staticmethod97def backward(ctx, grad_output):98grad_output = grad_output.detach()99input, filter, bias = ctx.saved_tensors100grad_output = grad_output.numpy()101grad_bias = np.sum(grad_output, keepdims=True)102grad_input = convolve2d(grad_output, filter.numpy(), mode='full')103# the previous line can be expressed equivalently as:104# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')105grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')106return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)107108109class ScipyConv2d(Module):110def __init__(self, filter_width, filter_height):111super(ScipyConv2d, self).__init__()112self.filter = Parameter(torch.randn(filter_width, filter_height))113self.bias = Parameter(torch.randn(1, 1))114115def forward(self, input):116return ScipyConv2dFunction.apply(input, self.filter, self.bias)117118119###############################################################120# **Example usage:**121122module = ScipyConv2d(3, 3)123print("Filter and bias: ", list(module.parameters()))124input = torch.randn(10, 10, requires_grad=True)125output = module(input)126print("Output from the convolution: ", output)127output.backward(torch.randn(8, 8))128print("Gradient for the input map: ", input.grad)129130###############################################################131# **Check the gradients:**132133from torch.autograd.gradcheck import gradcheck134135moduleConv = ScipyConv2d(3, 3)136137input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]138test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)139print("Are the gradients correct: ", test)140141142