CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/prototype_source/maskedtensor_overview.py
Views: 494
# -*- coding: utf-8 -*-12"""3(Prototype) MaskedTensor Overview4*********************************5"""67######################################################################8# This tutorial is designed to serve as a starting point for using MaskedTensors9# and discuss its masking semantics.10#11# MaskedTensor serves as an extension to :class:`torch.Tensor` that provides the user with the ability to:12#13# * use any masked semantics (for example, variable length tensors, nan* operators, etc.)14# * differentiation between 0 and NaN gradients15# * various sparse applications (see tutorial below)16#17# For a more detailed introduction on what MaskedTensors are, please find the18# `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__.19#20# Using MaskedTensor21# ==================22#23# In this section we discuss how to use MaskedTensor including how to construct, access, the data24# and mask, as well as indexing and slicing.25#26# Preparation27# -----------28#29# We'll begin by doing the necessary setup for the tutorial:30#3132import torch33from torch.masked import masked_tensor, as_masked_tensor34import warnings3536# Disable prototype warnings and such37warnings.filterwarnings(action='ignore', category=UserWarning)3839######################################################################40# Construction41# ------------42#43# There are a few different ways to construct a MaskedTensor:44#45# * The first way is to directly invoke the MaskedTensor class46# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor`47# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor`48#49# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`.50#51# Accessing the data and mask52# ---------------------------53#54# The underlying fields in a MaskedTensor can be accessed through:55#56# * the :meth:`MaskedTensor.get_data` function57# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid"58# while ``False`` indicates "unspecified" or "invalid".59#60# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that61# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to62# return a Tensor with filled values.63#64# Indexing and slicing65# --------------------66#67# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing68# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:69#7071data = torch.arange(24).reshape(2, 3, 4)72mask = data % 2 == 07374print("data:\n", data)75print("mask:\n", mask)7677######################################################################78#7980# float is used for cleaner visualization when being printed81mt = masked_tensor(data.float(), mask)8283print("mt[0]:\n", mt[0])84print("mt[:, :, 2:4]:\n", mt[:, :, 2:4])8586######################################################################87# Why is MaskedTensor useful?88# ===========================89#90# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen91# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings92# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues.93#94# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today95# and illustrate how :class:`MaskedTensor` can solve these problems.96#97# Distinguishing between 0 and NaN gradient98# -----------------------------------------99#100# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are101# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value102# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading103# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing104# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early105# in the chain of operations a NaN value manifests).106#107# :class:`MaskedTensor` is the perfect solution for this!108#109# torch.where110# ^^^^^^^^^^^111#112# In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations113# can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0114# or one from undefined gradients. Therefore, we remain consistent and mask out the results:115#116# Current result:117#118119x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)120y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))121y.sum().backward()122x.grad123124######################################################################125# :class:`MaskedTensor` result:126#127128x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])129mask = x < 0130mx = masked_tensor(x, mask, requires_grad=True)131my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)132y = torch.where(mask, torch.exp(mx), my)133y.sum().backward()134mx.grad135136######################################################################137# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`138# to mask out elements instead of setting them to zero.139#140# Another torch.where141# ^^^^^^^^^^^^^^^^^^^142#143# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example.144#145# Current result:146#147148a = torch.randn((), requires_grad=True)149b = torch.tensor(False)150c = torch.ones(())151print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))152print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))153154######################################################################155# :class:`MaskedTensor` result:156#157158a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)159b = torch.tensor(False)160c = torch.ones(())161print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))162print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))163164######################################################################165# This issue is similar (and even links to the next issue below) in that it expresses frustration with166# unexpected behavior because of the inability to differentiate "no gradient" vs "zero gradient",167# which in turn makes working with other ops difficult to reason about.168#169# When using mask, x/0 yields NaN grad170# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^171#172# In `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__, the user proposes that173# `x.grad` should be `[0, 1]` instead of the `[nan, 1]`,174# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether.175#176# Current result:177#178179x = torch.tensor([1., 1.], requires_grad=True)180div = torch.tensor([0., 1.])181y = x/div # => y is [inf, 1]182mask = (div != 0) # => mask is [0, 1]183y[mask].backward()184x.grad185186######################################################################187# :class:`MaskedTensor` result:188#189190x = torch.tensor([1., 1.], requires_grad=True)191div = torch.tensor([0., 1.])192y = x/div # => y is [inf, 1]193mask = (div != 0) # => mask is [0, 1]194loss = as_masked_tensor(y, mask)195loss.sum().backward()196x.grad197198######################################################################199# :func:`torch.nansum` and :func:`torch.nanmean`200# ----------------------------------------------201#202# In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__,203# the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly.204#205# Current result:206#207208a = torch.tensor([1., 2., float('nan')])209b = torch.tensor(1.0, requires_grad=True)210c = a * b211c1 = torch.nansum(c)212bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)213bgrad1214215######################################################################216# :class:`MaskedTensor` result:217#218219a = torch.tensor([1., 2., float('nan')])220b = torch.tensor(1.0, requires_grad=True)221mt = masked_tensor(a, ~torch.isnan(a))222c = mt * b223c1 = torch.sum(c)224bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)225bgrad1226227######################################################################228# Safe Softmax229# ------------230#231# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`__232# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"233# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),234# then this will result in NaNs, which can lead to training divergence.235#236# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup:237#238239data = torch.randn(3, 3)240mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])241x = data.masked_fill(~mask, float('-inf'))242mt = masked_tensor(data, mask)243print("x:\n", x)244print("mt:\n", mt)245246######################################################################247# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely248# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`.249# However, what we would really like is for the gradients to be masked out since they are unspecified and would be250# invalid for training.251#252# PyTorch result:253#254255x.softmax(0)256257######################################################################258# :class:`MaskedTensor` result:259#260261mt.softmax(0)262263######################################################################264# Implementing missing torch.nan* operators265# -----------------------------------------266#267# In `Issue 61474 <https://github.com/pytorch/pytorch/issues/61474>`__,268# there is a request to add additional operators to cover the various `torch.nan*` applications,269# such as ``torch.nanmax``, ``torch.nanmin``, etc.270#271# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional272# operators, we propose using :class:`MaskedTensor` instead.273# Since `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__,274# we can use it as a comparison point:275#276277x = torch.arange(16).float()278y = x * x.fmod(4)279z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros280281######################################################################282#283print("y:\n", y)284# z is just y with the zeros replaced with nan's285print("z:\n", z)286287######################################################################288#289290print("y.mean():\n", y.mean())291print("z.nanmean():\n", z.nanmean())292# MaskedTensor successfully ignores the 0's293print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))294295######################################################################296# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring297# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the298# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation,299# and we already have support for the other operations listed in the issue. For example:300#301302torch.argmin(masked_tensor(y, y != 0))303304######################################################################305# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1.306#307# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent308# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``309# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.310#311312x = torch.empty(16).fill_(float('nan'))313print("x:\n", x)314print("torch.nanmean(x):\n", torch.nanmean(x))315print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))316317######################################################################318# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.319#320# Conclusion321# ==========322#323# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their324# value through a series of examples and issues that they've helped resolve.325#326# Further Reading327# ===============328#329# To continue learning more, you can find our330# `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__331# to see how MaskedTensor enables sparsity and the different storage formats we currently support.332#333334335