Skip to content

Commit 3aae6ff

Browse files
IvanYashchukpytorchmergebot
authored andcommitted
Add nvprims.var_mean (pytorch#83508)
This PR adds nvfuser-specific primitive - `var_mean`. Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager. I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`). Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti): ```py import torch from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.executor import execute def func(a): return torch.native_layer_norm(a, (1024,), None, None, 1e-6) a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda") with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a) for _ in range(10): execute(gm, a, executor="strictly_nvfuser"); ``` run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py` ```py # WITH THIS PR # kernel1 run in 0.032768 ms, achieved: 641.25 GB/s # kernel1 run in 0.033792 ms, achieved: 621.818 GB/s # kernel1 run in 0.032768 ms, achieved: 641.25 GB/s # kernel1 run in 0.032608 ms, achieved: 644.396 GB/s # kernel1 run in 0.031744 ms, achieved: 661.935 GB/s # kernel1 run in 0.031744 ms, achieved: 661.935 GB/s # kernel1 run in 0.032768 ms, achieved: 641.25 GB/s # kernel1 run in 0.03072 ms, achieved: 684 GB/s # kernel1 run in 0.031744 ms, achieved: 661.935 GB/s # kernel1 run in 0.031744 ms, achieved: 661.935 GB/s # ON MASTER # kernel1 run in 0.05632 ms, achieved: 373.091 GB/s # kernel1 run in 0.044032 ms, achieved: 477.209 GB/s # kernel1 run in 0.044032 ms, achieved: 477.209 GB/s # kernel1 run in 0.044032 ms, achieved: 477.209 GB/s # kernel1 run in 0.043808 ms, achieved: 479.649 GB/s # kernel1 run in 0.043008 ms, achieved: 488.571 GB/s # kernel1 run in 0.044032 ms, achieved: 477.209 GB/s # kernel1 run in 0.043008 ms, achieved: 488.571 GB/s # kernel1 run in 0.043008 ms, achieved: 488.571 GB/s # kernel1 run in 0.043008 ms, achieved: 488.571 GB/s ``` So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape. Also this PR fixes pytorch#83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`). Ref. pytorch#80187 Pull Request resolved: pytorch#83508 Approved by: https://github.com/ngimel
1 parent 261be8e commit 3aae6ff

File tree

15 files changed

+381
-44
lines changed

15 files changed

+381
-44
lines changed

test/test_decomp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
188188
1e-3,
189189
1e-3,
190190
),
191+
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
191192
}
192193
if (test_dtype, op) in tol_table:
193194
rtol, atol = tol_table[(decomp.dtype, op)]

test/test_prims.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,31 @@ def _wrapper(a):
353353
self.assertTrue(result.is_contiguous)
354354
self.assertEqual(_wrapper(a), result)
355355

356+
@onlyCUDA
357+
@skipCUDAIfRocm
358+
@dtypes(torch.float16, torch.float32)
359+
@parametrize("correction", [0, 1])
360+
@parametrize("keepdim", [True, False])
361+
def test_var_mean(self, device, dtype, correction, keepdim):
362+
from torch.fx.experimental.proxy_tensor import make_fx
363+
from torch._prims.context import TorchRefsNvfuserCapabilityMode
364+
365+
366+
def _wrapper(a):
367+
return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim)
368+
369+
make_arg = partial(make_tensor, device=device, dtype=dtype)
370+
371+
with TorchRefsNvfuserCapabilityMode():
372+
gm = make_fx(_wrapper)(make_arg((5, 5)))
373+
374+
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
375+
includes_nvprims_var_mean = any(
376+
torch.ops.nvprims.var_mean.main == node.target
377+
for node in call_function_nodes
378+
)
379+
self.assertTrue(includes_nvprims_var_mean)
380+
356381
@onlyCUDA
357382
@skipCUDAIfRocm
358383
@dtypes(torch.float32)

