Skip to content

Commit b159a52

Browse files
Revert "Add nvprims.var_mean (pytorch#83508)"
This reverts commit 7e7694b. Reverted pytorch#83508 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
1 parent 71cd3fa commit b159a52

File tree

15 files changed

+44
-381
lines changed

15 files changed

+44
-381
lines changed

test/test_decomp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
188188
1e-3,
189189
1e-3,
190190
),
191-
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
192191
}
193192
if (test_dtype, op) in tol_table:
194193
rtol, atol = tol_table[(decomp.dtype, op)]

test/test_prims.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -353,31 +353,6 @@ def _wrapper(a):
353353
self.assertTrue(result.is_contiguous)
354354
self.assertEqual(_wrapper(a), result)
355355

356-
@onlyCUDA
357-
@skipCUDAIfRocm
358-
@dtypes(torch.float16, torch.float32)
359-
@parametrize("correction", [0, 1])
360-
@parametrize("keepdim", [True, False])
361-
def test_var_mean(self, device, dtype, correction, keepdim):
362-
from torch.fx.experimental.proxy_tensor import make_fx
363-
from torch._prims.context import TorchRefsNvfuserCapabilityMode
364-
365-
366-
def _wrapper(a):
367-
return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim)
368-
369-
make_arg = partial(make_tensor, device=device, dtype=dtype)
370-
371-
with TorchRefsNvfuserCapabilityMode():
372-
gm = make_fx(_wrapper)(make_arg((5, 5)))
373-
374-
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
375-
includes_nvprims_var_mean = any(
376-
torch.ops.nvprims.var_mean.main == node.target
377-
for node in call_function_nodes
378-
)
379-
self.assertTrue(includes_nvprims_var_mean)
380-
381356
@onlyCUDA
382357
@skipCUDAIfRocm
383358
@dtypes(torch.float32)

torch/_prims/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
DimsType,
2121
Number,
2222
NumberType,
23-
RETURN_TYPE,
2423
ShapeType,
2524
StrideType,
2625
TensorLike,
@@ -281,6 +280,20 @@ def TensorMeta(
281280
#
282281
# Common datastructures and helpers
283282
#
283+
284+
# Describes the return type of the primitive:
285+
#
286+
# - NEW, a new tensor is created
287+
# - VIEW, a view of an input tensor is returned
288+
# - INPLACE, one or more input tensors is modified
289+
#
290+
# these descriptors are mututally exclusive and exhaustive.
291+
class RETURN_TYPE(Enum):
292+
NEW = (0,)
293+
VIEW = (1,)
294+
INPLACE = (2,)
295+
296+
284297
def _wrap_tensor_meta(f):
285298
def wrap(t):
286299
if (

torch/_prims/context.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch._refs.nn.functional
1313
import torch._refs.special
1414
import torch.overrides
15-
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
1615

1716
from torch._prims_common import torch_function_passthrough
1817
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
@@ -205,42 +204,15 @@ def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
205204
):
206205
gm = get_isolated_graphmodule(func, args, kwargs)
207206

208-
supported_ops = NvfuserPrimOperatorSupport()
209207
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
210208
any_unsupported = any(
211-
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
209+
not _is_node_supported_nvfuser(node) for node in call_function_nodes
212210
)
213211
return any_unsupported
214212

215213

216-
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
217-
def __init__(self):
218-
super().__init__(
219-
strict=False,
220-
should_fallback_fn=_is_func_unsupported_nvfuser,
221-
prims_mode_cls=NvfuserPrimsMode,
222-
)
223-
224-
def _is_var_mean(self, func):
225-
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
226-
(
227-
isinstance(func, torch._ops.OpOverload)
228-
or isinstance(func, torch._ops.OpOverloadPacket)
229-
)
230-
and "aten.var_mean" in str(func)
231-
)
232-
233-
def __torch_function__(
234-
self,
235-
orig_func: Callable,
236-
types: Sequence,
237-
args: Sequence[Any] = (),
238-
kwargs: Dict = None,
239-
):
240-
if kwargs is None:
241-
kwargs = {}
242-
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
243-
if self._is_var_mean(orig_func):
244-
return torch.ops.nvprims.var_mean(*args, **kwargs)
245-
# Then we use TorchRefsMode to interpret the rest
246-
return super().__torch_function__(orig_func, types, args, kwargs)
214+
TorchRefsNvfuserCapabilityMode = functools.partial(
215+
TorchRefsMode,
216+
should_fallback_fn=_is_func_unsupported_nvfuser,
217+
prims_mode_cls=NvfuserPrimsMode,
218+
)

