CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/maskedtensor_overview.py
Views: 494
1
# -*- coding: utf-8 -*-
2
3
"""
4
(Prototype) MaskedTensor Overview
5
*********************************
6
"""
7
8
######################################################################
9
# This tutorial is designed to serve as a starting point for using MaskedTensors
10
# and discuss its masking semantics.
11
#
12
# MaskedTensor serves as an extension to :class:`torch.Tensor` that provides the user with the ability to:
13
#
14
# * use any masked semantics (for example, variable length tensors, nan* operators, etc.)
15
# * differentiation between 0 and NaN gradients
16
# * various sparse applications (see tutorial below)
17
#
18
# For a more detailed introduction on what MaskedTensors are, please find the
19
# `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__.
20
#
21
# Using MaskedTensor
22
# ==================
23
#
24
# In this section we discuss how to use MaskedTensor including how to construct, access, the data
25
# and mask, as well as indexing and slicing.
26
#
27
# Preparation
28
# -----------
29
#
30
# We'll begin by doing the necessary setup for the tutorial:
31
#
32
33
import torch
34
from torch.masked import masked_tensor, as_masked_tensor
35
import warnings
36
37
# Disable prototype warnings and such
38
warnings.filterwarnings(action='ignore', category=UserWarning)
39
40
######################################################################
41
# Construction
42
# ------------
43
#
44
# There are a few different ways to construct a MaskedTensor:
45
#
46
# * The first way is to directly invoke the MaskedTensor class
47
# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor`
48
# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor`
49
#
50
# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`.
51
#
52
# Accessing the data and mask
53
# ---------------------------
54
#
55
# The underlying fields in a MaskedTensor can be accessed through:
56
#
57
# * the :meth:`MaskedTensor.get_data` function
58
# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid"
59
# while ``False`` indicates "unspecified" or "invalid".
60
#
61
# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that
62
# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to
63
# return a Tensor with filled values.
64
#
65
# Indexing and slicing
66
# --------------------
67
#
68
# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing
69
# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
70
#
71
72
data = torch.arange(24).reshape(2, 3, 4)
73
mask = data % 2 == 0
74
75
print("data:\n", data)
76
print("mask:\n", mask)
77
78
######################################################################
79
#
80
81
# float is used for cleaner visualization when being printed
82
mt = masked_tensor(data.float(), mask)
83
84
print("mt[0]:\n", mt[0])
85
print("mt[:, :, 2:4]:\n", mt[:, :, 2:4])
86
87
######################################################################
88
# Why is MaskedTensor useful?
89
# ===========================
90
#
91
# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
92
# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
93
# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues.
94
#
95
# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today
96
# and illustrate how :class:`MaskedTensor` can solve these problems.
97
#
98
# Distinguishing between 0 and NaN gradient
99
# -----------------------------------------
100
#
101
# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are
102
# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value
103
# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading
104
# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing
105
# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early
106
# in the chain of operations a NaN value manifests).
107
#
108
# :class:`MaskedTensor` is the perfect solution for this!
109
#
110
# torch.where
111
# ^^^^^^^^^^^
112
#
113
# In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations
114
# can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0
115
# or one from undefined gradients. Therefore, we remain consistent and mask out the results:
116
#
117
# Current result:
118
#
119
120
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
121
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
122
y.sum().backward()
123
x.grad
124
125
######################################################################
126
# :class:`MaskedTensor` result:
127
#
128
129
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
130
mask = x < 0
131
mx = masked_tensor(x, mask, requires_grad=True)
132
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
133
y = torch.where(mask, torch.exp(mx), my)
134
y.sum().backward()
135
mx.grad
136
137
######################################################################
138
# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
139
# to mask out elements instead of setting them to zero.
140
#
141
# Another torch.where
142
# ^^^^^^^^^^^^^^^^^^^
143
#
144
# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example.
145
#
146
# Current result:
147
#
148
149
a = torch.randn((), requires_grad=True)
150
b = torch.tensor(False)
151
c = torch.ones(())
152
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
153
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
154
155
######################################################################
156
# :class:`MaskedTensor` result:
157
#
158
159
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
160
b = torch.tensor(False)
161
c = torch.ones(())
162
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
163
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
164
165
######################################################################
166
# This issue is similar (and even links to the next issue below) in that it expresses frustration with
167
# unexpected behavior because of the inability to differentiate "no gradient" vs "zero gradient",
168
# which in turn makes working with other ops difficult to reason about.
169
#
170
# When using mask, x/0 yields NaN grad
171
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
172
#
173
# In `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__, the user proposes that
174
# `x.grad` should be `[0, 1]` instead of the `[nan, 1]`,
175
# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether.
176
#
177
# Current result:
178
#
179
180
x = torch.tensor([1., 1.], requires_grad=True)
181
div = torch.tensor([0., 1.])
182
y = x/div # => y is [inf, 1]
183
mask = (div != 0) # => mask is [0, 1]
184
y[mask].backward()
185
x.grad
186
187
######################################################################
188
# :class:`MaskedTensor` result:
189
#
190
191
x = torch.tensor([1., 1.], requires_grad=True)
192
div = torch.tensor([0., 1.])
193
y = x/div # => y is [inf, 1]
194
mask = (div != 0) # => mask is [0, 1]
195
loss = as_masked_tensor(y, mask)
196
loss.sum().backward()
197
x.grad
198
199
######################################################################
200
# :func:`torch.nansum` and :func:`torch.nanmean`
201
# ----------------------------------------------
202
#
203
# In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__,
204
# the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly.
205
#
206
# Current result:
207
#
208
209
a = torch.tensor([1., 2., float('nan')])
210
b = torch.tensor(1.0, requires_grad=True)
211
c = a * b
212
c1 = torch.nansum(c)
213
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
214
bgrad1
215
216
######################################################################
217
# :class:`MaskedTensor` result:
218
#
219
220
a = torch.tensor([1., 2., float('nan')])
221
b = torch.tensor(1.0, requires_grad=True)
222
mt = masked_tensor(a, ~torch.isnan(a))
223
c = mt * b
224
c1 = torch.sum(c)
225
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
226
bgrad1
227
228
######################################################################
229
# Safe Softmax
230
# ------------
231
#
232
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`__
233
# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
234
# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
235
# then this will result in NaNs, which can lead to training divergence.
236
#
237
# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup:
238
#
239
240
data = torch.randn(3, 3)
241
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
242
x = data.masked_fill(~mask, float('-inf'))
243
mt = masked_tensor(data, mask)
244
print("x:\n", x)
245
print("mt:\n", mt)
246
247
######################################################################
248
# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely
249
# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`.
250
# However, what we would really like is for the gradients to be masked out since they are unspecified and would be
251
# invalid for training.
252
#
253
# PyTorch result:
254
#
255
256
x.softmax(0)
257
258
######################################################################
259
# :class:`MaskedTensor` result:
260
#
261
262
mt.softmax(0)
263
264
######################################################################
265
# Implementing missing torch.nan* operators
266
# -----------------------------------------
267
#
268
# In `Issue 61474 <https://github.com/pytorch/pytorch/issues/61474>`__,
269
# there is a request to add additional operators to cover the various `torch.nan*` applications,
270
# such as ``torch.nanmax``, ``torch.nanmin``, etc.
271
#
272
# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
273
# operators, we propose using :class:`MaskedTensor` instead.
274
# Since `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__,
275
# we can use it as a comparison point:
276
#
277
278
x = torch.arange(16).float()
279
y = x * x.fmod(4)
280
z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros
281
282
######################################################################
283
#
284
print("y:\n", y)
285
# z is just y with the zeros replaced with nan's
286
print("z:\n", z)
287
288
######################################################################
289
#
290
291
print("y.mean():\n", y.mean())
292
print("z.nanmean():\n", z.nanmean())
293
# MaskedTensor successfully ignores the 0's
294
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
295
296
######################################################################
297
# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring
298
# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the
299
# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation,
300
# and we already have support for the other operations listed in the issue. For example:
301
#
302
303
torch.argmin(masked_tensor(y, y != 0))
304
305
######################################################################
306
# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1.
307
#
308
# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
309
# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
310
# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.
311
#
312
313
x = torch.empty(16).fill_(float('nan'))
314
print("x:\n", x)
315
print("torch.nanmean(x):\n", torch.nanmean(x))
316
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
317
318
######################################################################
319
# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
320
#
321
# Conclusion
322
# ==========
323
#
324
# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
325
# value through a series of examples and issues that they've helped resolve.
326
#
327
# Further Reading
328
# ===============
329
#
330
# To continue learning more, you can find our
331
# `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__
332
# to see how MaskedTensor enables sparsity and the different storage formats we currently support.
333
#
334
335