torch/_prims/__init__.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DimsType,
2121
Number,
2222
NumberType,
23+
RETURN_TYPE,
2324
ShapeType,
2425
StrideType,
2526
TensorLike,
@@ -280,20 +281,6 @@ def TensorMeta(
280281
#
281282
# Common datastructures and helpers
282283
#
283-
284-
# Describes the return type of the primitive:
285-
#
286-
# - NEW, a new tensor is created
287-
# - VIEW, a view of an input tensor is returned
288-
# - INPLACE, one or more input tensors is modified
289-
#
290-
# these descriptors are mututally exclusive and exhaustive.
291-
class RETURN_TYPE(Enum):
292-
NEW = (0,)
293-
VIEW = (1,)
294-
INPLACE = (2,)
295-
296-
297284
def _wrap_tensor_meta(f):
298285
def wrap(t):
299286
if (

torch/_prims/context.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch._refs.nn.functional
1313
import torch._refs.special
1414
import torch.overrides
15+
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
1516

1617
from torch._prims_common import torch_function_passthrough
1718
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
@@ -204,15 +205,42 @@ def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
204205
):
205206
gm = get_isolated_graphmodule(func, args, kwargs)
206207

208+
supported_ops = NvfuserPrimOperatorSupport()
207209
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
208210
any_unsupported = any(
209-
not _is_node_supported_nvfuser(node) for node in call_function_nodes
211+
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
210212
)
211213
return any_unsupported
212214

213215

214-
TorchRefsNvfuserCapabilityMode = functools.partial(
215-
TorchRefsMode,
216-
should_fallback_fn=_is_func_unsupported_nvfuser,
217-
prims_mode_cls=NvfuserPrimsMode,
218-
)
216+
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
217+
def __init__(self):
218+
super().__init__(
219+
strict=False,
220+
should_fallback_fn=_is_func_unsupported_nvfuser,
221+
prims_mode_cls=NvfuserPrimsMode,
222+
)
223+
224+
def _is_var_mean(self, func):
225+
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
226+
(
227+
isinstance(func, torch._ops.OpOverload)
228+
or isinstance(func, torch._ops.OpOverloadPacket)
229+
)
230+
and "aten.var_mean" in str(func)
231+
)
232+
233+
def __torch_function__(
234+
self,
235+
orig_func: Callable,
236+
types: Sequence,
237+
args: Sequence[Any] = (),
238+
kwargs: Dict = None,
239+
):
240+
if kwargs is None:
241+
kwargs = {}
242+
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
243+
if self._is_var_mean(orig_func):
244+
return torch.ops.nvprims.var_mean(*args, **kwargs)
245+
# Then we use TorchRefsMode to interpret the rest
246+
return super().__torch_function__(orig_func, types, args, kwargs)

