Skip to content

[Enhancement] Support more flexible layout host pythonic expr #623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2025

Conversation

LeiWang1999
Copy link
Member

  • Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
  • Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

… to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and consistency of expression handling within the JIT adapter. It expands the pythonic_expr utility to support a wider range of TVM TIR expressions, ensuring more accurate and readable string representations. Concurrently, it migrates the kernel launch code generation from using a deprecated legalize_c function to leverage the enhanced pythonic_expr, streamlining the codebase and preparing for future enhancements.

Highlights

  • Enhanced Expression Handling: Extended the pythonic_expr utility in tilelang/jit/adapter/utils.py to correctly convert additional TVM TIR expressions, specifically FloorDiv, Min, Max, Add, Sub, and FloorMod, into their appropriate Pythonic string representations. This improves the flexibility and accuracy of expression stringification.
  • Code Standardization and Deprecation: Replaced all usages of the deprecated legalize_c helper function with the now more capable pythonic_expr within the TLCUDASourceWrapper and TLCPUSourceWrapper classes in tilelang/jit/adapter/wrapper.py. This change standardizes how kernel launch parameters (like block and grid dimensions) are handled, ensuring consistency and leveraging the improved expression conversion.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the pythonic_expr function to support more TIR expression types, making it more flexible for representing expressions in a Python-like string format. It also refactors the code by replacing the deprecated legalize_c function with pythonic_expr in TLCUDASourceWrapper and removing the unused legalize_c from TLCPUSourceWrapper. The changes are logical and improve the codebase. I've suggested a refactoring to reduce code duplication in the _pythonic_visitor function, which should improve maintainability, and pointed out where the deprecated function is still being called.

Comment on lines 132 to 155
elif isinstance(node, tvm.tir.FloorDiv):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} / {b_str}"
elif isinstance(node, tvm.tir.Min):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"min({a_str}, {b_str})"
elif isinstance(node, tvm.tir.Max):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"max({a_str}, {b_str})"
elif isinstance(node, tvm.tir.Add):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} + {b_str}"
elif isinstance(node, tvm.tir.Sub):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} - {b_str}"
elif isinstance(node, tvm.tir.FloorMod):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} % {b_str}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new elif blocks for handling different TIR expressions introduce some code repetition, specifically for retrieving the string representations of operands a and b. This can be refactored to improve code clarity and maintainability.

