File tree Expand file tree Collapse file tree 1 file changed +13
-6
lines changed Expand file tree Collapse file tree 1 file changed +13
-6
lines changed Original file line number Diff line number Diff line change @@ -163,9 +163,12 @@ def pattern(
163163 )
164164 return reduce_scatter
165165
166- def replacement (input : torch .Tensor , mat2 : torch .Tensor ,
167- scale_a : torch .Tensor ,
168- scale_b : torch .Tensor ) -> torch .Tensor :
166+ def replacement (
167+ input : torch .Tensor ,
168+ mat2 : torch .Tensor ,
169+ scale_a : torch .Tensor ,
170+ scale_b : torch .Tensor ,
171+ ) -> torch .Tensor :
169172 # Calculate output shape: input @ mat2 with scatter_dim reduced
170173 output_shape = [* input .shape [:- 1 ], mat2 .shape [1 ]]
171174 scatter_dim = 0
@@ -294,9 +297,13 @@ def pattern(
294297 )
295298 return reduce_scatter
296299
297- def replacement (input : torch .Tensor , mat2 : torch .Tensor ,
298- scale_a : torch .Tensor , scale_b : torch .Tensor ,
299- cutlass_mm_output : torch .Tensor ) -> torch .Tensor :
300+ def replacement (
301+ input : torch .Tensor ,
302+ mat2 : torch .Tensor ,
303+ scale_a : torch .Tensor ,
304+ scale_b : torch .Tensor ,
305+ cutlass_mm_output : torch .Tensor ,
306+ ) -> torch .Tensor :
300307 # Calculate output shape: input @ mat2 with scatter_dim reduced
301308 output_shape = [* input .shape [:- 1 ], mat2 .shape [1 ]]
302309 scatter_dim = 0
You can’t perform that action at this time.
0 commit comments