Skip to content

Commit 8fe845f

Browse files
Apply ruff-format to collective_fusion.py after rebase
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com> Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
1 parent 4f500ae commit 8fe845f

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,12 @@ def pattern(
163163
)
164164
return reduce_scatter
165165

166-
def replacement(input: torch.Tensor, mat2: torch.Tensor,
167-
scale_a: torch.Tensor,
168-
scale_b: torch.Tensor) -> torch.Tensor:
166+
def replacement(
167+
input: torch.Tensor,
168+
mat2: torch.Tensor,
169+
scale_a: torch.Tensor,
170+
scale_b: torch.Tensor,
171+
) -> torch.Tensor:
169172
# Calculate output shape: input @ mat2 with scatter_dim reduced
170173
output_shape = [*input.shape[:-1], mat2.shape[1]]
171174
scatter_dim = 0
@@ -294,9 +297,13 @@ def pattern(
294297
)
295298
return reduce_scatter
296299

297-
def replacement(input: torch.Tensor, mat2: torch.Tensor,
298-
scale_a: torch.Tensor, scale_b: torch.Tensor,
299-
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
300+
def replacement(
301+
input: torch.Tensor,
302+
mat2: torch.Tensor,
303+
scale_a: torch.Tensor,
304+
scale_b: torch.Tensor,
305+
cutlass_mm_output: torch.Tensor,
306+
) -> torch.Tensor:
300307
# Calculate output shape: input @ mat2 with scatter_dim reduced
301308
output_shape = [*input.shape[:-1], mat2.shape[1]]
302309
scatter_dim = 0

0 commit comments

Comments
 (0)