CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/maskedtensor_adagrad.py
Views: 494
1
# -*- coding: utf-8 -*-
2
3
"""
4
(Prototype) Efficiently writing "sparse" semantics for Adagrad with MaskedTensor
5
================================================================================
6
"""
7
8
######################################################################
9
# Before working through this tutorial, please review the MaskedTensor
10
# `Overview <https://pytorch.org/tutorials/prototype/maskedtensor_overview.html>`__ and
11
# `Sparsity <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__ tutorials.
12
#
13
# Introduction and Motivation
14
# ---------------------------
15
# `Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
16
# that were introduced while writing "sparse" semantics for Adagrad, but really,
17
# the code uses sparsity as a proxy for masked semantics rather than the intended use case of sparsity:
18
# a compression and optimization technique.
19
# Previously, we worked around the lack of formal masked semantics by introducing one-off semantics and operators
20
# while forcing users to be aware of storage details such as indices and values.
21
#
22
# Now that we have masked semantics, we are better equipped to point out when sparsity is used as a semantic extension.
23
# We'll also compare and contrast this with equivalent code written using MaskedTensor.
24
# In the end the code snippets are repeated without additional comments to show the difference in brevity.
25
#
26
# Preparation
27
# -----------
28
#
29
30
import torch
31
import warnings
32
33
# Disable prototype warnings and such
34
warnings.filterwarnings(action='ignore', category=UserWarning)
35
36
# Some hyperparameters
37
eps = 1e-10
38
clr = 0.1
39
40
i = torch.tensor([[0, 1, 1], [2, 0, 2]])
41
v = torch.tensor([3, 4, 5], dtype=torch.float32)
42
grad = torch.sparse_coo_tensor(i, v, [2, 4])
43
44
######################################################################
45
# Simpler Code with MaskedTensor
46
# ------------------------------
47
#
48
# Before we get too far in the weeds, let's introduce the problem a bit more concretely. We will be taking a look
49
# into the `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
50
# implementation in PyTorch with the ultimate goal of simplifying and more faithfully representing the masked approach.
51
#
52
# For reference, this is the regular, dense code path without masked gradients or sparsity:
53
#
54
# .. code-block:: python
55
#
56
# state_sum.addcmul_(grad, grad, value=1)
57
# std = state_sum.sqrt().add_(eps)
58
# param.addcdiv_(grad, std, value=-clr)
59
#
60
# The vanilla tensor implementation for sparse is:
61
#
62
# .. code-block:: python
63
#
64
# def _make_sparse(grad, grad_indices, values):
65
# size = grad.size()
66
# if grad_indices.numel() == 0 or values.numel() == 0:
67
# return torch.empty_like(grad)
68
# return torch.sparse_coo_tensor(grad_indices, values, size)
69
#
70
# grad = grad.coalesce() # the update is non-linear so indices must be unique
71
# grad_indices = grad._indices()
72
# grad_values = grad._values()
73
#
74
# state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) # a different _make_sparse per layout
75
# std = state_sum.sparse_mask(grad)
76
# std_values = std._values().sqrt_().add_(eps)
77
# param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
78
#
79
# while :class:`MaskedTensor` minimizes the code to the snippet:
80
#
81
# .. code-block:: python
82
#
83
# state_sum2 = state_sum2 + masked_grad.pow(2).get_data()
84
# std2 = masked_tensor(state_sum2.to_sparse(), mask)
85
# std2 = std2.sqrt().add(eps)
86
# param2 = param2.add((masked_grad / std2).get_data(), alpha=-clr)
87
#
88
# In this tutorial, we will go through each implementation line by line, but at first glance, we can notice
89
# (1) how much shorter the MaskedTensor implementation is, and
90
# (2) how it avoids conversions between dense and sparse tensors.
91
#
92
93
######################################################################
94
# Original Sparse Implementation
95
# ------------------------------
96
#
97
# Now, let's break down the code with some inline comments:
98
#
99
100
def _make_sparse(grad, grad_indices, values):
101
size = grad.size()
102
if grad_indices.numel() == 0 or values.numel() == 0:
103
return torch.empty_like(grad)
104
return torch.sparse_coo_tensor(grad_indices, values, size)
105
106
# We don't support sparse gradients
107
param = torch.arange(8).reshape(2, 4).float()
108
state_sum = torch.full_like(param, 0.5) # initial value for state sum
109
110
grad = grad.coalesce() # the update is non-linear so indices must be unique
111
grad_indices = grad._indices()
112
grad_values = grad._values()
113
# pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
114
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
115
116
# We take care to make std sparse, even though state_sum clearly is not.
117
# This means that we're only applying the gradient to parts of the state_sum
118
# for which it is specified. This further drives the point home that the passed gradient is not sparse, but masked.
119
# We currently dodge all these concerns using the private method `_values`.
120
std = state_sum.sparse_mask(grad)
121
std_values = std._values().sqrt_().add_(eps)
122
123
# Note here that we currently don't support div for sparse Tensors because zero / zero is not well defined,
124
# so we're forced to perform `grad_values / std_values` outside the sparse semantic and then convert back to a
125
# sparse tensor with `make_sparse`.
126
# We'll later see that MaskedTensor will actually handle these operations for us as well as properly denote
127
# undefined / undefined = undefined!
128
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
129
130
######################################################################
131
# The third to last line -- `std = state_sum.sparse_mask(grad)` -- is where we have a very important divergence.
132
#
133
# The addition of eps should technically be applied to all values but instead is only applied to specified values.
134
# Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
135
# If parts of the values of the gradient are zero, they are still included if materialized even though they
136
# could be compressed by other sparse storage layouts. This is theoretically quite brittle!
137
# That said, one could argue that eps is always very small, so it might not matter so much in practice.
138
#
139
# Moreover, an implementation `add_` for sparsity as a storage layout and compression scheme
140
# should cause densification, but we force it not to for performance.
141
# For this one-off case it is fine.. until we want to introduce new compression scheme, such as
142
# `CSC <https://pytorch.org/docs/master/sparse.html#sparse-csc-docs>`__,
143
# `BSR <https://pytorch.org/docs/master/sparse.html#sparse-bsr-docs>`__,
144
# or `BSC <https://pytorch.org/docs/master/sparse.html#sparse-bsc-docs>`__.
145
# We will then need to introduce separate Tensor types for each and write variations for gradients compressed
146
# using different storage formats, which is inconvenient and not quite scalable nor clean.
147
#
148
# MaskedTensor Sparse Implementation
149
# ----------------------------------
150
#
151
# We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
152
# MaskedTensor proposes to disentangle the sparsity optimization from the semantic extension; for example,
153
# currently we can't have dense semantics with sparse storage or masked semantics with dense storage.
154
# MaskedTensor enables these ideas by purposefully separating the storage from the semantics.
155
#
156
# Consider the above example using a masked gradient:
157
#
158
159
# Let's now import MaskedTensor!
160
from torch.masked import masked_tensor
161
162
# Create an entirely new set of parameters to avoid errors
163
param2 = torch.arange(8).reshape(2, 4).float()
164
state_sum2 = torch.full_like(param, 0.5) # initial value for state sum
165
166
mask = (grad.to_dense() != 0).to_sparse()
167
masked_grad = masked_tensor(grad, mask)
168
169
state_sum2 = state_sum2 + masked_grad.pow(2).get_data()
170
std2 = masked_tensor(state_sum2.to_sparse(), mask)
171
172
# We can add support for in-place operations later. Notice how this doesn't
173
# need to access any storage internals and is in general a lot shorter
174
std2 = std2.sqrt().add(eps)
175
176
param2 = param2.add((masked_grad / std2).get_data(), alpha=-clr)
177
178
######################################################################
179
# Note that the implementations look quite similar, but the MaskedTensor implementation is shorter and simpler.
180
# In particular, much of the boilerplate code around ``_make_sparse``
181
# (and needing to have a separate implementation per layout) is handled for the user with :class:`MaskedTensor`.
182
#
183
# At this point, let's print both this version and original version for easier comparison:
184
#
185
186
print("state_sum:\n", state_sum)
187
print("state_sum2:\n", state_sum2)
188
189
######################################################################
190
#
191
192
print("std:\n", std)
193
print("std2:\n", std2)
194
195
######################################################################
196
#
197
198
print("param:\n", param)
199
print("param2:\n", param2)
200
201
######################################################################
202
# Conclusion
203
# ----------
204
#
205
# In this tutorial, we've discussed how native masked semantics can enable a cleaner developer experience for
206
# Adagrad's existing implementation in PyTorch, which used sparsity as a proxy for writing masked semantics.
207
# But more importantly, allowing masked semantics to be a first class citizen through MaskedTensor
208
# removes the reliance on sparsity or unreliable hacks to mimic masking, thereby allowing for proper independence
209
# and development, while enabling sparse semantics, such as this one.
210
#
211
# Further Reading
212
# ---------------
213
#
214
# To continue learning more, you can find our final review (for now) on
215
# `MaskedTensor Advanced Semantics <https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics.html>`__
216
# to see some of the differences in design decisions between :class:`MaskedTensor` and NumPy's MaskedArray, as well
217
# as reduction semantics.
218
#
219
220