@@ -1000,6 +1000,36 @@ def _fused_dropout_decomposition(input, p, generator=None):
1000
1000
return (res , mask )
1001
1001
1002
1002
1003
+ @register_decomposition (aten ._to_copy )
1004
+ def _to_copy (
1005
+ x : Tensor ,
1006
+ * ,
1007
+ dtype : Optional [torch .dtype ] = None ,
1008
+ layout = None ,
1009
+ device : Optional [torch .device ] = None ,
1010
+ pin_memory : bool = False ,
1011
+ non_blocking : bool = False ,
1012
+ memory_format : Optional [torch .memory_format ] = None ,
1013
+ ):
1014
+ assert not layout or layout == torch .strided , "TODO"
1015
+ assert not pin_memory , "TODO"
1016
+ assert device is not None or dtype is not None or memory_format is not None
1017
+ dtype_converted = False
1018
+ if device is not None and device != x .get_device ():
1019
+ # avoid conversions on cpu
1020
+ if dtype is not None and device .type == "cpu" :
1021
+ x = torch ._prims .convert_element_type (x , dtype )
1022
+ dtype_converted = True
1023
+ x = torch ._prims .device_put (x , device )
1024
+ if dtype is not None and not dtype_converted :
1025
+ x = torch ._prims .convert_element_type (x , dtype )
1026
+ if memory_format is not None : # no ref/prim for memory format
1027
+ out = torch .empty_like (x , memory_format = memory_format )
1028
+ out .copy_ (x )
1029
+ return out # type: ignore[call-overload]
1030
+ return x
1031
+
1032
+
1003
1033
@register_decomposition (aten .xlogy .Tensor )
1004
1034
@pw_cast_for_int_to_real
1005
1035
def xlogy (self : Tensor , other : Tensor ) -> Tensor :
0 commit comments