File tree 1 file changed +6
-6
lines changed 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -670,18 +670,18 @@ def __new__(
670
670
quant_max : Optional [int ] = None ,
671
671
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
672
672
dtype = None ,
673
- # TODO: remove args and kwargs
674
- * args ,
675
- ** kwargs
673
+ strides = None ,
676
674
):
675
+ kwargs = {}
677
676
kwargs ["device" ] = int_data .device
678
677
kwargs ["layout" ] = (
679
678
kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else int_data .layout
680
679
)
681
680
if dtype is None :
682
681
dtype = scale .dtype
683
682
kwargs ["dtype" ] = dtype
684
- assert not kwargs .get ("requires_grad" , False )
683
+ if strides is not None :
684
+ kwargs ["strides" ] = strides
685
685
kwargs ["requires_grad" ] = False
686
686
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
687
687
@@ -696,8 +696,7 @@ def __init__(
696
696
quant_max : Optional [int ] = None ,
697
697
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
698
698
dtype = None ,
699
- * args ,
700
- ** kwargs
699
+ strides = None ,
701
700
):
702
701
self .int_data = int_data
703
702
self .scale = scale
@@ -912,6 +911,7 @@ def _apply_fn_to_data(self, fn):
912
911
self .quant_max ,
913
912
self .zero_point_domain ,
914
913
dtype = self .dtype ,
914
+ strides = self .stride (),
915
915
)
916
916
917
917
@classmethod
You can’t perform that action at this time.
0 commit comments