Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
# subclasses.
if type(t) is pytorch.Tensor and t.layout is pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if not data_ptr:
# This happens when t.numel() == 0
continue
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
tgi = data_ptr_to_tensor_group_index[data_ptr]
Expand Down
4 changes: 2 additions & 2 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,10 +648,10 @@ def wrap_tensor(t: TensorLike, dim_length: int) -> TensorLike:


@clangop()
def copy_with_setitem(a: TensorLike, key, value: TensorLike) -> TensorLike:
def setitem(a: TensorLike, key, value: TensorLike) -> TensorLike:
# TODO: do more checking here. We used to have a check
# lambda: f"{key=} tries to index more dimensions than {a.ndim=}",
return prims.copy_with_setitem(a, key, value)
return prims.setitem(a, key, value)


# NOTE: currently supported indexing:
Expand Down
7 changes: 3 additions & 4 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class PrimIDs(Enum):
SCATTER_ADD = auto()
TAKE = auto()
TAKE_ALONG_AXIS = auto()
COPY_WITH_SETITEM = auto()
SETITEM = auto()
# Linear algebra prims (Mostly experimental)
MATMUL = auto()
_GROUPED_MM = auto() # Used for grouped matmuls
Expand Down Expand Up @@ -3621,12 +3621,11 @@ def take_along_axis_meta(a: TensorProxy, /, index: TensorProxy, dim: int) -> Ten
take_along_axis = make_prim(PrimIDs.TAKE_ALONG_AXIS, "take_along_axis", meta=take_along_axis_meta)


def copy_with_setitem_meta(a: TensorProxy, index, value: TensorProxy) -> TensorProxy:
# TODO: port checks from clang, currently there because of the utilities they need
def setitem_meta(a: TensorProxy, index, value: TensorProxy | Number | NumberProxy) -> TensorProxy:
return TensorProxy(like=a)


copy_with_setitem = make_prim(PrimIDs.COPY_WITH_SETITEM, "copy_with_setitem", meta=copy_with_setitem_meta)
setitem = make_prim(PrimIDs.SETITEM, "setitem", meta=setitem_meta, tags=(OpTags.DONT_DCE,))


def gather_meta(a: TensorProxy, /, index: TensorProxy, dim: int) -> TensorProxy:
Expand Down
3 changes: 3 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def replace_redundant_inputs(
# into one for ops in this set.
NON_FUNCTIONAL_OPS: set[prims.PrimIDs | str] = {
prims.PrimIDs.UNIFORM,
prims.PrimIDs.EMPTY,
"empty",
"torch.empty",
"torch.uniform", # this doesn't exist as of the PR
"torch.uniform_like", # this doesn't exist as of the PR
# thunder.core.prims doesn't support. See https://pytorch.org/docs/stable/generated/torch.rand.html.
Expand Down
51 changes: 31 additions & 20 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
import dataclasses

from thunder.core.update_aliases import insert_alias_updates
import thunder.core.utils as utils
from thunder.core import dtypes, prims
from thunder.core.devices import Device
Expand Down Expand Up @@ -1442,26 +1443,6 @@ def _maximum_grad(a: TensorProxy, b: TensorProxy, /):
register_grad(pids.SHAPE, prims.shape)


def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):
fwd = prims.copy_with_setitem(a, index, value)
g = get_grad(fwd)

a_grad = prims.copy_with_setitem(g, index, 0)
put_grad(a, a_grad)

if isinstance(value, TensorProxy):
value_grad = g[index]
# NOTE: `value` could be broadcasted.
if not utils.same_shape(value_grad.shape, value.shape):
value_grad = sum_to(value_grad, value.shape)
put_grad(value, value_grad)

return fwd


register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad)


