-
Notifications
You must be signed in to change notification settings - Fork 39
Refactor attention mask and bias structures for clarity #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Improves code readability by renaming function methods to better reflect their purpose: - zoh_offset becomes attn_mask_offset - active_mask_offset becomes attn_bias_offset Makes the codebase more self-documenting by using descriptive names that clearly indicate the functions calculate offsets for attention masks and bias tensors respectively.
Splits the monolithic ZOH_params struct into two focused components: Mask_params for attention masking operations and Bias_params for attention bias handling. Simplifies parameter management by grouping related functionality and improves code organization for flash attention operations.
Updates function name from copy_ZOH to copy_Mask to better describe its actual functionality and improve code clarity.
Simplifies the mask struct naming by removing "Dynamic" prefix for better clarity. Updates parameter names from ZOH-related terminology to more descriptive "Mask" and "Bias" names, making the code more readable and self-documenting. Changes affect function signatures, variable names, and comments to reflect the new terminology while maintaining the same functionality.
Refactors shared memory layout definitions to distinguish between mask and bias operations, replacing the combined ZOH/ActiveMask approach with separate SmemLayoutMask and SmemLayoutBias structures. Updates memory size calculations to account for both mask and bias components independently, improving clarity and maintainability of the memory management system.
Updates parameter names from generic `zoh` and `active_mask` to more descriptive `attn_mask` and `attn_bias` throughout the flash attention API. Improves code readability and aligns naming conventions with standard attention mechanism terminology.
Improves code readability by replacing confusing ZOH/ActiveMask naming with clearer Mask/Bias terminology throughout the attention kernel. Updates variable names, tensor declarations, and function calls to use consistent naming conventions that better reflect the actual purpose of these components in the attention computation.
Corrects the argument mapping in the flash_dma_cuda.fwd call by swapping the order of zero_hold_states and active_mask parameters to match the expected function signature. The change ensures proper parameter alignment where active_mask is passed as the attn_mask argument and attn_mask is passed as the bias argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the handling of attention masks and biases by splitting the previous ZOH and active-mask logic into dedicated Mask and Bias structures, renaming related functions and parameters, and realigning CUDA kernel calls to the new API.
- Rename
copy_ZOHtocopy_MaskandDynamicMasktoMask, updating their signatures to accept both mask and bias tensors. - Update kernel traits to introduce separate layouts and copy types for mask and bias, and adjust shared-memory size calculations.
- Split parameters structures into
Mask_paramsandBias_params, update Python/C++ API calls and offsets to useattn_maskandattn_bias.
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| csrc/src/utils.h | Renamed copy_ZOH to copy_Mask. |
| csrc/src/mask.h | Renamed DynamicMask to Mask and updated apply_mask signature. |
| csrc/src/kernel_traits.h | Added SmemLayoutMask/SmemLayoutBias and GmemTiledCopyMask/GmemTiledCopyBias. |
| csrc/src/flash_fwd_kernel.h | Replaced ZOH/active-mask tensors with mask/bias tensors and updated all copy/apply calls. |
| csrc/src/flash.h | Split ZOH_params into Mask_params and Bias_params. |
| csrc/src/block_info.h | Renamed offset helpers to attn_mask_offset and attn_bias_offset. |
| csrc/flash_api.cpp | Updated API to accept attn_mask and attn_bias instead of zoh and active_mask. |
| benchmarks/benchmark_forward_equivalence.py | Updated test call to new mask/bias arguments. |
Comments suppressed due to low confidence (2)
csrc/src/utils.h:503
- [nitpick] The function name
copy_Maskis now used for copying both mask and bias tensors, which can be misleading. Consider renaming it to a more generic name (e.g.,copy_AttnParam) or introducing separatecopy_Maskandcopy_Biasoverloads for clarity.
__forceinline__ __device__ void copy_Mask(
csrc/src/mask.h:41
- [nitpick] The parameter names
MaskandBiasinapply_maskconflict with the type names and local variables, which may cause confusion. Consider renaming them to lowercase or more descriptive names (e.g.,mask_tensor,bias_tensor).
MaskType &Mask, // Attention Mask (MMA=4, MMA_M, MMA_N)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Enhance code readability and organization by renaming functions and parameters related to attention masks and biases. Separate the ZOH parameters into distinct mask and bias structures, improving clarity in memory management and function signatures. Adjust the CUDA function call to ensure proper parameter alignment.