CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/numpy_extensions_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Creating Extensions Using NumPy and SciPy
4
=========================================
5
**Author**: `Adam Paszke <https://github.com/apaszke>`_
6
7
**Updated by**: `Adam Dziedzic <https://github.com/adam-dziedzic>`_
8
9
In this tutorial, we shall go through two tasks:
10
11
1. Create a neural network layer with no parameters.
12
13
- This calls into **numpy** as part of its implementation
14
15
2. Create a neural network layer that has learnable weights
16
17
- This calls into **SciPy** as part of its implementation
18
"""
19
20
import torch
21
from torch.autograd import Function
22
23
###############################################################
24
# Parameter-less example
25
# ----------------------
26
#
27
# This layer doesn’t particularly do anything useful or mathematically
28
# correct.
29
#
30
# It is aptly named ``BadFFTFunction``
31
#
32
# **Layer Implementation**
33
34
from numpy.fft import rfft2, irfft2
35
36
37
class BadFFTFunction(Function):
38
@staticmethod
39
def forward(ctx, input):
40
numpy_input = input.detach().numpy()
41
result = abs(rfft2(numpy_input))
42
return input.new(result)
43
44
@staticmethod
45
def backward(ctx, grad_output):
46
numpy_go = grad_output.numpy()
47
result = irfft2(numpy_go)
48
return grad_output.new(result)
49
50
# since this layer does not have any parameters, we can
51
# simply declare this as a function, rather than as an ``nn.Module`` class
52
53
54
def incorrect_fft(input):
55
return BadFFTFunction.apply(input)
56
57
###############################################################
58
# **Example usage of the created layer:**
59
60
input = torch.randn(8, 8, requires_grad=True)
61
result = incorrect_fft(input)
62
print(result)
63
result.backward(torch.randn(result.size()))
64
print(input)
65
66
###############################################################
67
# Parametrized example
68
# --------------------
69
#
70
# In deep learning literature, this layer is confusingly referred
71
# to as convolution while the actual operation is cross-correlation
72
# (the only difference is that filter is flipped for convolution,
73
# which is not the case for cross-correlation).
74
#
75
# Implementation of a layer with learnable weights, where cross-correlation
76
# has a filter (kernel) that represents weights.
77
#
78
# The backward pass computes the gradient ``wrt`` the input and the gradient ``wrt`` the filter.
79
80
from numpy import flip
81
import numpy as np
82
from scipy.signal import convolve2d, correlate2d
83
from torch.nn.modules.module import Module
84
from torch.nn.parameter import Parameter
85
86
87
class ScipyConv2dFunction(Function):
88
@staticmethod
89
def forward(ctx, input, filter, bias):
90
# detach so we can cast to NumPy
91
input, filter, bias = input.detach(), filter.detach(), bias.detach()
92
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
93
result += bias.numpy()
94
ctx.save_for_backward(input, filter, bias)
95
return torch.as_tensor(result, dtype=input.dtype)
96
97
@staticmethod
98
def backward(ctx, grad_output):
99
grad_output = grad_output.detach()
100
input, filter, bias = ctx.saved_tensors
101
grad_output = grad_output.numpy()
102
grad_bias = np.sum(grad_output, keepdims=True)
103
grad_input = convolve2d(grad_output, filter.numpy(), mode='full')
104
# the previous line can be expressed equivalently as:
105
# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')
106
grad_filter = correlate2d(input.numpy(), grad_output, mode='valid')
107
return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float)
108
109
110
class ScipyConv2d(Module):
111
def __init__(self, filter_width, filter_height):
112
super(ScipyConv2d, self).__init__()
113
self.filter = Parameter(torch.randn(filter_width, filter_height))
114
self.bias = Parameter(torch.randn(1, 1))
115
116
def forward(self, input):
117
return ScipyConv2dFunction.apply(input, self.filter, self.bias)
118
119
120
###############################################################
121
# **Example usage:**
122
123
module = ScipyConv2d(3, 3)
124
print("Filter and bias: ", list(module.parameters()))
125
input = torch.randn(10, 10, requires_grad=True)
126
output = module(input)
127
print("Output from the convolution: ", output)
128
output.backward(torch.randn(8, 8))
129
print("Gradient for the input map: ", input.grad)
130
131
###############################################################
132
# **Check the gradients:**
133
134
from torch.autograd.gradcheck import gradcheck
135
136
moduleConv = ScipyConv2d(3, 3)
137
138
input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
139
test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4)
140
print("Are the gradients correct: ", test)
141
142