Skip to content

Commit 49c50ee

Browse files
committed
[maskedtensor] Overview tutorial [1/4]
1 parent 04e1ba9 commit 49c50ee

File tree

2 files changed

+314
-0
lines changed

2 files changed

+314
-0
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Prototype) MaskedTensor Overview
5+
=====================
6+
**Author**: `George Qi <https://github.com/george-qi>`_
7+
"""
8+
9+
######################################################################
10+
# This tutorial is designed to serve as a starting point for using MaskedTensors
11+
# and discuss its masking semantics.
12+
#
13+
14+
######################################################################
15+
# Using MaskedTensor
16+
# ++++++++++++++++++
17+
#
18+
# Construction
19+
# ------------
20+
#
21+
# There are a few different ways to construct a MaskedTensor:
22+
#
23+
# * The first way is to directly invoke the MaskedTensor class
24+
# * The second (and our recommended way) is to use :func:`masked.masked_tensor` and :func:`masked.as_masked_tensor`
25+
# factory functions, which are analogous to :func:`torch.tensor` and :func:`torch.as_tensor`
26+
#
27+
# Throughout this tutorial, we will be assuming the import line: `from torch.masked import masked_tensor`.
28+
#
29+
# Accessing the data and mask
30+
# ---------------------------
31+
#
32+
# The underlying fields in a MaskedTensor can be accessed through:
33+
#
34+
# * the :meth:`MaskedTensor.get_data` function
35+
# * the :meth:`MaskedTensor.get_mask` function. Recall that ``True`` indicates "specified" or "valid"
36+
# while ``False`` indicates "unspecified" or "invalid".
37+
#
38+
# In general, the underlying data that is returned may not be valid in the unspecified entries, so we recommend that
39+
# when users require a Tensor without any masked entries, that they use :meth:`MaskedTensor.to_tensor` (as shown above) to
40+
# return a Tensor with filled values.
41+
#
42+
# Indexing and slicing
43+
# --------------------
44+
#
45+
# :class:`MaskedTensor` is a Tensor subclass, which means that it inherits the same semantics for indexing and slicing
46+
# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
47+
#
48+
49+
import torch
50+
from torch.masked import masked_tensor
51+
52+
data = torch.arange(24).reshape(2, 3, 4)
53+
mask = data % 2 == 0
54+
55+
print("data\n", data)
56+
print("mask\n", mask)
57+
58+
# float is used for cleaner visualization when being printed
59+
mt = masked_tensor(data.float(), mask)
60+
61+
print ("mt[0]:\n", mt[0])
62+
print ("mt[:, :, 2:4]", mt[:, :, 2:4])
63+
64+
######################################################################
65+
# Why is MaskedTensor useful?
66+
# +++++++++++++++++++++++++++
67+
#
68+
# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
69+
# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
70+
# that regular Tensors are unable to; indeed, :class:`MaskedTensor` was born in a large part due to these recurring issues.
71+
#
72+
# Below, we will discuss some of the most common issues that are still unresolved in PyTorch today
73+
# and illustrate how :class:`MaskedTensor` can solve these problems.
74+
#
75+
# Distinguishing between 0 and NaN gradient
76+
# -----------------------------------------
77+
#
78+
# One issue that :class:`torch.Tensor` runs into is the inability to distinguish between gradients that are
79+
# undefined (NaN) vs. gradients that are actually 0. Because PyTorch does not have a way of marking a value
80+
# as specified/valid vs. unspecified/invalid, it is forced to rely on NaN or 0 (depending on the use case), leading
81+
# to unreliable semantics since many operations aren't meant to handle NaN values properly. What is even more confusing
82+
# is that sometimes depending on the order of operations, the gradient could vary (for example, depending on how early
83+
# in the chain of operations a NaN value manifests).
84+
#
85+
# :class:`MaskedTensor` is the perfect solution for this!
86+
#
87+
# :func:`torch.where`
88+
# ^^^^^^^^^^^^^^^^^^^
89+
#
90+
# In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations
91+
# can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0
92+
# or one from undefined gradients. Therefore, we remain consistent and mask out the results:
93+
#
94+
# Current result:
95+
#
96+
97+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], requires_grad=True, dtype=torch.float)
98+
y = torch.where(x < 0, torch.exp(x), torch.ones_like(x))
99+
y.sum().backward()
100+
x.grad
101+
102+
######################################################################
103+
# :class:`MaskedTensor` result:
104+
#
105+
106+
x = torch.tensor([-10., -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
107+
mask = x < 0
108+
mx = masked_tensor(x, mask, requires_grad=True)
109+
my = masked_tensor(torch.ones_like(x), ~mask, requires_grad=True)
110+
y = torch.where(mask, torch.exp(mx), my)
111+
y.sum().backward()
112+
mx.grad
113+
114+
######################################################################
115+
# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
116+
# to mask out elements instead of setting them to zero.
117+
#
118+
# Another :func:`torch.where`
119+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
120+
#
121+
# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example.
122+
#
123+
# Current result:
124+
#
125+
126+
a = torch.randn((), requires_grad=True)
127+
b = torch.tensor(False)
128+
c = torch.ones(())
129+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
130+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
131+
132+
######################################################################
133+
# :class:`MaskedTensor` result:
134+
#
135+
136+
a = masked_tensor(torch.randn(()), torch.tensor(True), requires_grad=True)
137+
b = torch.tensor(False)
138+
c = torch.ones(())
139+
print("torch.where(b, a/0, c):\n", torch.where(b, a/0, c))
140+
print("torch.autograd.grad(torch.where(b, a/0, c), a):\n", torch.autograd.grad(torch.where(b, a/0, c), a))
141+
142+
######################################################################
143+
# This issue is similar (and even links to the next issue below) in that it expresses frustration with
144+
# unexpected behavior because of the inability to differentiate "no gradient" vs "zero gradient",
145+
# which in turn makes working with other ops difficult to reason about.
146+
#
147+
# When using mask, x/0 yields NaN grad
148+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
149+
#
150+
# In `Issue 4132 <https://github.com/pytorch/pytorch/issues/4132>`__, the user proposes that
151+
# `x.grad` should be `[0, 1]` instead of the `[nan, 1]`,
152+
# whereas :class:`MaskedTensor` makes this very clear by masking out the gradient altogether.
153+
#
154+
# Current result:
155+
#
156+
157+
x = torch.tensor([1., 1.], requires_grad=True)
158+
div = torch.tensor([0., 1.])
159+
y = x/div # => y is [inf, 1]
160+
mask = (div != 0) # => mask is [0, 1]
161+
y[mask].backward()
162+
x.grad
163+
164+
######################################################################
165+
# :class:`MaskedTensor` result:
166+
#
167+
168+
x = torch.tensor([1., 1.], requires_grad=True)
169+
div = torch.tensor([0., 1.])
170+
y = x/div # => y is [inf, 1]
171+
>>>
172+
mask = (div != 0) # => mask is [0, 1]
173+
loss = as_masked_tensor(y, mask)
174+
loss.sum().backward()
175+
x.grad
176+
177+
######################################################################
178+
# :func:`torch.nansum` and :func:`torch.nanmean`
179+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
180+
#
181+
# In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__,
182+
# the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly.
183+
#
184+
# Current result:
185+
#
186+
187+
a = torch.tensor([1., 2., float('nan')])
188+
b = torch.tensor(1.0, requires_grad=True)
189+
c = a * b
190+
c1 = torch.nansum(c)
191+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
192+
bgrad1
193+
194+
######################################################################
195+
# :class:`MaskedTensor` result:
196+
#
197+
198+
a = torch.tensor([1., 2., float('nan')])
199+
b = torch.tensor(1.0, requires_grad=True)
200+
mt = masked_tensor(a, ~torch.isnan(a))
201+
c = mt * b
202+
c1 = torch.sum(c)
203+
bgrad1, = torch.autograd.grad(c1, b, retain_graph=True)
204+
bgrad1
205+
206+
######################################################################
207+
# Safe Softmax
208+
# ------------
209+
#
210+
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_
211+
# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
212+
# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
213+
# then this will result in NaNs, which can lead to training divergence.
214+
#
215+
# Luckily, :class:`MaskedTensor` has solved this issue. Consider this setup:
216+
#
217+
218+
data = torch.randn(3, 3)
219+
mask = torch.tensor([[True, False, False], [True, False, True], [False, False, False]])
220+
x = data.masked_fill(~mask, float('-inf'))
221+
mt = masked_tensor(data, mask)
222+
print("x:\n", x)
223+
print("mt:\n", mt)
224+
225+
######################################################################
226+
# For example, we want to calculate the softmax along `dim=0`. Note that the second column is "unsafe" (i.e. entirely
227+
# masked out), so when the softmax is calculated, the result will yield `0/0 = nan` since `exp(-inf) = 0`.
228+
# However, what we would really like is for the gradients to be masked out since they are unspecified and would be
229+
# invalid for training.
230+
#
231+
# PyTorch result:
232+
#
233+
234+
x.softmax(0)
235+
236+
######################################################################
237+
# :class:`MaskedTensor` result:
238+
#
239+
240+
mt.softmax(0)
241+
242+
######################################################################
243+
# Implementing missing torch.nan* operators
244+
# --------------------------------------------------------------------------------------------------------------
245+
#
246+
# In `Issue 61474 <<https://github.com/pytorch/pytorch/issues/61474>`__,
247+
# there is a request to add additional operators to cover the various `torch.nan*` applications,
248+
# such as ``torch.nanmax``, ``torch.nanmin``, etc.
249+
#
250+
# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
251+
# operators, we propose using :class:`MaskedTensor`s instead. Since
252+
# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point:
253+
#
254+
255+
x = torch.arange(16).float()
256+
y = x * x.fmod(4)
257+
z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros
258+
259+
print("y:\n, y")
260+
# z is just y with the zeros replaced with nan's
261+
print("z:\n", z)
262+
print("y.mean():\n", y.mean())
263+
print("z.nanmean():\n", z.nanmean())
264+
# MaskedTensor successfully ignores the 0's
265+
print("torch.mean(masked_tensor(y, y != 0)):\n", torch.mean(masked_tensor(y, y != 0)))
266+
267+
######################################################################
268+
# In the above example, we've constructed a `y` and would like to calculate the mean of the series while ignoring
269+
# the zeros. `torch.nanmean` can be used to do this, but we don't have implementations for the rest of the
270+
# `torch.nan*` operations. :class:`MaskedTensor` solves this issue by being able to use the base operation,
271+
# and we already have support for the other operations listed in the issue. For example:
272+
#
273+
274+
torch.argmin(masked_tensor(y, y != 0))
275+
276+
######################################################################
277+
# Indeed, the index of the minimum argument when ignoring the 0's is the 1 in index 1.
278+
#
279+
# :class:`MaskedTensor` can also support reductions when the data is fully masked out, which is equivalent
280+
# to the case above when the data Tensor is completely ``nan``. ``nanmean`` would return ``nan``
281+
# (an ambiguous return value), while MaskedTensor would more accurately indicate a masked out result.
282+
#
283+
284+
x = torch.empty(16).fill_(float('nan'))
285+
print("x:\n", x)
286+
print("torch.nanmean(x):\n", torch.nanmean(x))
287+
print("torch.nanmean via maskedtensor:\n", torch.mean(masked_tensor(x, ~torch.isnan(x))))
288+
289+
######################################################################
290+
# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
291+
#
292+
# Conclusion
293+
# ++++++++++
294+
#
295+
# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
296+
# value through a series of examples and issues that they've helped resolve.
297+
#
298+
# Further Reading
299+
# +++++++++++++++
300+
#
301+
# To continue learning more, you can find our
302+
# `Sparsity tutorial <https://github.com/pytorch/tutorials/pull/2050/files>`_ to see how MaskedTensor enables sparsity
303+
# and the different storage formats we currently support.
304+
#

prototype_source/prototype_index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ Prototype features are not available as part of binary distributions like PyPI o
141141
:link: ../prototype/nestedtensor.html
142142
:tags: NestedTensor
143143

144+
.. MaskedTensor
145+
146+
.. customcarditem::
147+
:header: MaskedTensor Overview
148+
:card_description: Learn about masked tensors, the source of truth for specified and unspecified values
149+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
150+
:link: ../prototype/maskedtensor_overview.html
151+
:tags: MaskedTensor
152+
144153
.. End of tutorial card section
145154
146155
.. raw:: html
@@ -172,3 +181,4 @@ Prototype features are not available as part of binary distributions like PyPI o
172181
prototype/vmap_recipe.html
173182
prototype/vulkan_workflow.html
174183
prototype/nestedtensor.html
184+
prototype/maskedtensor_overview.html

0 commit comments

Comments
 (0)