@@ -1088,6 +1088,36 @@ def _fused_dropout_decomposition(input, p, generator=None):
1088
1088
return (res , mask )
1089
1089
1090
1090
1091
+ @register_decomposition (aten ._to_copy )
1092
+ def _to_copy (
1093
+ x : Tensor ,
1094
+ * ,
1095
+ dtype : Optional [torch .dtype ] = None ,
1096
+ layout = None ,
1097
+ device : Optional [torch .device ] = None ,
1098
+ pin_memory : bool = False ,
1099
+ non_blocking : bool = False ,
1100
+ memory_format : Optional [torch .memory_format ] = None ,
1101
+ ):
1102
+ assert not layout or layout == torch .strided , "TODO"
1103
+ assert not pin_memory , "TODO"
1104
+ assert device is not None or dtype is not None or memory_format is not None
1105
+ dtype_converted = False
1106
+ if device is not None and device != x .get_device ():
1107
+ # avoid conversions on cpu
1108
+ if dtype is not None and device .type == "cpu" :
1109
+ x = torch ._prims .convert_element_type (x , dtype )
1110
+ dtype_converted = True
1111
+ x = torch ._prims .device_put (x , device )
1112
+ if dtype is not None and not dtype_converted :
1113
+ x = torch ._prims .convert_element_type (x , dtype )
1114
+ if memory_format is not None : # no ref/prim for memory format
1115
+ out = torch .empty_like (x , memory_format = memory_format )
1116
+ out .copy_ (x )
1117
+ return out # type: ignore[call-overload]
1118
+ return x
1119
+
1120
+
1091
1121
@register_decomposition (aten .xlogy .Tensor )
1092
1122
@pw_cast_for_int_to_real
1093
1123
def xlogy (self : Tensor , other : Tensor ) -> Tensor :
0 commit comments