def _log_sigmoid_grad(
a: TensorProxy,
) -> TensorProxy:
Expand Down Expand Up @@ -2594,6 +2575,34 @@ def index_put_aug_fwd(
return VJPDual(primal, residuals)


@register_augmented_forward(prims.PrimIDs.SETITEM)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to define forward and backward separately. If not, when we extract the forward trace from the combined fw&bw function, we hit into prims.setitem(clone(g), index, 0) in backward, which survives through DCE because setitem is tagged as DONT_DCE.

def setitem_aug_fwd(a, index, value) -> VJPDual:
primal = prims.setitem(a, index, value)
value_shape = value.shape if isinstance(value, TensorProxy) else None
return VJPDual(primal, (index, value_shape))


@register_backward(prims.PrimIDs.SETITEM)
def setitem_backward(index, value_shape, g):
# We avoid using Tensor.clone because nvfuserex has unsoundness in mutation on cloned tensors
# See https://github.com/Lightning-AI/lightning-thunder/issues/2793
def clone(t):
cd = get_compile_data()
buf = prims.empty(t.shape, device=t.device, dtype=t.dtype)
return prims.copy_(t, buf, grad_enabled=cd.is_grad_enabled if cd is not None else False)

a_grad = prims.setitem(clone(g), index, 0)

value_grad = None
if value_shape is not None:
value_grad = g[index]
# NOTE: `value` could be broadcasted.
if not utils.same_shape(value_grad.shape, value_shape):
value_grad = sum_to(value_grad, value_shape)

return a_grad, None, value_grad


if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch._C._distributed_c10d import _resolve_process_group
Expand Down Expand Up @@ -3048,6 +3057,8 @@ def vjp(func):
def _vjp(primals, cotangents, **kwargs):
flat_func, flat_args, spec = flatten_func(func, primals, kwargs)
trace = construct_trace()(flat_func, *flat_args)
# No need to insert prims.update_aliases, but we need insert_alias_updates for variable substitution
trace = insert_alias_updates(trace, [])
result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)
# If the argument is a CPU scalar tensor, its gradient needs to be summed into a scalar tensor.
vjp_result = tuple(
Expand Down
19 changes: 14 additions & 5 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from enum import Enum, auto
import operator
from typing import TYPE_CHECKING
import dataclasses
import inspect
Expand Down Expand Up @@ -389,11 +390,19 @@ def _run_with_cache_info():
exception=str(e),
)

function_to_run = (
value_and_grad(thunder_symbol)
if requires_grad and (disable_torch_autograd is None or not disable_torch_autograd)
else thunder_symbol
)
if requires_grad and (disable_torch_autograd is None or not disable_torch_autograd):
if thunder_symbol is operator.setitem:
# operator.setitem returns None, which makes its backward pass empty
# We don't need to cover torch.Tensor.__setitem__ as dynamo uses operator.setitem instead
def setitem_and_return(a, key, value):
a[key] = value
return a

function_to_run = value_and_grad(setitem_and_return)
else:
function_to_run = value_and_grad(thunder_symbol)
else:
function_to_run = thunder_symbol
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
try:
Expand Down
22 changes: 9 additions & 13 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,19 +1482,6 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
_register_implementation(ltorch.scatter_add, checker=_always_executable, execution_transform=_scatter_add_transform)
_register_implementation(ltorch.take_along_dim, take_along_dim, checker=_always_executable)

# out of place setitem helper


def _copy_with_setitem_impl(a, key, value):
c = a.clone()
c[key] = value
return c


copy_with_setitem_impl = ex.register_operator(
"copy_with_setitem_impl", meta=prims.copy_with_setitem_meta, fn=_copy_with_setitem_impl
)
_register_implementation(prims.copy_with_setitem, copy_with_setitem_impl, checker=_always_executable)

#
# Linear algebra operations
Expand Down Expand Up @@ -2376,6 +2363,15 @@ def _copy__impl(copy_from, copy_to, grad_enabled):
_register_implementation(prims.copy_, copy_, checker=_always_executable)


def _setitem_impl(a, key, value):
a[key] = value
return a


setitem = ex.register_operator("setitem", tags=(prims.OpTags.DONT_DCE,), like=ltorch.setitem_, fn=_setitem_impl)
_register_implementation(prims.setitem, setitem, checker=_always_executable)


def _shape_impl(t):
return t.shape

Expand Down
44 changes: 44 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4473,6 +4473,50 @@ def make_nd_idx(dim_length: int, indices: int, ndim: int):
shape_ops.append(getitem_opinfo)


def setitem_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

def _make_setitem_sample(tensor, key):
indexed_shape = tensor[key].shape

# Tests for getitem are already slow, and doubling them is too time-consuming
# value = make_tensor(indexed_shape, device=device, dtype=dtype, requires_grad=requires_grad)
# yield SampleInput(tensor, key, value)

pre_broadcast_shape = tuple(random.choice((s, 1)) for s in indexed_shape)
pre_broadcast_value = make_tensor(pre_broadcast_shape, device=device, dtype=dtype, requires_grad=requires_grad)
return SampleInput(tensor, key, pre_broadcast_value)

for sample in getitem_sample_generator(op, device, dtype, requires_grad, **kwargs):
tensor, key = sample.args
yield _make_setitem_sample(tensor, key)

