@@ -169,15 +169,23 @@ def replacement(
169169 scale_a : torch .Tensor ,
170170 scale_b : torch .Tensor ,
171171 ) -> torch .Tensor :
172+ # Calculate output shape: input @ mat2 with scatter_dim reduced
173+ output_shape = [* input .shape [:- 1 ], mat2 .shape [1 ]]
174+ scatter_dim = 0
172175 gemm_rs = torch .ops .symm_mem .fused_scaled_matmul_reduce_scatter (
173176 input ,
174177 mat2 ,
175178 scale_a ,
176179 scale_b ,
177180 "avg" ,
178- scatter_dim = 0 ,
179- out_dtype = self .dtype ,
180- group_name = self .tp .device_group .group_name ,
181+ scatter_dim , # orig_scatter_dim
182+ scatter_dim , # scatter_dim_after_maybe_reshape
183+ self .tp .device_group .group_name ,
184+ output_shape ,
185+ None , # bias
186+ None , # result_scale
187+ self .dtype , # out_dtype
188+ False , # use_fast_accum
181189 )
182190
183191 return gemm_rs
@@ -296,15 +304,23 @@ def replacement(
296304 scale_b : torch .Tensor ,
297305 cutlass_mm_output : torch .Tensor ,
298306 ) -> torch .Tensor :
307+ # Calculate output shape: input @ mat2 with scatter_dim reduced
308+ output_shape = [* input .shape [:- 1 ], mat2 .shape [1 ]]
309+ scatter_dim = 0
299310 gemm_rs = torch .ops .symm_mem .fused_scaled_matmul_reduce_scatter (
300311 input ,
301312 mat2 ,
302313 scale_a ,
303314 scale_b ,
304315 "avg" ,
305- scatter_dim = 0 ,
306- out_dtype = self .dtype ,
307- group_name = self .tp .device_group .group_name ,
316+ scatter_dim , # orig_scatter_dim
317+ scatter_dim , # scatter_dim_after_maybe_reshape
318+ self .tp .device_group .group_name ,
319+ output_shape ,
320+ None , # bias
321+ None , # result_scale
322+ self .dtype , # out_dtype
323+ False , # use_fast_accum
308324 )
309325
310326 return gemm_rs
0 commit comments