CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/reasoning_about_shapes.py
Views: 494
1
"""
2
Reasoning about Shapes in PyTorch
3
=================================
4
5
When writing models with PyTorch, it is commonly the case that the parameters
6
to a given layer depend on the shape of the output of the previous layer. For
7
example, the ``in_features`` of an ``nn.Linear`` layer must match the
8
``size(-1)`` of the input. For some layers, the shape computation involves
9
complex equations, for example convolution operations.
10
11
One way around this is to run the forward pass with random inputs, but this is
12
wasteful in terms of memory and compute.
13
14
Instead, we can make use of the ``meta`` device to determine the output shapes
15
of a layer without materializing any data.
16
"""
17
18
import torch
19
import timeit
20
21
t = torch.rand(2, 3, 10, 10, device="meta")
22
conv = torch.nn.Conv2d(3, 5, 2, device="meta")
23
start = timeit.default_timer()
24
out = conv(t)
25
end = timeit.default_timer()
26
27
print(out)
28
print(f"Time taken: {end-start}")
29
30
31
##########################################################################
32
# Observe that since data is not materialized, passing arbitrarily large
33
# inputs will not significantly alter the time taken for shape computation.
34
35
t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")
36
start = timeit.default_timer()
37
out = conv(t_large)
38
end = timeit.default_timer()
39
40
print(out)
41
print(f"Time taken: {end-start}")
42
43
44
######################################################
45
# Consider an arbitrary network such as the following:
46
47
import torch.nn as nn
48
import torch.nn.functional as F
49
50
51
class Net(nn.Module):
52
def __init__(self):
53
super().__init__()
54
self.conv1 = nn.Conv2d(3, 6, 5)
55
self.pool = nn.MaxPool2d(2, 2)
56
self.conv2 = nn.Conv2d(6, 16, 5)
57
self.fc1 = nn.Linear(16 * 5 * 5, 120)
58
self.fc2 = nn.Linear(120, 84)
59
self.fc3 = nn.Linear(84, 10)
60
61
def forward(self, x):
62
x = self.pool(F.relu(self.conv1(x)))
63
x = self.pool(F.relu(self.conv2(x)))
64
x = torch.flatten(x, 1) # flatten all dimensions except batch
65
x = F.relu(self.fc1(x))
66
x = F.relu(self.fc2(x))
67
x = self.fc3(x)
68
return x
69
70
71
###############################################################################
72
# We can view the intermediate shapes within an entire network by registering a
73
# forward hook to each layer that prints the shape of the output.
74
75
def fw_hook(module, input, output):
76
print(f"Shape of output to {module} is {output.shape}.")
77
78
79
# Any tensor created within this torch.device context manager will be
80
# on the meta device.
81
with torch.device("meta"):
82
net = Net()
83
inp = torch.randn((1024, 3, 32, 32))
84
85
for name, layer in net.named_modules():
86
layer.register_forward_hook(fw_hook)
87
88
out = net(inp)
89
90