# Boolean mask indexing
boolean_mask_cases = [
((6,), (torch.tensor([True, False, True, False, True, False]),)),
((2, 3), (torch.tensor([[True, False, True], [False, True, False]]),)),
((2, 3, 4), ([False, True], [False, True, False], slice(None))),
((2, 3, 4), (torch.tensor([False, False]), slice(None))),
((2, 3, 4), (torch.tensor([True, False]), [1, 1], slice(None))),
((2, 3, 4), (1, torch.tensor([True, False, True]), slice(None))),
((2, 3), (torch.tensor([True, False]), None, [0, 2])),
((4, 2, 3), (Ellipsis, [False, True, False])),
]

for shape, key in boolean_mask_cases:
tensor = make(shape)
yield _make_setitem_sample(tensor, key)


setitem_opinfo = OpInfo(
operator.setitem,
sample_input_generator=setitem_sample_generator,
torch_reference=operator.setitem,
numpy_reference=operator.setitem,
)
shape_ops.append(setitem_opinfo)


def movedim_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

Expand Down
43 changes: 43 additions & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"index_select",
# Finite difference approximation doesn't work for this function
"embedding",
"setitem",
"index_put",
"batch_norm",
"instance_norm",
Expand Down Expand Up @@ -689,6 +690,48 @@ def test_vjp_correctness_embedding_manual(op, device, dtype, executor, comp):
comp(actual_out, out)


@ops((get_opinfo("setitem"),), supported_dtypes=(dtypes.float64,))
def test_vjp_correctness_setitem_manual(op, device, dtype, executor, comp):
for sample in op.sample_inputs(device, dtype, requires_grad=True):

def torch_reference(tensor, idx, value):
cloned = tensor * 1
op.torch_reference(cloned, idx, value)
return cloned

def op_fn(tensor, idx, value):
cloned = tensor * 1
op.op(cloned, idx, value)
return cloned

tensor, key, value = sample.args
assert not sample.kwargs

tensor_ref = tensor.detach().clone().requires_grad_(True)
out = torch_reference(tensor_ref, key, value)
v = make_tensor_like(out)
expected = torch.autograd.grad(out, (tensor_ref, value), v)

flat_op, flat_args, spec = flatten_func(op_fn, (tensor, key, value), {})

t_key = key if isinstance(key, tuple) else (key,)
if any(isinstance(k, (torch.Tensor, Sequence)) and torch.tensor(k).dtype == torch.bool for k in t_key):
with pytest.raises(NotImplementedError):
vjp(flat_op)(flat_args, (v,))
continue

initial_trace = thunder.trace()(vjp(flat_op), flat_args, (v,))
jfn = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)
actual_out, actual_grad = jfn(flat_args, (v,))

# With advanced indexing, an element may be assigned multiple times and the assignment order is not guaranteed.
# comp(actual_out, out)

comp(tensor, tensor_ref)
comp(actual_grad[0], expected[0])
comp(actual_grad[-1], expected[1])


@ops((op for op in opinfos if op.name == "type_as"), supported_dtypes=(dtypes.float64,))
def test_vjp_correctness_type_as_manual(op, device, dtype, executor, comp):
for sample in op.sample_inputs(device, dtype, requires_grad=True):
Expand Down
11 changes: 9 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ def _copy_(a, b, /):
return prims.copy_(b, a, grad_enabled=cd.is_grad_enabled if cd is not None else False)


def _clone_via_copy(t: TensorProxy) -> TensorProxy:
"""Produces a functional clone using an explicit copy instead of prims.clone."""
cd = get_compile_data()
buf = prims.empty(t.shape, device=t.device, dtype=t.dtype)
return prims.copy_(t, buf, grad_enabled=cd.is_grad_enabled if cd is not None else False)


@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return _copy_(a, b)
Expand Down Expand Up @@ -1223,12 +1230,12 @@ def flip(a: TensorLike, /, *dims: int) -> TensorLike:
# fake out of place variant
@torchsymbol(id="setitem")
def setitem(inp, idx, val):
return clang.copy_with_setitem(inp, idx, val)
raise NotImplementedError


@torchsymbol(torch.Tensor.__setitem__, id="setitem_", is_method=True, tags=(prims.OpTags.IN_PLACE,))
def setitem_(inp, idx, val):
return _copy_(inp, setitem(inp, idx, val))
return clang.setitem(inp, idx, val)


@torchsymbol(torch.Tensor.__getitem__, id="torch.Tensor.__getitem__", method_name="getitem")
Expand Down
Loading