Skip to content

Conversation

@Alex4210987
Copy link
Collaborator

No description provided.

xinxyxiao added 5 commits July 29, 2025 03:26
…nd clarity (tile-ai#668)

- Enhanced buffer index handling to address precision issues by removing redundant operations.
- Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection.
- Updated related documentation to reflect changes in buffer management practices.
…ed flexibility

- Introduced a new input.txt file for configurable parameters.
- Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack.
- Streamlined the main function for better clarity and organization.
- Added a new test script to facilitate running the example with specified parameters.
… example with swizzle layout annotations

- Deleted input.txt and test.sh files as they are no longer needed.
- Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance.
- Reintroduced swizzle usage in the kernel for better performance.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Alex4210987, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new example demonstrating an optimized FlashAttention-2 forward pass specifically tailored for AMD MI300 series GPUs. It leverages the tilelang framework to implement a highly performant attention mechanism, including autotuning capabilities and correctness verification against a PyTorch reference implementation. Minor adjustments to a core reduction template are also included to support this new functionality.

Highlights

  • New FlashAttention-2 Example for AMD MI300: I've added a new Python example (examples/amd/example_amd_flash_attn_fwd.py) that implements a FlashAttention-2 forward pass. This example is designed to run on AMD MI300 series GPUs, leveraging tilelang for optimized performance.
  • Autotuning and Verification: The new FlashAttention-2 example includes an autotuner to discover optimal configurations for various parameters (e.g., block_M, block_N, threads, num_stages). It also provides a PyTorch reference implementation and asserts correctness, ensuring the optimized kernel produces accurate results.
  • Core Kernel Implementation: The FlashAttention-2 logic is encapsulated within a T.prim_func that handles block-wise processing of Q, K, V tensors, shared memory management, GEMM operations, and causal masking, following the principles of FlashAttention-2.
  • HIP Reduction Template Enhancement: I've updated the AllReduce struct in src/tl_templates/hip/reduce.h by adding a thread_offset template parameter. This modification provides more flexibility for reduction operations within the HIP backend, which may be utilized by the new FlashAttention kernel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

- Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`.
- Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls.
- Removed outdated comments and improved code organization for better readability.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new example for the FlashAttention forward pass on AMD MI300 series GPUs using tilelang. The implementation follows the principles of FlashAttention-2 and includes autotuning, verification against a PyTorch reference, and benchmarking. A minor change is also made to a C++ template in reduce.h.


def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is in Chinese. For consistency and to make the code more accessible to a wider audience, it's best to use English for all comments.

Suggested change
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
# PyTorch reference implementation

"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
block_M = [64, 128, 256]
block_N = [32, 64, 128]
threads = [128, 256, 512]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment to explain the purpose of annotating the layout for Q_shared and how it optimizes memory access.

        T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) # Optimize memory access by using a swizzled layout

scale = (1.0 / dim)**0.5 * 1.44269504
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 1.44269504 is an approximation of 1/ln(2) (or log2(e)), used to convert natural exponentiation to base-2 exponentiation (exp(x) = exp2(x * log2(e))). Adding a comment to explain this will improve readability.

    scale = (1.0 / dim)**0.5 * 1.44269504  # 1/sqrt(dim) * log2(e)

def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is in Chinese. For consistency and to make the code more accessible to a wider audience, it's best to use English for all comments.

# main function

};

template <class Reducer, int threads, int scale> struct AllReduce {
template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new template parameter thread_offset is introduced but it is not used within the AllReduce struct. This is potentially dead code and makes the template signature more complex than necessary. If this parameter is intended for future use, a comment explaining its purpose would be beneficial. Otherwise, it should be removed to improve code clarity.

- Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function.
- Streamlined the `main` function parameter formatting for consistency.
- Removed unnecessary blank lines to enhance overall code organization.
@@ -0,0 +1,240 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lisence should be removed.

@LeiWang1999 LeiWang1999 merged commit adcba27 into tile-ai:main Jul 31, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants