Skip to content

[maskedtensor] Overview tutorial [1/4] #2050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Oct 28, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
undo requirements.txt changes
  • Loading branch information
Svetlana Karslioglu authored and george-qi committed Oct 22, 2022
commit 16abb37b50af238c5c7e3604af7fb340a72c455f
77 changes: 50 additions & 27 deletions prototype_source/maskedtensor_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""
(Prototype) MaskedTensor Overview
=================================
*********************************
"""

######################################################################
Expand All @@ -14,16 +14,33 @@
# * use any masked semantics (for example, variable length tensors, nan* operators, etc.)
# * differentiation between 0 and NaN gradients
# * various sparse applications (see tutorial below)
#
#
# For a more detailed introduction on what MaskedTensors are, please find the
# `torch.masked documentation <https://pytorch.org/docs/master/masked.html>`__.
#
#
# Using MaskedTensor
# ++++++++++++++++++
# ==================
#
# In this section we discuss how to use MaskedTensor including how to construct, access, the data
# and mask, as well as indexing and slicing.
#
# Preparation
# -----------
#
# We'll begin by doing the necessary setup for the tutorial:
#

import torch
from torch.masked import masked_tensor, as_masked_tensor
import warnings

# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)

######################################################################
# Construction
# ------------
#
#
# There are a few different ways to construct a MaskedTensor:
#
# * The first way is to directly invoke the MaskedTensor class
Expand Down Expand Up @@ -52,24 +69,24 @@
# as :class:`torch.Tensor`. Below are some examples of common indexing and slicing patterns:
#

import torch
from torch.masked import masked_tensor, as_masked_tensor

data = torch.arange(24).reshape(2, 3, 4)
mask = data % 2 == 0

print("data\n", data)
print("mask\n", mask)
print("data:\n", data)
print("mask:\n", mask)

######################################################################
#

# float is used for cleaner visualization when being printed
mt = masked_tensor(data.float(), mask)

print ("mt[0]:\n", mt[0])
print ("mt[:, :, 2:4]", mt[:, :, 2:4])
print("mt[0]:\n", mt[0])
print("mt[:, :, 2:4]:\n", mt[:, :, 2:4])

######################################################################
# Why is MaskedTensor useful?
# +++++++++++++++++++++++++++
# ===========================
#
# Because of :class:`MaskedTensor`'s treatment of specified and unspecified values as a first-class citizen
# instead of an afterthought (with filled values, nans, etc.), it is able to solve for several of the shortcomings
Expand All @@ -90,8 +107,8 @@
#
# :class:`MaskedTensor` is the perfect solution for this!
#
# :func:`torch.where`
# ^^^^^^^^^^^^^^^^^^^
# torch.where
# ^^^^^^^^^^^
#
# In `Issue 10729 <https://github.com/pytorch/pytorch/issues/10729>`__, we notice a case where the order of operations
# can matter when using :func:`torch.where` because we have trouble differentiating between if the 0 is a real 0
Expand Down Expand Up @@ -121,8 +138,8 @@
# The gradient here is only provided to the selected subset. Effectively, this changes the gradient of `where`
# to mask out elements instead of setting them to zero.
#
# Another :func:`torch.where`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Another torch.where
# ^^^^^^^^^^^^^^^^^^^
#
# `Issue 52248 <https://github.com/pytorch/pytorch/issues/52248>`__ is another example.
#
Expand Down Expand Up @@ -174,15 +191,14 @@
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x/div # => y is [inf, 1]
>>>
mask = (div != 0) # => mask is [0, 1]
loss = as_masked_tensor(y, mask)
loss.sum().backward()
x.grad

######################################################################
# :func:`torch.nansum` and :func:`torch.nanmean`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# ----------------------------------------------
#
# In `Issue 67180 <https://github.com/pytorch/pytorch/issues/67180>`__,
# the gradient isn't calculate properly (a longstanding issue), whereas :class:`MaskedTensor` handles it correctly.
Expand Down Expand Up @@ -213,7 +229,7 @@
# Safe Softmax
# ------------
#
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`_
# Safe softmax is another great example of `an issue <https://github.com/pytorch/pytorch/issues/55056>`__
# that arises frequently. In a nutshell, if there is an entire batch that is "masked out"
# or consists entirely of padding (which, in the softmax case, translates to being set `-inf`),
# then this will result in NaNs, which can lead to training divergence.
Expand Down Expand Up @@ -247,24 +263,31 @@

######################################################################
# Implementing missing torch.nan* operators
# --------------------------------------------------------------------------------------------------------------
# -----------------------------------------
#
# In `Issue 61474 <<https://github.com/pytorch/pytorch/issues/61474>`__,
# In `Issue 61474 <https://github.com/pytorch/pytorch/issues/61474>`__,
# there is a request to add additional operators to cover the various `torch.nan*` applications,
# such as ``torch.nanmax``, ``torch.nanmin``, etc.
#
# In general, these problems lend themselves more naturally to masked semantics, so instead of introducing additional
# operators, we propose using :class:`MaskedTensor`s instead. Since
# `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`_, we can use it as a comparison point:
# operators, we propose using :class:`MaskedTensor` instead.
# Since `nanmean has already landed <https://github.com/pytorch/pytorch/issues/21987>`__,
# we can use it as a comparison point:
#

x = torch.arange(16).float()
y = x * x.fmod(4)
z = y.masked_fill(y == 0, float('nan')) # we want to get the mean of y when ignoring the zeros

print("y:\n, y")
######################################################################
#
print("y:\n", y)
# z is just y with the zeros replaced with nan's
print("z:\n", z)

######################################################################
#

print("y.mean():\n", y.mean())
print("z.nanmean():\n", z.nanmean())
# MaskedTensor successfully ignores the 0's
Expand Down Expand Up @@ -296,13 +319,13 @@
# This is a similar problem to safe softmax where `0/0 = nan` when what we really want is an undefined value.
#
# Conclusion
# ++++++++++
# ==========
#
# In this tutorial, we've introduced what MaskedTensors are, demonstrated how to use them, and motivated their
# value through a series of examples and issues that they've helped resolve.
#
# Further Reading
# +++++++++++++++
# ===============
#
# To continue learning more, you can find our
# `MaskedTensor Sparsity tutorial <https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html>`__
Expand Down