Skip to content

Commit

Permalink
Add decompositions for zero_, fill_, new_full, new_zeros, new_ones (p…
Browse files Browse the repository at this point in the history
…ytorch#82332)

This makes symbolic tracing tests for logsigmoid and xlogy start working again.

While I'm at it, add pin_memory and layout kwargs to empty; but they
don't actually do anything and raise an error if they are non standard.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: pytorch#82332
Approved by: https://github.com/eellison
  • Loading branch information
ezyang authored and pytorchmergebot committed Jul 28, 2022
1 parent 4a000ff commit 98b9dfa
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 9 deletions.
5 changes: 0 additions & 5 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,6 @@ def f(a, b):
xfail('nanmean', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('native_layer_norm', ''), # Unexpected type <class 'torch.SymbolicIntNode'> when computing elementwise type promot...
xfail('new_full', ''),
xfail('new_ones', ''),
xfail('new_zeros', ''),
xfail('nn.functional.adaptive_avg_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_avg_pool2d', ''), # argument 'size' must be tuple of ints, but found element o...
xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d.default - couldn't find symbolic meta func...
Expand Down Expand Up @@ -697,7 +694,6 @@ def f(a, b):
xfail('nn.functional.layer_norm', ''), # Unexpected type <class 'torch.SymbolicIntNode'> when computing elementwise type...
xfail('nn.functional.linear', ''), # aten.mv.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.logsigmoid', ''),
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d...
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
Expand Down Expand Up @@ -831,7 +827,6 @@ def f(a, b):
xfail('view', ''), # Tensors of type TensorImpl do not have numel
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('where', ''), # expected predicate to be bool, got torch.float32
xfail('xlogy', ''),
xfail('zero_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('zeros_like', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
}
Expand Down
6 changes: 6 additions & 0 deletions torch/_prims/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def torch_to_refs_map():
torch.Tensor.__and__: torch._refs.bitwise_and,
torch.Tensor.__or__: torch._refs.bitwise_or,
torch.Tensor.__eq__: torch._refs.eq,
torch.Tensor.new_empty: torch._refs.new_empty,
torch.Tensor.new_full: torch._refs.new_full,
torch.Tensor.new_zeros: torch._refs.new_zeros,
torch.Tensor.new_ones: torch._refs.new_ones,
torch.Tensor.fill_: torch._refs.fill_,
torch.Tensor.zero_: torch._refs.zero_,
# TODO: Should these methods be mapped some other way?
torch.Tensor.copy_: torch._prims.copy_to,
torch.Tensor.resize: torch._prims.resize,
Expand Down
100 changes: 96 additions & 4 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,18 @@ def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
return prims.fill(a, value)


def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
r = prims.fill(a, value)
prims.copy_to(a, r)
return a


def zero_(a: TensorLikeType) -> TensorLikeType:
r = prims.fill(a, 0)
prims.copy_to(a, r)
return a


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def floor(a):
return prims.floor(a)
Expand Down Expand Up @@ -2949,7 +2961,9 @@ def empty(
*shape,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
layout: Optional[torch.layout] = None,
requires_grad: bool = False,
pin_memory: bool = False,
memory_format: torch.memory_format = torch.contiguous_format,
) -> TensorLikeType:
check(
Expand All @@ -2971,7 +2985,13 @@ def empty(
strides = utils.make_channels_last_2d_strides_for(shape)

return torch.empty_strided(
shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
shape,
strides,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)


Expand All @@ -2998,13 +3018,66 @@ def new_empty(
)


# TODO: missing kwargs (e.g. layout)
@register_decomposition(torch.ops.aten.new_zeros)
def new_zeros(
a: TensorLikeType,
size: ShapeType,
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
pin_memory: bool = False,
) -> TensorLikeType:
r = a.new_empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)
r.zero_()
return r


@register_decomposition(torch.ops.aten.new_ones)
def new_ones(
a: TensorLikeType,
size: ShapeType,
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
pin_memory: bool = False,
) -> TensorLikeType:
r = a.new_empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)
r.fill_(1)
return r


@register_decomposition(torch.ops.aten.new_full)
def new_full(
a: TensorLikeType,
size: ShapeType,
fill_value: NumberType,
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
pin_memory: bool = False,
) -> TensorLikeType:
r = a.new_empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)
r.fill_(fill_value) # type: ignore[arg-type]
return r


def empty_like(
a: TensorLikeType,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
layout: Optional[torch.layout] = None,
requires_grad: bool = False,
pin_memory: bool = False,
memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType:

Expand All @@ -3017,15 +3090,23 @@ def empty_like(
return torch.empty(
a.shape,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
)

# memory_format == torch.preserve_format
strides = utils.compute_elementwise_output_strides(a)
return torch.empty_strided(
a.shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
a.shape,
strides,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)


Expand Down Expand Up @@ -3226,15 +3307,26 @@ def empty_strided(
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
layout: Optional[torch.layout] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> TensorLikeType:

if pin_memory:
raise NotImplementedError("PrimTorch doesn't support pinned memory")
if layout is not None and layout is not torch.strided:
raise NotImplementedError(f"PrimTorch doesn't support layout={layout}")

shape = utils.extract_shape_from_varargs(shape)
dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device("cpu") if device is None else device

return prims.empty_strided(
shape, strides, dtype=dtype, device=device, requires_grad=requires_grad
shape,
strides,
dtype=dtype,
device=device,
requires_grad=requires_grad,
)


Expand Down
15 changes: 15 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21976,6 +21976,21 @@ def __init__(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
),
),
PythonRefInfo(
"_refs.new_full",
torch_opinfo_name="new_full",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.new_ones",
torch_opinfo_name="new_ones",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.new_zeros",
torch_opinfo_name="new_zeros",
supports_nvfuser=False,
),
#
# Conditional Reference OpInfos
#
Expand Down

0 comments on commit 98b9dfa

Please sign in to comment.