Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/reasoning_about_shapes.py
Views: 713
"""1Reasoning about Shapes in PyTorch2=================================34When writing models with PyTorch, it is commonly the case that the parameters5to a given layer depend on the shape of the output of the previous layer. For6example, the ``in_features`` of an ``nn.Linear`` layer must match the7``size(-1)`` of the input. For some layers, the shape computation involves8complex equations, for example convolution operations.910One way around this is to run the forward pass with random inputs, but this is11wasteful in terms of memory and compute.1213Instead, we can make use of the ``meta`` device to determine the output shapes14of a layer without materializing any data.15"""1617import torch18import timeit1920t = torch.rand(2, 3, 10, 10, device="meta")21conv = torch.nn.Conv2d(3, 5, 2, device="meta")22start = timeit.default_timer()23out = conv(t)24end = timeit.default_timer()2526print(out)27print(f"Time taken: {end-start}")282930##########################################################################31# Observe that since data is not materialized, passing arbitrarily large32# inputs will not significantly alter the time taken for shape computation.3334t_large = torch.rand(2**10, 3, 2**16, 2**16, device="meta")35start = timeit.default_timer()36out = conv(t_large)37end = timeit.default_timer()3839print(out)40print(f"Time taken: {end-start}")414243######################################################44# Consider an arbitrary network such as the following:4546import torch.nn as nn47import torch.nn.functional as F484950class Net(nn.Module):51def __init__(self):52super().__init__()53self.conv1 = nn.Conv2d(3, 6, 5)54self.pool = nn.MaxPool2d(2, 2)55self.conv2 = nn.Conv2d(6, 16, 5)56self.fc1 = nn.Linear(16 * 5 * 5, 120)57self.fc2 = nn.Linear(120, 84)58self.fc3 = nn.Linear(84, 10)5960def forward(self, x):61x = self.pool(F.relu(self.conv1(x)))62x = self.pool(F.relu(self.conv2(x)))63x = torch.flatten(x, 1) # flatten all dimensions except batch64x = F.relu(self.fc1(x))65x = F.relu(self.fc2(x))66x = self.fc3(x)67return x686970###############################################################################71# We can view the intermediate shapes within an entire network by registering a72# forward hook to each layer that prints the shape of the output.7374def fw_hook(module, input, output):75print(f"Shape of output to {module} is {output.shape}.")767778# Any tensor created within this torch.device context manager will be79# on the meta device.80with torch.device("meta"):81net = Net()82inp = torch.randn((1024, 3, 32, 32))8384for name, layer in net.named_modules():85layer.register_forward_hook(fw_hook)8687out = net(inp)888990