torch/_prims/nvfuser_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
5757
# PROTOTYPE nvfuser executor
5858
# Everything in the graph must support nvfuser
5959
for node in gm.graph.nodes:
60+
if node.op == "call_function" and "getitem" in node.name:
61+
continue
6062
if (
6163
node.op == "call_function"
6264
and getattr(node.target, "impl_nvfuser", None) is None
@@ -77,6 +79,10 @@ def _to_nvfuser_constant(arg):
7779

7880
class FusionInterpreter(torch.fx.Interpreter):
7981
def call_function(self, target, args, kwargs):
82+
# This handles tuple unpacking
83+
if "getitem" in str(target):
84+
assert isinstance(args[0], tuple)
85+
return target(*args, **kwargs)
8086
args = tuple(map(_to_nvfuser_constant, args))
8187
target = target.impl_nvfuser
8288
args = (fd,) + args
@@ -132,6 +138,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
132138
return (
133139
node.op == "call_function"
134140
and getattr(node.target, "impl_nvfuser", None) is not None
141+
or "getitem" in node.name # getitem is a special case
135142
)
136143

137144

torch/_prims/nvfuser_prims.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,31 @@
55
# can be added in the future for the corresponding higher-level torch/aten
66
# functions.
77

8-
from typing import Any, Dict
8+
from typing import Any, Dict, Optional
99

1010
import torch
1111

1212
from torch._prims_common import (
1313
DimsSequenceType,
14+
ELEMENTWISE_TYPE_PROMOTION_KIND,
1415
getnvFuserDtype,
1516
ShapeType,
1617
TensorLikeType,
1718
)
1819

19-
from torch._prims_common.wrappers import backwards_not_supported
20+
from torch._prims_common.wrappers import (
21+
backwards_not_supported,
22+
elementwise_type_promotion_wrapper,
23+
)
2024

2125
nvprim_namespace = "nvprims"
2226
nvprim = torch.library.Library(nvprim_namespace, "DEF")
2327
nvprim_impl = torch.library.Library(
2428
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
2529
)
30+
nvprim_implicit_impl = torch.library.Library(
31+
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
32+
)
2633
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
2734
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
2835

@@ -234,6 +241,23 @@ def _var_nvfuser(
234241
return fd.ops.var(a, dims, correction, keep_dims)
235242

236243

244+
def _var_mean_nvfuser(
245+
fd: Any,
246+
a: TensorLikeType,
247+
dims: DimsSequenceType,
248+
unbiased: Optional[bool] = None,
249+
keepdim: bool = False,
250+
*,
251+
correction: int,
252+
):
253+
# Unbiased arg shouldn't be set when this function is called
254+
assert unbiased is None
255+
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
256+
# keepdim is handled by the reference implementation
257+
keepdim = False
258+
return fd.ops.var_mean(a, dims, correction, keepdim)
259+
260+
237261
def _amax_nvfuser(
238262
fd: Any,
239263
a: TensorLikeType,
@@ -256,12 +280,112 @@ def _amin_nvfuser(
256280
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
257281
_nvfuser_impls["sum"] = _sum_nvfuser
258282
_nvfuser_impls["var"] = _var_nvfuser
283+
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
259284
_nvfuser_impls["amax"] = _amax_nvfuser
260285
_nvfuser_impls["amin"] = _amin_nvfuser
261286

262287

288+
def register_var_mean():
289+
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
290+
name = "var_mean.main"
291+
292+
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
293+
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
294+
295+
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
296+
nvprim.define(
297+
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
298+
+ " -> (Tensor, Tensor)"
299+
)
300+
301+
# This function is used for device="meta" Tensors.
302+
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
303+
if torch._prims_common.is_complex_dtype(inp.dtype):
304+
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
305+
else:
306+
output_dtype = inp.dtype
307+
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
308+
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
309+
if keepdim:
310+
output_shape = [
311+
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
312+
]
313+
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
314+
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
315+
mean = torch.ops.nvprims.broadcast_in_dim(
316+
mean, output_shape, broadcast_dims
317+
)
318+
return (var, mean)
319+
320+
# This function is used under _AutoDispatchBelowAutograd context
321+
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
322+
correction = torch._prims_common.set_correction(unbiased, correction)
323+
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
324+
325+
nvprim_impl.impl(name, _prim_impl)
326+
nvprim_meta_impl.impl(name, _meta_var_mean)
327+
328+
prim_packet = torch.ops.nvprims.var_mean
329+
prim = prim_packet.main
330+
331+
def _unbiased_overload_impl(inp, unbiased):
332+
return prim(inp, dim=None, unbiased=unbiased)
333+
334+
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
335+
336+
@elementwise_type_promotion_wrapper(
337+
type_promoting_args=("a",),
338+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
339+
)
340+
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
341+
correction = torch._prims_common.set_correction(unbiased, correction)
342+
# reduces over all dimensions if dim=() is passed
343+
if dim == () or dim == []:
344+
dim = None
345+
dim = torch._prims_common.reduction_dims(a.shape, dim)
346+
347+
# For complex tensors eager computes the variance as the sum of variances of
348+
# the real and imaginary parts
349+
# TODO: Creating a complex tensor from real and imaginary parts is not supported
350+
if torch._prims_common.is_complex_dtype(a.dtype):
351+
raise NotImplementedError("Complex tensors are not supported")
352+
353+
var_mean = prim(a, dim, correction=correction)
354+
355+
if keepdim:
356+
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
357+
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
358+
var, mean = var_mean
359+
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
360+
mean = torch.ops.nvprims.broadcast_in_dim(
361+
mean, output_shape, broadcast_dims
362+
)
363+
var_mean = (var, mean)
364+
return var_mean
365+
366+
def _var_mean_autograd(
367+
a, dim=None, unbiased=None, keepdim=False, *, correction=None
368+
):
369+
# This wrapper is needed to convert prims calls inside
370+
# elementwise_type_promotion_wrapper to nvprims calls
371+
from torch._prims.context import NvfuserPrimsMode
372+
373+
with NvfuserPrimsMode():
374+
return backwards_not_supported(_var_mean_ref)(
375+
a, dim, unbiased, keepdim, correction=correction
376+
)
377+
378+
nvprim_autograd_impl.impl(name, _var_mean_autograd)
379+
380+
for p in (prim_packet, prim):
381+
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
382+
p.impl_nvfuser = _nvfuser_impls["var_mean"]
383+
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
384+
385+
263386
def register_nvprims():
264387
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
388+
register_var_mean()
265389
for name in nvprim_names:
266390
main_prim = getattr(torch.ops.prims, name)
267391

torch/_prims_common/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,19 @@ class REDUCTION_OUTPUT_TYPE_KIND(Enum):
10331033
ALWAYS_BOOL = (3,)
10341034

10351035

1036+
# Describes the return type of the primitive:
1037+
#
1038+
# - NEW, a new tensor is created
1039+
# - VIEW, a view of an input tensor is returned
1040+
# - INPLACE, one or more input tensors is modified
1041+
#
1042+
# these descriptors are mututally exclusive and exhaustive.
1043+
class RETURN_TYPE(Enum):
1044+
NEW = (0,)
1045+
VIEW = (1,)
1046+
INPLACE = (2,)
1047+
1048+
10361049
# TODO: document type promotion kinds
10371050
def elementwise_dtypes(
10381051
*_args,
@@ -1348,6 +1361,23 @@ def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...
13481361
return dims
13491362

13501363

1364+
def set_correction(
1365+
unbiased: Optional[bool] = None,
1366+
correction: Optional[int] = None,
1367+
):
1368+
if correction is not None and unbiased is not None:
1369+
raise RuntimeError("cannot specify both correction and unbiased arguments")
1370+
elif correction is None and unbiased is None:
1371+
correction = 1
1372+
elif correction is None and unbiased is not None:
1373+
correction = 0 if unbiased is False else 1
1374+
if not isinstance(correction, int):
1375+
raise ValueError("correction argument should be integer")
1376+
if correction < 0:
1377+
raise ValueError("correction argument should be non-negative")
1378+
return correction
1379+
1380+
13511381
def check_in_bounds_for_storage(
13521382
a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
13531383
):

torch/_prims_common/wrappers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,11 @@ def _fn(*args, **kwargs):
118118

119119
result = fn(**bound.arguments)
120120

121-
# FIXME?: assumes result is a single tensor
122-
assert isinstance(result, TensorLike)
123-
return _maybe_convert_to_dtype(result, result_dtype)
121+
if isinstance(result, TensorLike):
122+
return _maybe_convert_to_dtype(result, result_dtype)
123+
if isinstance(result, Sequence):
124+
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
125+
raise AssertionError(f"Unhandled result type: {type(result)}")
124126

125127
_fn.__signature__ = sig # type: ignore[attr-defined]
126128
return _fn

0 commit comments

Comments
 (0)