5
5
# can be added in the future for the corresponding higher-level torch/aten
6
6
# functions.
7
7
8
- from typing import Any , Dict , Optional
8
+ from typing import Any , Dict
9
9
10
10
import torch
11
11
12
12
from torch ._prims_common import (
13
13
DimsSequenceType ,
14
- ELEMENTWISE_TYPE_PROMOTION_KIND ,
15
14
getnvFuserDtype ,
16
15
ShapeType ,
17
16
TensorLikeType ,
18
17
)
19
18
20
- from torch ._prims_common .wrappers import (
21
- backwards_not_supported ,
22
- elementwise_type_promotion_wrapper ,
23
- )
19
+ from torch ._prims_common .wrappers import backwards_not_supported
24
20
25
21
nvprim_namespace = "nvprims"
26
22
nvprim = torch .library .Library (nvprim_namespace , "DEF" )
27
23
nvprim_impl = torch .library .Library (
28
24
nvprim_namespace , "IMPL" , "CompositeExplicitAutograd"
29
25
)
30
- nvprim_implicit_impl = torch .library .Library (
31
- nvprim_namespace , "IMPL" , "CompositeImplicitAutograd"
32
- )
33
26
nvprim_autograd_impl = torch .library .Library (nvprim_namespace , "IMPL" , "Autograd" )
34
27
nvprim_meta_impl = torch .library .Library (nvprim_namespace , "IMPL" , "Meta" )
35
28
@@ -241,23 +234,6 @@ def _var_nvfuser(
241
234
return fd .ops .var (a , dims , correction , keep_dims )
242
235
243
236
244
- def _var_mean_nvfuser (
245
- fd : Any ,
246
- a : TensorLikeType ,
247
- dims : DimsSequenceType ,
248
- unbiased : Optional [bool ] = None ,
249
- keepdim : bool = False ,
250
- * ,
251
- correction : int ,
252
- ):
253
- # Unbiased arg shouldn't be set when this function is called
254
- assert unbiased is None
255
- # Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
256
- # keepdim is handled by the reference implementation
257
- keepdim = False
258
- return fd .ops .var_mean (a , dims , correction , keepdim )
259
-
260
-
261
237
def _amax_nvfuser (
262
238
fd : Any ,
263
239
a : TensorLikeType ,
@@ -280,112 +256,12 @@ def _amin_nvfuser(
280
256
_nvfuser_impls ["convert_element_type" ] = _convert_element_type_nvfuser
281
257
_nvfuser_impls ["sum" ] = _sum_nvfuser
282
258
_nvfuser_impls ["var" ] = _var_nvfuser
283
- _nvfuser_impls ["var_mean" ] = _var_mean_nvfuser
284
259
_nvfuser_impls ["amax" ] = _amax_nvfuser
285
260
_nvfuser_impls ["amin" ] = _amin_nvfuser
286
261
287
262
288
- def register_var_mean ():
289
- """This function is used to register the var_mean function in torch.ops.nvprims module."""
290
- name = "var_mean.main"
291
-
292
- # This overload must be default for correct dispatching of var_mean(Tensor, bool)
293
- nvprim .define ("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)" )
294
-
295
- # This signature tries to combine several overloads of the torch.var_mean function into one overload.
296
- nvprim .define (
297
- f"{ name } (Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
298
- + " -> (Tensor, Tensor)"
299
- )
300
-
301
- # This function is used for device="meta" Tensors.
302
- def _meta_var_mean (inp , dim = None , unbiased = None , keepdim = False , * , correction = None ):
303
- if torch ._prims_common .is_complex_dtype (inp .dtype ):
304
- output_dtype = torch ._prims_common .corresponding_real_dtype (inp .dtype )
305
- else :
306
- output_dtype = inp .dtype
307
- var = torch ._prims ._reduction_meta (inp , dim , output_dtype = output_dtype )
308
- mean = torch ._prims ._reduction_meta (inp , dim , output_dtype = inp .dtype )
309
- if keepdim :
310
- output_shape = [
311
- inp .shape [i ] if i not in dim else 1 for i in range (inp .ndim )
312
- ]
313
- broadcast_dims = [i for i in range (inp .ndim ) if i not in dim ]
314
- var = torch .ops .nvprims .broadcast_in_dim (var , output_shape , broadcast_dims )
315
- mean = torch .ops .nvprims .broadcast_in_dim (
316
- mean , output_shape , broadcast_dims
317
- )
318
- return (var , mean )
319
-
320
- # This function is used under _AutoDispatchBelowAutograd context
321
- def _prim_impl (inp , dim = None , unbiased = None , keepdim = False , * , correction = None ):
322
- correction = torch ._prims_common .set_correction (unbiased , correction )
323
- return torch .var_mean (inp , dim , correction = correction , keepdim = keepdim )
324
-
325
- nvprim_impl .impl (name , _prim_impl )
326
- nvprim_meta_impl .impl (name , _meta_var_mean )
327
-
328
- prim_packet = torch .ops .nvprims .var_mean
329
- prim = prim_packet .main
330
-
331
- def _unbiased_overload_impl (inp , unbiased ):
332
- return prim (inp , dim = None , unbiased = unbiased )
333
-
334
- nvprim_implicit_impl .impl ("var_mean" , _unbiased_overload_impl )
335
-
336
- @elementwise_type_promotion_wrapper (
337
- type_promoting_args = ("a" ,),
338
- type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND .COMPLEX_TO_FLOAT ,
339
- )
340
- def _var_mean_ref (a , dim = None , unbiased = None , keepdim = False , * , correction = None ):
341
- correction = torch ._prims_common .set_correction (unbiased , correction )
342
- # reduces over all dimensions if dim=() is passed
343
- if dim == () or dim == []:
344
- dim = None
345
- dim = torch ._prims_common .reduction_dims (a .shape , dim )
346
-
347
- # For complex tensors eager computes the variance as the sum of variances of
348
- # the real and imaginary parts
349
- # TODO: Creating a complex tensor from real and imaginary parts is not supported
350
- if torch ._prims_common .is_complex_dtype (a .dtype ):
351
- raise NotImplementedError ("Complex tensors are not supported" )
352
-
353
- var_mean = prim (a , dim , correction = correction )
354
-
355
- if keepdim :
356
- output_shape = [a .shape [i ] if i not in dim else 1 for i in range (a .ndim )]
357
- broadcast_dims = [i for i in range (a .ndim ) if i not in dim ]
358
- var , mean = var_mean
359
- var = torch .ops .nvprims .broadcast_in_dim (var , output_shape , broadcast_dims )
360
- mean = torch .ops .nvprims .broadcast_in_dim (
361
- mean , output_shape , broadcast_dims
362
- )
363
- var_mean = (var , mean )
364
- return var_mean
365
-
366
- def _var_mean_autograd (
367
- a , dim = None , unbiased = None , keepdim = False , * , correction = None
368
- ):
369
- # This wrapper is needed to convert prims calls inside
370
- # elementwise_type_promotion_wrapper to nvprims calls
371
- from torch ._prims .context import NvfuserPrimsMode
372
-
373
- with NvfuserPrimsMode ():
374
- return backwards_not_supported (_var_mean_ref )(
375
- a , dim , unbiased , keepdim , correction = correction
376
- )
377
-
378
- nvprim_autograd_impl .impl (name , _var_mean_autograd )
379
-
380
- for p in (prim_packet , prim ):
381
- p .__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
382
- p .impl_nvfuser = _nvfuser_impls ["var_mean" ]
383
- p .return_type = torch ._prims_common .RETURN_TYPE .NEW # type: ignore[attr-defined]
384
-
385
-
386
263
def register_nvprims ():
387
264
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
388
- register_var_mean ()
389
265
for name in nvprim_names :
390
266
main_prim = getattr (torch .ops .prims , name )
391
267
0 commit comments