Skip to content

Commit 533203f

Browse files
Natalia Gimelsheinpytorchmergebot
authored andcommitted
_to_copy decomp (pytorch#84108)
Per title Pull Request resolved: pytorch#84108 Approved by: https://github.com/Chillee
1 parent 9fc02f6 commit 533203f

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

torch/_decomp/decompositions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,36 @@ def _fused_dropout_decomposition(input, p, generator=None):
10001000
return (res, mask)
10011001

10021002

1003+
@register_decomposition(aten._to_copy)
1004+
def _to_copy(
1005+
x: Tensor,
1006+
*,
1007+
dtype: Optional[torch.dtype] = None,
1008+
layout=None,
1009+
device: Optional[torch.device] = None,
1010+
pin_memory: bool = False,
1011+
non_blocking: bool = False,
1012+
memory_format: Optional[torch.memory_format] = None,
1013+
):
1014+
assert not layout or layout == torch.strided, "TODO"
1015+
assert not pin_memory, "TODO"
1016+
assert device is not None or dtype is not None or memory_format is not None
1017+
dtype_converted = False
1018+
if device is not None and device != x.get_device():
1019+
# avoid conversions on cpu
1020+
if dtype is not None and device.type == "cpu":
1021+
x = torch._prims.convert_element_type(x, dtype)
1022+
dtype_converted = True
1023+
x = torch._prims.device_put(x, device)
1024+
if dtype is not None and not dtype_converted:
1025+
x = torch._prims.convert_element_type(x, dtype)
1026+
if memory_format is not None: # no ref/prim for memory format
1027+
out = torch.empty_like(x, memory_format=memory_format)
1028+
out.copy_(x)
1029+
return out # type: ignore[call-overload]
1030+
return x
1031+
1032+
10031033
@register_decomposition(aten.xlogy.Tensor)
10041034
@pw_cast_for_int_to_real
10051035
def xlogy(self: Tensor, other: Tensor) -> Tensor:

0 commit comments

Comments
 (0)