Suggested change
elif isinstance(node, tvm.tir.FloorDiv):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} / {b_str}"
elif isinstance(node, tvm.tir.Min):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"min({a_str}, {b_str})"
elif isinstance(node, tvm.tir.Max):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"max({a_str}, {b_str})"
elif isinstance(node, tvm.tir.Add):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} + {b_str}"
elif isinstance(node, tvm.tir.Sub):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} - {b_str}"
elif isinstance(node, tvm.tir.FloorMod):
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} % {b_str}"
elif isinstance(node, (tvm.tir.FloorDiv, tvm.tir.Min, tvm.tir.Max, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod)):
op_format_map = {
tvm.tir.FloorDiv: "{} / {}",
tvm.tir.Min: "min({}, {})",
tvm.tir.Max: "max({}, {})",
tvm.tir.Add: "{} + {}",
tvm.tir.Sub: "{} - {}",
tvm.tir.FloorMod: "{} % {}",
}
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = op_format_map[type(node)].format(a_str, b_str)

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.
…tion

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.
@LeiWang1999 LeiWang1999 merged commit f2203ae into tile-ai:main Jul 10, 2025
2 of 3 checks passed
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 18, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 18, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Jul 20, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 added a commit to Hzfengsy/tilelang that referenced this pull request Jul 24, 2025
…i#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
LeiWang1999 pushed a commit that referenced this pull request Jul 30, 2025
**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in T.thread_binding(length, "blockIdx.x"):
               for tx in T.thread_binding(length, "threadIdx.x"):
                   C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Device call
   ```
   After compilation, `add` becomes a CUDA `__device__` function.

4. **Cython-based Python/C++ interop**  
   Replaced ctypes with Cython for all Python/C++ interactions:
   - Python → C++ calls
   - C++ → Cython calls  
   This improves performance by around 100x and reduces CPU overhead during compile/runtime.

5. **FP8 data type standardization**  
   Migrated `e5m2_float8` and similar types to Torch-standardized variants`float8_e5m2` and etc.



* Refactor CMakeLists.txt to set default build type and manage dependencies for tvm_cython modules

* Update default value of `check_well_formed` parameter in `prim_func` to False for improved flexibility in TIR function parsing.

* Add StorageRewrite function to transform module

Introduced the StorageRewrite function in the tilelang.transform module, which returns a TVM transform pass. This addition enhances the functionality of the module by providing a new transformation option for users.

* Refactor null option handling in IR and layout inference

- Updated instances of `NullOpt` to `std::nullopt` in `ir.cc` and `parallel.cc` for consistency with modern C++ practices.
- Enhanced layout inference logic in `layout_inference.cc` to improve type safety by replacing `as<Fragment>().get()` with `as<FragmentNode>()`.
- Adjusted error handling in `multi_version_buffer_rewriter.cc` and `persist_threadblock.cc` to use more concise null checks.
- Cleaned up test files by commenting out `tilelang.testing.main()` and replacing it with specific test function calls for better clarity.
- Removed unused test file `test_tilelang_kernel_deepseek_nsa.py` to streamline the testing suite.

* Update TVM subproject and refactor cluster planning and tile operation handling

- Updated the TVM subproject to a dirty commit state.
- Refactored copyright headers in `cluster_planning.cc` to reflect the new licensing.
- Enhanced error handling in `lower_tile_op.cc` to check for missing padding map annotations.
- Modified test files to improve clarity and functionality, including adjustments to kernel compilation and test assertions.
- Updated various test cases to ensure proper handling of annotations and configurations in the TileLang testing framework.

* Update annotation type in warp specialized test for consistency

- Changed the annotation type in the `test_warp_specialized` function from a literal integer to `T.int32(3)` for improved type safety and consistency with the TileLang framework.

* Refactor test execution in warp specialized test

- Replaced the direct call to `test_warp_specialized()` with `tilelang.testing.main()` in the test file to standardize test execution and improve integration with the TileLang testing framework.

* refactor

* [Enhancement] Add strict layout map for improved buffer layout inference (#594)

- Introduced a `strict_layout_map` to enhance layout inference by ensuring that buffers with strict layout requirements are properly accounted for during the inference process.
- Updated the inference logic to check for the presence of buffers in the `strict_layout_map` before applying layout changes, improving the accuracy of layout assignments.
- Refactored the layout inference steps to include the copying of layouts into the new strict map, ensuring a clear separation of layout handling based on inference levels.

* [Example] Update examples to use @tilelang.jit (#597)

* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Enhancement] Refine error messaging in LowerBulkCopy for global and shared range checks (#599)

* [Enhancement] Improve error messaging for global and shared range legality checks in LowerBulkCopy

- Updated error messages in the LowerBulkCopy function to provide clearer context when global and shared ranges are illegal.
- Enhanced the readability of the error output by including tensor names, improving debugging and validation processes during bulk copy operations.

* [Enhancement] Refine error messaging in LowerBulkCopy for global and shared range checks

- Improved the clarity of error messages in the LowerBulkCopy function by enhancing the output format.
- Included additional context in error messages to aid debugging when global and shared ranges are found to be illegal, ensuring better traceability during bulk copy operations.

* [Enhancement] Introduce PassConfig `TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE` to enable aggressive shared memory reuse (#602)

* [Enhancement] Add aggressive shared memory merge option in memory allocation

- Introduced a new configuration option `tl.enable_aggressive_shared_memory_merge` to enable aggressive merging of shared memory allocations.
- Updated the `SharedMemLinearAccessPatternFinder` class to support an aggressive merge strategy, allowing for improved memory reuse.
- Modified the `MergeSharedMemoryAllocations` function to incorporate the new merging strategy based on the configuration.
- Enhanced the `PassConfigKey` enumeration to include the new aggressive merge option, ensuring it can be configured appropriately.

* lint fix

* [Enhancement] Add aggressive shared memory merge configuration option

- Introduced a new configuration option `kEnableAggressiveSharedMemoryMerge` to enable aggressive merging of shared memory allocations, enhancing memory management capabilities.

* [Enhancement] Update MergeSharedMemoryAllocations to support aggressive merge option

- Modified the `MergeSharedMemoryAllocations` function to accept an `enable_aggressive_merge` parameter, allowing for more flexible memory management.
- Introduced a new helper function `should_enable_aggressive_merge` to determine the aggressive merge configuration based on the pass context and target.
- Updated the relevant calls in the `phase.py` and `__init__.py` files to utilize the new aggressive merge functionality, enhancing the overall memory allocation strategy.

* [Refactor] Update accumulation handling in gemm_sm90.h (#603)

- Replaced the use of `tiled_mma.accumulate_ = GMMA::ScaleOut::Zero` with a call to `clear(acc)` for better clarity and maintainability in the accumulation logic.
- This change enhances the readability of the code by standardizing the approach to clearing accumulation values across multiple sections of the file.

* [Enhancement] Add tma bulk copy. (#600)

* [Bugfix] Fixed mha_bwd shape inconsistency error (#604)

* lint fix

* Update requirements-lint.txt to maintain clang-format version consistency

* [Bugfix] Avoid duplicate data access when cross thread buffer meet replicate register (#606)

* [Enhancement] Improve debug output formatting in layout and fragment nodes

- Updated the `DebugOutput` methods in `LayoutNode` and `FragmentNode` to provide more structured and informative output, including transformation details and thread range information.
- Enhanced layout inference logic in `ParallelOp` to add predicates for cross-thread shared memory access, improving layout handling in parallel operations.
- Minor adjustment in `layout_inference.cc` to ensure clarity in parallel loop handling.

* lint fix

* [Enhancement] Support tf32 gemm_rs (#607)

- Added a line break in `quickstart.py` for better readability.
- Simplified the JIT kernel compilation in `quickstart.py` by removing the unused execution backend option.
- Modified `example_elementwise_add.py` to disable cache for `tilelang` and optimized the element-wise addition kernel by utilizing shared memory for input tensors, improving performance.
- Updated default values for matrix dimensions and block sizes in the argument parser to enhance usability.

* [Enhancement] Introduce option `TL_DISABLE_FAST_MATH` and `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` (#609)

* [Enhancement] Introduce new PassConfig options for fast math and PTXAS verbosity

- Added `kDisableFastMath` and `kEnablePTXASVerboseOutput` configuration options to enhance control over compilation settings.
- Updated `LibraryGenerator` to utilize these new pass configurations, allowing for more flexible compilation behavior based on user preferences.
- Enhanced `PassConfigKey` enumeration to include the new options, ensuring they can be configured appropriately in the pass context.

* [Refactor] Update PTXAS verbosity configuration key in LibraryGenerator

- Changed the configuration key for PTXAS verbosity from `TL_VERBOSE_PTXAS_OUTPUT` to `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` to align with the new naming convention introduced in recent enhancements.
- This update ensures consistency in the configuration options used within the `LibraryGenerator` class, improving clarity and maintainability of the code.

* lint fix

* fix build

* [Experimental][Language] add `T.GEMM_SP` for sm90 sparse tensor core (#526)

* [experimental] add a draft gemm_sp

* [3rdparty] bump cutlass to v3.9.3

* [lint] run format.sh

* [chore] rebase

* [chore] use abs path

* [gemm_sp] add metadata layout

* [ci] add more example

* [lint] run format.sh

* [chore] polish

* [chore] move gemm_sp to experimental

* [chore] polish

* [lint] run format.sh

* [Enhancement] Improve bulk copy handling and update GEMM sparse tensor test

* Added a warning log for unsupported non-swizzled global layouts in the bulk copy operation, ensuring fallback to normal copy.
* Refactored the GEMM sparse tensor test by removing unnecessary imports and simplifying the kernel compilation process.
* Updated the test to directly call the `run_gemm_sp` function, enhancing clarity and functionality.

* Implement Test

* [Enhancement] Update GEMM SP and SM89 templates for improved functionality

* Refactored GEMM SP computation to enhance warp partitioning logic, ensuring compatibility with Hopper architecture.
* Updated layout inference to support new WGMMA conditions and improved error messaging for unsupported targets.
* Modified SM89 templates to utilize new MMA atom structures, enhancing performance and compatibility with fp8 types.
* Added conditional inclusion for GEMM SP header based on CUDA architecture version.

* lint fix

* [gemm_sp] support more layout and data types

* Enhancement: sync T.gemm_sp's layout inference with T.gemm

* Enhancement: support more block_k in compress util

* [Enhancement] enable block_k=64

* [Lint] run format.sh

* [Enhancement] compressor support more dtype

* Enhancement: enable block_K=32

* [Lint] format.sh

* [Fixbug] fix shape

* Refactor: sync gemm

* [Enhancement] enable transpose

* [Enhancement] enable fp8_e4m3

* [Enhancement] enable int8

* [Lint] run format.sh

* [Benchmark] add gemm_sp benchmark

* [Example] fix 256 threads hang

* [CI] fix ci

* [Chore] resolve gemini feedback

* [Benchmark] increase search space

* [Lint] format

* [CI] skip sparse tensor core related tests as only sm90 is supported

* [CI] pass local run

* Update gemm_sm89.h

* lint fix

* lint fix

* [Enhancement] Add support for sparse GEMM and initialize CUDA architecture flags

- Introduced a new boolean flag `enable_sparse_gemm_` to control the inclusion of sparse GEMM functionality in CUDA code generation.
- Updated the `Finish` method to conditionally include the sparse GEMM header based on the new flag.
- Implemented logic in `VisitStmt_` to enable sparse GEMM when the corresponding external call is detected.
- Added a function to initialize the `TORCH_CUDA_ARCH_LIST` environment variable based on the target compute version, enhancing compatibility with PyTorch.
- Refactored the initialization function into the appropriate module and ensured it is called in the sparse utilities module.

* Update test_compress_utils.py

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Doc] Phaseout Legacy documentations (#610)

- Added a new entry in the README for the introduction of `T.gemm_sp` supporting 2:4 sparse tensor core.
- Removed several outdated documentation files related to convolution, flash attention, and other tutorials to streamline the documentation structure.

* [Refactor] Phaseout Pass ParallelLoopTransformer (#611)

* Refactor layout inference by removing the ParallelLoopTransformer class. Updated layout inference logic to streamline buffer access collection and condition handling in parallel loops. This change simplifies the code structure and enhances maintainability.

* Update MHA backward test cases to use reduced dimensions for batch size and context length

* fix build

* [Enhancement] Update ReduceOp initialization values for integer types (#614)

* [Enhancement] Update ReduceOp initialization values for integer types

- Modified the `MakeInitValue` method in `ReduceOp` to handle integer data types correctly by returning appropriate minimum and maximum values based on the bit width.
- Added checks for integer types to ensure correct initialization for `kMax` and `kMin` reduction types, enhancing the robustness of the reduction operations.

* [Enhancement] Update ReduceOp to handle unsigned integer initialization values

- Enhanced the `MakeInitValue` method in `ReduceOp` to include support for unsigned integer data types.
- Added conditions to return appropriate initialization values for `kMax` and `kMin` reduction types based on the data type, improving the robustness of reduction operations.

* Bump transformers from 4.50.0 to 4.51.0 in /examples/bitnet-1.58b (#615)

Bumps [transformers](https://github.com/huggingface/transformers) from 4.50.0 to 4.51.0.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](huggingface/transformers@v4.50.0...v4.51.0)

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 4.51.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Refactor] refactor autotune examples (#617)

* [Refactor] Update tilelang kernel functions and remove unused imports

- Refactored the `flashattn_fwd`, `flashattn_bwd_preprocess`, and `flashattn_bwd_postprocess` functions to utilize direct kernel calls instead of cached versions, improving clarity and performance.
- Added `@tilelang.jit` decorators with specified output indices to enhance kernel compilation.
- Removed unused import of `cached` from `tilelang`, streamlining the code.
- Commented out the main testing function call in `test_tilelang_kernel_mha_bwd.py` for potential future use.

* [Refactor] Simplify configuration generation in benchmark and example scripts

- Refactored the `get_configs` functions in multiple benchmark and example scripts to utilize a dictionary-based approach for parameter configuration, improving readability and maintainability.
- Updated the `flashattn` and `chunk_scan_fwd` functions to directly accept configuration parameters, enhancing flexibility in kernel tuning.
- Removed redundant code and streamlined the configuration generation process across various files, ensuring consistency in how configurations are defined and utilized.

* [Refactor] Update configuration handling in benchmark scripts

- Refactored the `get_configs` functions in benchmark scripts to accept a variable argument list, improving flexibility in configuration management.
- Enhanced the `matmul` and `flashattn` functions to utilize the updated configuration approach, streamlining parameter handling for kernel tuning.
- Added `@autotune` decorators to relevant functions, ensuring consistent autotuning behavior across benchmarks.
- Cleaned up redundant code and improved overall readability in the affected files.

* [Refactor] Clean up formatting and update subproject commit

- Updated the subproject commit reference in the TVM directory to indicate a dirty state.
- Removed unnecessary blank lines and improved formatting in the `benchmark_matmul` and `benchmark_matmul_fp8` scripts for better readability.
- Streamlined the function definitions in the `flashattn` example script to enhance clarity and maintainability.

* [Refactor] Update AutoTuner configuration handling

- Modified the AutoTuner class to check if kernel parameters are set before processing tunable arguments, improving robustness in configuration handling.
- Enhanced the logic for skipping compilation when tunable parameters are already provided, ensuring efficient use of resources.
- Updated comments for clarity and maintainability.

* lint fix

* Update TVM subproject commit to indicate dirty state and modify MHA backward test cases

- Updated the subproject commit reference in the TVM directory to reflect a dirty state.
- Adjusted the `test_mha_bwd` function to use a new configuration for the MHA backward tests, changing the context size from 128 to 256.
- Uncommented the main testing function call for potential execution.

* lint fix

* Bump transformers from 4.51.0 to 4.52.1 in /examples/bitnet-1.58b (#619)

Bumps [transformers](https://github.com/huggingface/transformers) from 4.51.0 to 4.52.1.
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](huggingface/transformers@v4.51.0...v4.52.1)

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 4.52.1
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Fix PTXAS options flag in LibraryGenerator for consistency (#620)

* Refactor FP8 type handling across multiple files to standardize usage of "float8_e4m3" and "float8_e5m2" instead of "e4m3_float8" and "e5m2_float8". This includes updates in benchmarks, examples, tests, and internal utilities.

* [Refactor] Add parallel loop transform pass for condition extraction (#618)

* [Refactor] Add parallel loop transform

* done format check

* pull 3rdparty repo

* Refactor loop variable handling in transformation utilities

- Updated the logic in `loop_parallel_transform_utils.h` to simplify the handling of related loop variables.
- Removed the check that enforced a single related loop variable, replacing it with a return statement when multiple variables are detected, enhancing clarity and maintainability of the transformation process.

* Update loop_parallel_transform_utils.h

* Refactor loop variable handling in transformation utilities

- Enhanced the logic in `loop_parallel_transform_utils.h` to improve clarity and maintainability by simplifying the handling of related loop variables.
- Replaced the previous enforcement of a single related loop variable with a return statement for multiple variables detected.

* remove disable cache flag as commit id has been key component

* lint fix

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Dev] Update linear attention examples to enhance performance on Hopper GPUs (#621)

* Tune linear attention examples on H100

* Add retnet fwd kernel

* fix lint

* [Enhancement] Add ahead of time cython compilation in setup.py (#622)

* [Enhancement] Add Cython support and compiler detection in setup.py

- Introduced a new `CythonExtension` class for building Cython-based extensions, enhancing the build process for Cython projects.
- Implemented functions to detect the Cython compiler and C++ compiler, improving compatibility and user experience.
- Updated the build process to handle Cython extensions alongside CMake extensions, ensuring a seamless integration for users.
- Added caching mechanisms for Cython compilation to optimize build times and reduce unnecessary recompilation.

* [Enhancement] Add Cython dependency and enable CMake extension building

- Added Cython as a required dependency in `pyproject.toml` to support Cython-based extensions.
- Updated `setup.py` to enable building CMake extensions, improving the build process for projects utilizing both Cython and CMake.
- Modified the Cython compiler detection logic to streamline installation instructions for users.

* [Enhancement] Support more flexible layout host pythonic expr (#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix

* [Enhancement] support composable expression for shape with symbolic vars (#624)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix

* minor update

* 🐍Fix the file name "test_exmaple_tilelang_nsa" (#629)

* [Enhancement] Add CPU utilization and count settings for Auto-Tuning (#630)

* [Enhancement] Add CPU utilization and count settings for Auto-Tuning

- Introduced environment variables for CPU utilization, counts, and maximum CPU count for auto-tuning.
- Updated the AutoTuner class to utilize these new settings, improving flexibility and performance in multi-threaded environments.
- Enhanced logging to provide better insights into the auto-tuning process based on the configured CPU settings.

* typo fix

* [AutoTune] Support `with set_autotune_inputs` to set auto tuning input tensors (#632)

* [Refactor] Simplify and modularize autotuner implementation

- Removed unused imports and extensive code sections from the autotuner module to enhance readability and maintainability.
- Modularized the code by introducing new imports for autotuning and capturing functionalities, streamlining the overall structure.
- Improved logging setup and removed redundant timeout handling functions, focusing on core autotuning logic.
- Updated the AutoTuner class to better utilize the new modular structure, ensuring efficient performance during auto-tuning processes.

* [Refactor] Clean up and enhance capture and tuner modules

- Improved code readability by removing unnecessary blank lines and organizing imports in `capture.py` and `tuner.py`.
- Enhanced logging in the `AutoTuner` class to provide clearer warnings regarding the usage of `supply_prog` in the context of auto-tuning.
- Streamlined the `CaptureStack` class for better thread-local context management.

* lint fix

* [Refactor] Simplify configuration and autotuning logic in blocksparse GEMM example

- Updated `get_configs` function to reduce the number of configurations, enhancing performance and clarity.
- Removed the `get_best_config` function, integrating its logic directly into the `blocksparse_matmul` function with the `@autotune` decorator for streamlined autotuning.
- Adjusted the main function to directly utilize the autotuned kernel, simplifying the overall structure and improving readability.
- Deleted obsolete test file for autotuning decorator, cleaning up the codebase.

* [Refactor] Improve code formatting and readability in autotune test file

- Reformatted the `matmul` function and `get_configs` function for better readability by adjusting line breaks and indentation.
- Fixed a typo in the `enable_rasteration` parameter name to ensure consistency.
- Cleaned up unnecessary blank lines to enhance overall code clarity.

* Update example_blocksparse_gemm.py

* Update capture.py

* [Pass] Introduce flag to diable cp async lowering (#633)

* [Enhancement] Update PipelinePlanner to support async copy configuration

- Modified the `Substitute` method in `PipelinePlanner` to accept a `use_async_copy` parameter, allowing for more flexible pipeline planning based on async copy requirements.
- Updated the constructor of `PipelinePlanner` to initialize the `use_async_copy_` member variable.
- Adjusted the logic in the pipeline planning process to conditionally apply async copy annotations based on the new parameter.
- Commented out the `LoopVectorizeDynamic` call in `LowerAndLegalize` to prevent unintended modifications during the legalizing phase.

* Refactor PipelinePlanning function for improved readability

- Adjusted the formatting of the `use_async_copy` variable assignment in the `PipelinePlanning` function to enhance code clarity and maintainability.

* fix typo (#635)

* [Pass][Simplify] Introduce symbolic level simplify for condition expression (#634)

* [Enhancement] Add argument simplification option to StmtSimplifier

- Introduced a new `simplify_arguments` flag in the `StmtSimplifier::Apply` method to control argument simplification behavior.
- Updated the `Simplify` function to accept the new flag, allowing for enhanced flexibility in the simplification process.
- Adjusted the `LowerAndLegalize` and `_Simplify` functions to utilize the new argument, ensuring consistent behavior across the codebase.
- Added comments to clarify the purpose of the new flag and its impact on simplification logic.

* lint fix

* [Enhancement] Improve layout inference and reduce operation handling

- Updated `ParallelOp::InferLayout` to check for pure buffer stores, enhancing layout inference logic.
- Modified `ReduceOp::Lower` to include all threads in the AllReduce operation, improving performance on specific architectures.
- Added a TODO comment in `AllReduce` to consider merging synchronization barriers for optimization.

* lint fix

* [Enhancement] Add input validation for GEMM parameters

- Introduced checks to ensure that the dimensions M and N are divisible by their respective warp sizes (kMPerWarp and kNPerWarp) in the Gemm::ComputeWarpPartition method.
- Added informative error messages to assist in debugging when the input parameters do not meet the required conditions.

* bug fix

* Enhance test coverage by adding LLVM requirement decorator to multiple function call tests. This ensures that tests for argument count, type code, null data pointer, and dimensionality checks are only executed when LLVM is available, improving test reliability and clarity.

* lint fix

* Fix software pipeline stage annotation and update optional config handling in StmtSimplifier

* Add Python executable detection in CMake configuration and update TVM submodule reference. Remove unused vectorization tests for improved clarity.

* Update TVM submodule reference and refactor FFI registration to use static initialization blocks for improved organization and clarity.

* Refactor attribute handling in layout and IR nodes to use reflection registration. This change replaces the VisitAttrs method with a RegisterReflection method for improved clarity and organization across multiple classes, including KernelLaunchFrameNode, WarpSpecializeFrameNode, LayoutNode, FragmentNode, and SwizzledLayoutNode.

* finish rebase

* tvm update

* Refactor FFI registration across tilelang modules to use the updated `tvm.ffi` namespace. This includes changes in various files to replace `tvm._ffi` with `tvm.ffi`, enhancing consistency and clarity in the codebase.

* lint fix

* Update TVM submodule reference and modify CUDA runtime argument handling to use the new runtime constants for improved clarity and consistency.

* lint fix

* Refactor tensor data type references from "e4m3_float8" and "e5m2_float8" to "float8_e4m3" and "float8_e5m2" across multiple files for consistency and clarity.

* lint fix

* Refactor forward_index initialization in Fragment class to default to an empty array instead of None, ensuring consistent handling of optional outputs.

* test fix

* lint fix

* bugfix

* lint fix

* reduce fix

* lint fix

* carver fix

* cast fix

* Update submodule and enhance kernel launch functionality with optional block size parameter; add device kernel launch transformation.

* lint fix

* bugfix

* Refactor test execution in test_tilelang_cpu_gemm.py and enhance device call checks in lower.py to exclude C packed functions from kernel launch conditions.

* lint fix

* Update runtime.cc

* phase out lisence

* Update subproject commit for TVM to 555cc71

* Update subproject commit for TVM to d39953fa

* Update subproject commit for TVM to 9574805f

* Update subproject commit for TVM to a08b7c3

* fix ci

* ci fix

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: Cunxiao Ni <85601223+Cunxiao2002@users.noreply.github.com>
Co-authored-by: Yuxi Chi <cherichy@outlook.com>
Co-authored-by: Nathan Chen <120630832+Nathancgy@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: xs-keju <93414213+xs-keju@users.noreply.github.com>
Co-authored-by: Tong WU <109033598+Rachmanino@users.noreply.github.com>
Co-authored-by: Kadir Nar <kadir.nar@hotmail.com>
Co-authored-by: Yuqing Xia <35415939+xiayuqing0622@users.noreply.github.com>
Co-authored-by: xwhzz <wh.xie@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant