Skip to content

Commit

Permalink
[dtensor] handle the case where output of op is Optional[Tensor] (pyt…
Browse files Browse the repository at this point in the history
…orch#90241)

Observed by @aazzolini, some op might have Optional[Tensor] returns
where it return None (i.e. native_layer_norm_backward), it's a mismatch
between C++ aten op signature and python None, but we need to handle it
in the python side
Pull Request resolved: pytorch#90241
Approved by: https://github.com/aazzolini
  • Loading branch information
wanchaol authored and pytorchmergebot committed Dec 6, 2022
1 parent eace084 commit 9e314bd
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torch/distributed/_tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ArgKwargsType = Union[Tuple[object, ...], Dict[str, object]]
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
# be the same set of possiblities.
OutputSpecType = Optional[Union[DTensorSpec, Sequence[DTensorSpec]]]
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]


def unwrap_local_tensor(e: "dtensor.DTensor") -> torch.Tensor:
Expand Down Expand Up @@ -45,8 +45,13 @@ def wrap(res: object, spec: OutputSpecType) -> object:
assert spec is not None and isinstance(
spec, tuple
), f"output spec does not match with output! Expected tuple, got {spec}"

# NOTE: local results might return Optional Tensor from ATen op, so we need to
# handle that case and make sure we don't wrap None with DTensor.
# (i.e. native_layer_norm.backward)
return tuple(
dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
if e is not None and s is not None else None
for e, s in zip(res, spec)
)
else:
Expand Down

0 comments on commit 9e314bd

Please sign in to comment.