Skip to content

Add helion grouped GEMM kernel implementation #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mandroid6
Copy link

@mandroid6 mandroid6 commented Jun 16, 2025

Summary:
Implements a grouped GEMM kernel for helion based on the tritonbench reference,
enabling efficient batched matrix multiplication with varying matrix sizes.

Key Features:

- Two implementation strategies to work within helion's type system constraints
- Individual kernel calls approach (grouped_gemm_v2) - simple and robust
- Concatenated tensor approach (grouped_gemm_concatenated) - single kernel launch

Test Plan:

```
python examples/grouped_gemm.py
```

Additional logs:


=== Test Case 1: 3 groups ===
Group sizes: [(128, 128, 64), (256, 256, 128), (64, 64, 32)]
[0s] Starting DifferentialEvolutionSearch with population=40, generations=20, crossover_rate=0.8
[12s] Initial population: min=0.0052 mid=0.0065 max=0.0089 best=Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[8], num_warps=4, num_stages=8, indexing='block_ptr', use_yz_grid=False)
[30s] Generation 2: replaced=17 min=0.0052 mid=0.0055 max=0.0063 best=Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[4], num_warps=4, num_stages=8, indexing='pointer', use_yz_grid=False)
[48s] Generation 3: replaced=17 min=0.0050 mid=0.0054 max=0.0060 best=Config(block_sizes=[32, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=5, indexing='block_ptr', use_yz_grid=True)
[66s] Generation 4: replaced=9 min=0.0050 mid=0.0054 max=0.0059 best=Config(block_sizes=[32, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=5, indexing='block_ptr', use_yz_grid=True)
[84s] Generation 5: replaced=4 min=0.0050 mid=0.0053 max=0.0059 best=Config(block_sizes=[32, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=5, indexing='block_ptr', use_yz_grid=True)
[102s] Generation 6: replaced=9 min=0.0050 mid=0.0052 max=0.0057 best=Config(block_sizes=[32, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=5, indexing='block_ptr', use_yz_grid=True)
[119s] Generation 7: replaced=8 min=0.0050 mid=0.0052 max=0.0056 best=Config(block_sizes=[32, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=5, indexing='block_ptr', use_yz_grid=True)
[137s] Generation 8: replaced=8 min=0.0050 mid=0.0052 max=0.0056 best=Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[156s] Generation 9: replaced=6 min=0.0050 mid=0.0052 max=0.0055 best=Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[174s] Generation 10: replaced=6 min=0.0050 mid=0.0052 max=0.0055 best=Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[192s] Generation 11: replaced=5 min=0.0050 mid=0.0052 max=0.0055 best=Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[210s] Generation 12: replaced=7 min=0.0050 mid=0.0052 max=0.0053 best=Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[228s] Generation 13: replaced=7 min=0.0050 mid=0.0052 max=0.0053 best=Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[246s] Generation 14: replaced=8 min=0.0050 mid=0.0051 max=0.0053 best=Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[264s] Generation 15: replaced=4 min=0.0049 mid=0.0051 max=0.0053 best=Config(block_sizes=[16, 32, 32], loop_orders=[[1, 0]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[282s] Generation 16: replaced=5 min=0.0049 mid=0.0051 max=0.0053 best=Config(block_sizes=[16, 32, 32], loop_orders=[[1, 0]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[299s] Generation 17: replaced=5 min=0.0049 mid=0.0051 max=0.0052 best=Config(block_sizes=[16, 32, 32], loop_orders=[[1, 0]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[317s] Generation 18: replaced=2 min=0.0049 mid=0.0051 max=0.0052 best=Config(block_sizes=[16, 32, 32], loop_orders=[[1, 0]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[335s] Generation 19: replaced=2 min=0.0049 mid=0.0051 max=0.0052 best=Config(block_sizes=[16, 32, 32], loop_orders=[[1, 0]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True)
[335s] Autotuning complete in 335.1s after searching 1520 configs.
One can hardcode the best config and skip autotuning with:
    @helion.kernel(config=helion.Config(block_sizes=[16, 32, 32], loop_orders=[[1, 0]], l2_groupings=[1], num_warps=4, num_stages=4, indexing='pointer', use_yz_grid=True))

✓ Correctness test passed for 3 groups using grouped_gemm_v2 (individual kernels)
Helion time: 0.0123ms
Reference time: 0.0137ms
Speedup: 1.12x

=== Test Case 2: 4 groups ===
Group sizes: [(512, 1024, 256), (128, 512, 128), (256, 256, 64), (64, 128, 32)]
✓ Correctness test passed for 4 groups using grouped_gemm_v2 (individual kernels)
Helion time: 0.0201ms
Reference time: 0.0189ms
Speedup: 0.94x

=== Test Case 3: 2 groups ===
Group sizes: [(1024, 1024, 512), (512, 512, 256)]
✓ Correctness test passed for 2 groups using grouped_gemm_v2 (individual kernels)
Helion time: 0.0279ms
Reference time: 0.0132ms
Speedup: 0.47x

Summary:
    Implements a grouped GEMM kernel for helion based on the tritonbench reference,
    enabling efficient batched matrix multiplication with varying matrix sizes.
    
    ### Key Features:
    - Two implementation strategies to work within helion's type system constraints
    - Individual kernel calls approach (grouped_gemm_v2) - simple and robust
    - Concatenated tensor approach (grouped_gemm_concatenated) - single kernel launch
    - Comprehensive testing and benchmarking utilities
    - MoE-style workload support for expert routing scenarios
    
    
    Test Plan:
    ```
    python examples/grouped_gemm.py
    ```
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 16, 2025
@mandroid6 mandroid6 requested review from jansel and yf225 June 16, 2025 22:13
@@ -0,0 +1,330 @@
"""
Copy link
Contributor

@yf225 yf225 Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to add an expected code check in test_examples.py (similar to other examples).

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test for this in test_examples.py

C_group = acc.view(-1) # Flatten the result
tile_start = c_start + tile_m.begin * N + tile_n.begin
tile_end = tile_start + C_group.numel()
C_concat[tile_start:tile_end] = C_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised this works. Is this tested?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants