-
Notifications
You must be signed in to change notification settings - Fork 272
[CK_TILE][FMHA] Add sparse attention VSA #3341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
0607d31
add sparse attention VSA
jiangyon-amd 4fc61d9
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 1dc15fb
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 8ff98b8
fix the pre-commit
jiangyon-amd 3b00e40
Add jenga test and pre-commit
jiangyon-amd 997ec8f
add bf16 for vsa
jiangyon-amd 29d96a9
add jenga support bf16
jiangyon-amd 5e8a010
remove lse arg
jiangyon-amd d2278ab
split kernel code to block & kernel
jiangyon-amd faff9ab
fix the pre-commit
jiangyon-amd 55d9a8e
fix the pre-commit
jiangyon-amd a86fc80
fix the copyrights
jiangyon-amd 12420cd
fix the copyright
jiangyon-amd 776664a
fix the copyright & rename block to pipeline
jiangyon-amd 8ba592d
fix the copyright and pipeline
jiangyon-amd dcd0b5e
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 4176fc6
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 5990286
remove lse & dropout & add fmt
jiangyon-amd 3bada52
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 6a9461f
Merge branch 'develop' into sparse_attention_VSA
poyenc 6aa8466
Merge branch 'develop' into sparse_attention_VSA
asleepzzz 404a7ef
fix the jenga&VSA code review
jiangyon-amd 51846d7
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 83bed30
remove the useless code & resolved the comments
jiangyon-amd 3eef3ad
remove useless code
jiangyon-amd a9456ba
remove useless code
jiangyon-amd 56a3ab3
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 4693a5c
Clean up code
poyenc 685842c
Remove more unused code
poyenc 368d050
Re-format .hpp
poyenc 565ab74
Refactor codegen scripts
poyenc b375f94
Merge branch 'develop' into sparse_attention_VSA
poyenc d885111
Merge branch 'develop' into sparse_attention_VSA
jiangyon-amd 6dd398d
Merge branch 'develop' into sparse_attention_VSA
poyenc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| # SPDX-License-Identifier: MIT | ||
| # CMakeLists.txt for sparse attention (Jenga and VSA) | ||
|
|
||
| # Use SUPPORTED_GPU_TARGETS directly | ||
| set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) | ||
| set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) | ||
|
|
||
| message(STATUS "Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}") | ||
|
|
||
| list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") | ||
| if(NOT INST_TARGETS) | ||
| message(WARNING "Skipping Tile Engine Sparse Attention: No supported GPU targets found") | ||
| return() | ||
| endif() | ||
|
|
||
| message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}") | ||
|
|
||
| # Code generation scripts | ||
| file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS | ||
| ${CMAKE_CURRENT_LIST_DIR}/generate.py | ||
| ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py | ||
| ) | ||
| set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") | ||
|
|
||
| # ============================================================================ | ||
| # Jenga Sparse Attention | ||
| # ============================================================================ | ||
| set(SPARSE_ATTN_JENGA_CODE_GEN_ARGS | ||
| ${CMAKE_CURRENT_LIST_DIR}/generate.py | ||
| --api fwd_jenga | ||
| --receipt 600 | ||
| ) | ||
|
|
||
| # Generate list of Jenga kernels (at configure time, only list) | ||
| execute_process( | ||
| COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS} | ||
| --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt | ||
| RESULT_VARIABLE ret | ||
| ) | ||
| if(ret AND NOT ret EQUAL 0) | ||
| message(FATAL_ERROR "Failed to generate Jenga kernel list") | ||
| endif() | ||
|
|
||
| file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt SPARSE_ATTN_JENGA_GEN_BLOBS) | ||
|
|
||
| # Generate Jenga kernel source files at build time | ||
| add_custom_command( | ||
| OUTPUT ${SPARSE_ATTN_JENGA_GEN_BLOBS} | ||
| COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS} | ||
| --output_dir ${CMAKE_CURRENT_BINARY_DIR} | ||
| DEPENDS ${CODE_GEN_SCRIPTS} | ||
| COMMENT "Generate CK Tile Jenga Sparse Attention kernels" | ||
| ) | ||
|
|
||
| message(STATUS "Jenga kernel files to be generated: ${SPARSE_ATTN_JENGA_GEN_BLOBS}") | ||
|
|
||
| # Jenga Instances | ||
| set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances") | ||
|
|
||
| add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL | ||
| ${SPARSE_ATTN_JENGA_GEN_BLOBS} | ||
| ${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp | ||
| ) | ||
| target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE | ||
| ${CMAKE_CURRENT_LIST_DIR} | ||
| ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn | ||
| ) | ||
| set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) | ||
| set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp PROPERTIES LANGUAGE HIP) | ||
| set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) | ||
|
|
||
| target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE | ||
| -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN | ||
| -DCK_TILE_FMHA_FWD_FAST_EXP2 | ||
| -Wno-undefined-func-template | ||
| -Wno-float-equal | ||
| ) | ||
|
|
||
| # Jenga Example executable | ||
| set(EXAMPLE_JENGA_SPARSE_ATTN "tile_example_jenga_sparse_attn") | ||
| message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}") | ||
| add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp) | ||
| target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) | ||
| target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) | ||
| target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE | ||
| -Wno-undefined-func-template | ||
| -Wno-float-equal | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # VSA Sparse Attention | ||
| # ============================================================================ | ||
| set(SPARSE_ATTN_VSA_CODE_GEN_ARGS | ||
| ${CMAKE_CURRENT_LIST_DIR}/generate.py | ||
| --api fwd_vsa | ||
| --receipt 600 | ||
| ) | ||
|
|
||
| # Generate list of VSA kernels (at configure time, only list) | ||
| execute_process( | ||
| COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS} | ||
| --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt | ||
| RESULT_VARIABLE ret | ||
| ) | ||
| if(ret AND NOT ret EQUAL 0) | ||
| message(FATAL_ERROR "Failed to generate VSA kernel list") | ||
| endif() | ||
|
|
||
| file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS) | ||
|
|
||
| # Generate VSA kernel source files at build time | ||
| add_custom_command( | ||
| OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS} | ||
| COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS} | ||
| --output_dir ${CMAKE_CURRENT_BINARY_DIR} | ||
| DEPENDS ${CODE_GEN_SCRIPTS} | ||
| COMMENT "Generate CK Tile VSA Sparse Attention kernels" | ||
| ) | ||
|
|
||
| message(STATUS "VSA kernel files to be generated: ${SPARSE_ATTN_VSA_GEN_BLOBS}") | ||
|
|
||
| # VSA Instances | ||
| set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances") | ||
|
|
||
| add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL | ||
| ${SPARSE_ATTN_VSA_GEN_BLOBS} | ||
| ${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp | ||
| ) | ||
| target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE | ||
| ${CMAKE_CURRENT_LIST_DIR} | ||
| ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn | ||
| ) | ||
| set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) | ||
| set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp PROPERTIES LANGUAGE HIP) | ||
| set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) | ||
|
|
||
| target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE | ||
| -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN | ||
| -DCK_TILE_FMHA_FWD_FAST_EXP2 | ||
| -Wno-undefined-func-template | ||
| -Wno-float-equal | ||
| ) | ||
|
|
||
| # VSA Example executable | ||
| set(EXAMPLE_VSA_SPARSE_ATTN "tile_example_vsa_sparse_attn") | ||
| message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}") | ||
| add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp) | ||
| target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) | ||
| target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) | ||
| target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE | ||
| -Wno-undefined-func-template | ||
| -Wno-float-equal | ||
| ) | ||
|
|
||
| set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| # SPDX-License-Identifier: MIT | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| # SPDX-License-Identifier: MIT | ||
| # generate kernel instances to speed up compilation | ||
|
|
||
| FWD_DTYPE_MAP = { | ||
| "fp16": "FmhaSparseFwdFp16", | ||
| "bf16": "FmhaSparseFwdBf16", | ||
| } | ||
|
|
||
| _MASK_SIMPLIFIED_MAP = { | ||
| "s_no": "ck_tile::SimplifiedGenericAttentionMask<false>", | ||
| "s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>", | ||
| } | ||
|
|
||
| _MASK_MAP = { | ||
| "no": "FmhaMasks::NoMask", | ||
| "causal": "FmhaMasks::CausalMask", | ||
| "generic": "FmhaMasks::GenericMask", | ||
| } | ||
|
|
||
|
|
||
| def get_mask_map(mask: str): | ||
| if mask == "generic": | ||
| return _MASK_MAP | ||
| elif mask == "simplified": | ||
| return _MASK_SIMPLIFIED_MAP | ||
| else: | ||
| assert False | ||
| return None | ||
poyenc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| _MASK_CHECK_MAP = { | ||
| "no": "t.mask_type == mask_enum::no_mask", | ||
| "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", | ||
| "generic": "t.mask_type == mask_enum::window_generic", | ||
| } | ||
|
|
||
| _MASK_SIMPLIFIED_CHECK_MAP = { | ||
| "s_no": "t.mask_type == mask_enum::no_mask", | ||
| "s_mask": "t.mask_type != mask_enum::no_mask", | ||
| } | ||
|
|
||
|
|
||
| def get_mask_check_map(mask: str): | ||
| if mask == "generic": | ||
| return _MASK_CHECK_MAP | ||
| elif mask == "simplified": | ||
| return _MASK_SIMPLIFIED_CHECK_MAP | ||
| else: | ||
| assert False | ||
| return None | ||
poyenc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| MODE_MAP = {"batch": "false"} | ||
|
|
||
| LAYOUT_MAP = {"row": "true", "col": "false"} | ||
|
|
||
| PIPELINE_MAP = { | ||
| "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga", | ||
| "qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA", | ||
| } | ||
|
|
||
| PIPELINE_ENUM_MAP = { | ||
| "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", | ||
| "qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", | ||
| } | ||
|
|
||
| BOOL_MAP = { | ||
| "t": "true", | ||
| "f": "false", | ||
| True: "true", | ||
| False: "false", | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||
| # SPDX-License-Identifier: MIT | ||
|
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.