torch/_prims/nvfuser_executor.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
5757
# PROTOTYPE nvfuser executor
5858
# Everything in the graph must support nvfuser
5959
for node in gm.graph.nodes:
60-
if node.op == "call_function" and "getitem" in node.name:
61-
continue
6260
if (
6361
node.op == "call_function"
6462
and getattr(node.target, "impl_nvfuser", None) is None
@@ -79,10 +77,6 @@ def _to_nvfuser_constant(arg):
7977

8078
class FusionInterpreter(torch.fx.Interpreter):
8179
def call_function(self, target, args, kwargs):
82-
# This handles tuple unpacking
83-
if "getitem" in str(target):
84-
assert isinstance(args[0], tuple)
85-
return target(*args, **kwargs)
8680
args = tuple(map(_to_nvfuser_constant, args))
8781
target = target.impl_nvfuser
8882
args = (fd,) + args
@@ -138,7 +132,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
138132
return (
139133
node.op == "call_function"
140134
and getattr(node.target, "impl_nvfuser", None) is not None
141-
or "getitem" in node.name # getitem is a special case
142135
)
143136

144137

torch/_prims/nvfuser_prims.py

Lines changed: 2 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,24 @@
55
# can be added in the future for the corresponding higher-level torch/aten
66
# functions.
77

8-
from typing import Any, Dict, Optional
8+
from typing import Any, Dict
99

1010
import torch
1111

1212
from torch._prims_common import (
1313
DimsSequenceType,
14-
ELEMENTWISE_TYPE_PROMOTION_KIND,
1514
getnvFuserDtype,
1615
ShapeType,
1716
TensorLikeType,
1817
)
1918

20-
from torch._prims_common.wrappers import (
21-
backwards_not_supported,
22-
elementwise_type_promotion_wrapper,
23-
)
19+
from torch._prims_common.wrappers import backwards_not_supported
2420

2521
nvprim_namespace = "nvprims"
2622
nvprim = torch.library.Library(nvprim_namespace, "DEF")
2723
nvprim_impl = torch.library.Library(
2824
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
2925
)
30-
nvprim_implicit_impl = torch.library.Library(
31-
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
32-
)
3326
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
3427
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
3528

@@ -241,23 +234,6 @@ def _var_nvfuser(
241234
return fd.ops.var(a, dims, correction, keep_dims)
242235

243236

244-
def _var_mean_nvfuser(
245-
fd: Any,
246-
a: TensorLikeType,
247-
dims: DimsSequenceType,
248-
unbiased: Optional[bool] = None,
249-
keepdim: bool = False,
250-
*,
251-
correction: int,
252-
):
253-
# Unbiased arg shouldn't be set when this function is called
254-
assert unbiased is None
255-
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
256-
# keepdim is handled by the reference implementation
257-
keepdim = False
258-
return fd.ops.var_mean(a, dims, correction, keepdim)
259-
260-
261237
def _amax_nvfuser(
262238
fd: Any,
263239
a: TensorLikeType,
@@ -280,112 +256,12 @@ def _amin_nvfuser(
280256
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
281257
_nvfuser_impls["sum"] = _sum_nvfuser
282258
_nvfuser_impls["var"] = _var_nvfuser
283-
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
284259
_nvfuser_impls["amax"] = _amax_nvfuser
285260
_nvfuser_impls["amin"] = _amin_nvfuser
286261

287262

288-
def register_var_mean():
289-
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
290-
name = "var_mean.main"
291-
292-
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
293-
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
294-
295-
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
296-
nvprim.define(
297-
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
298-
+ " -> (Tensor, Tensor)"
299-
)
300-
301-
# This function is used for device="meta" Tensors.
302-
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
303-
if torch._prims_common.is_complex_dtype(inp.dtype):
304-
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
305-
else:
306-
output_dtype = inp.dtype
307-
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
308-
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
309-
if keepdim:
310-
output_shape = [
311-
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
312-
]
313-
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
314-
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
315-
mean = torch.ops.nvprims.broadcast_in_dim(
316-
mean, output_shape, broadcast_dims
317-
)
318-
return (var, mean)
319-
320-
# This function is used under _AutoDispatchBelowAutograd context
321-
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
322-
correction = torch._prims_common.set_correction(unbiased, correction)
323-
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
324-
325-
nvprim_impl.impl(name, _prim_impl)
326-
nvprim_meta_impl.impl(name, _meta_var_mean)
327-
328-
prim_packet = torch.ops.nvprims.var_mean
329-
prim = prim_packet.main
330-
331-
def _unbiased_overload_impl(inp, unbiased):
332-
return prim(inp, dim=None, unbiased=unbiased)
333-
334-
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
335-
336-
@elementwise_type_promotion_wrapper(
337-
type_promoting_args=("a",),
338-
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
339-
)
340-
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
341-
correction = torch._prims_common.set_correction(unbiased, correction)
342-
# reduces over all dimensions if dim=() is passed
343-
if dim == () or dim == []:
344-
dim = None
345-
dim = torch._prims_common.reduction_dims(a.shape, dim)
346-
347-
# For complex tensors eager computes the variance as the sum of variances of
348-
# the real and imaginary parts
349-
# TODO: Creating a complex tensor from real and imaginary parts is not supported
350-
if torch._prims_common.is_complex_dtype(a.dtype):
351-
raise NotImplementedError("Complex tensors are not supported")
352-
353-
var_mean = prim(a, dim, correction=correction)
354-
355-
if keepdim:
356-
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
357-
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
358-
var, mean = var_mean
359-
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
360-
mean = torch.ops.nvprims.broadcast_in_dim(
361-
mean, output_shape, broadcast_dims
362-
)
363-
var_mean = (var, mean)
364-
return var_mean
365-
366-
def _var_mean_autograd(
367-
a, dim=None, unbiased=None, keepdim=False, *, correction=None
368-
):
369-
# This wrapper is needed to convert prims calls inside
370-
# elementwise_type_promotion_wrapper to nvprims calls
371-
from torch._prims.context import NvfuserPrimsMode
372-
373-
with NvfuserPrimsMode():
374-
return backwards_not_supported(_var_mean_ref)(
375-
a, dim, unbiased, keepdim, correction=correction
376-
)
377-
378-
nvprim_autograd_impl.impl(name, _var_mean_autograd)
379-
380-
for p in (prim_packet, prim):
381-
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
382-
p.impl_nvfuser = _nvfuser_impls["var_mean"]
383-
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
384-
385-
386263
def register_nvprims():
387264
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
388-
register_var_mean()
389265
for name in nvprim_names:
390266
main_prim = getattr(torch.ops.prims, name)
391267

torch/_prims_common/__init__.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,19 +1033,6 @@ class REDUCTION_OUTPUT_TYPE_KIND(Enum):
10331033
ALWAYS_BOOL = (3,)
10341034

10351035

1036-
# Describes the return type of the primitive:
1037-
#
1038-
# - NEW, a new tensor is created
1039-
# - VIEW, a view of an input tensor is returned
1040-
# - INPLACE, one or more input tensors is modified
1041-
#
1042-
# these descriptors are mututally exclusive and exhaustive.
1043-
class RETURN_TYPE(Enum):
1044-
NEW = (0,)
1045-
VIEW = (1,)
1046-
INPLACE = (2,)
1047-
1048-
10491036
# TODO: document type promotion kinds
10501037
def elementwise_dtypes(
10511038
*_args,
@@ -1361,23 +1348,6 @@ def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...
13611348
return dims
13621349

13631350

1364-
def set_correction(
1365-
unbiased: Optional[bool] = None,
1366-
correction: Optional[int] = None,
1367-
):
1368-
if correction is not None and unbiased is not None:
1369-
raise RuntimeError("cannot specify both correction and unbiased arguments")
1370-
elif correction is None and unbiased is None:
1371-
correction = 1
1372-
elif correction is None and unbiased is not None:
1373-
correction = 0 if unbiased is False else 1
1374-
if not isinstance(correction, int):
1375-
raise ValueError("correction argument should be integer")
1376-
if correction < 0:
1377-
raise ValueError("correction argument should be non-negative")
1378-
return correction
1379-
1380-
13811351
def check_in_bounds_for_storage(
13821352
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
13831353
):

torch/_prims_common/wrappers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,9 @@ def _fn(*args, **kwargs):
118118

119119
result = fn(**bound.arguments)
120120

121-
if isinstance(result, TensorLike):
122-
return _maybe_convert_to_dtype(result, result_dtype)
123-
if isinstance(result, Sequence):
124-
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
125-
raise AssertionError(f"Unhandled result type: {type(result)}")
121+
# FIXME?: assumes result is a single tensor
122+
assert isinstance(result, TensorLike)
123+
return _maybe_convert_to_dtype(result, result_dtype)
126124

127125
_fn.__signature__ = sig # type: ignore[attr-defined]
128126
return _fn

0 commit comments

Comments
 (0)