Skip to content

Commit

Permalink
Pass to eliminate redundant branch and overcompute (#17170)
Browse files Browse the repository at this point in the history
* Implementation to eliminate redundant branch introduced due to operator padding and overcompute, this creates more opportunities to vectorize the code

* Fixed lint error in transform.py file

* Fixed lint errors in the file using_assume_to_reduce_branches.cc

* Fixed lint error in transform.py related to line too long

* Fixed Lint error related to space and length of the sentence in using_assume_to_reduce_branches.cc

* Fixed lint error : trailing whitespaces in using_assume_to_reduce_breanches.cc

* Fixed lint error: clang format issue in cpp files

* fixed pylint errors in python files and used clang format to format the cpp files

* Ran black format and removed the attr_registry_map.h import as it was running into some other issue because of which build was failing
  • Loading branch information
sdalvi-quic authored Jul 25, 2024
1 parent 7bd738a commit 6704175
Show file tree
Hide file tree
Showing 4 changed files with 1,063 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,14 @@ TVM_DLL Pass InstrumentProfileIntrinsics();
*/
TVM_DLL Pass DefaultGPUSchedule();

/*!
* \brief This pass analyzes primfunc & eliminates branch introdued due to layout specific padding.
* It leverages from the buffer assumptions and use the information to eliminate the branch.
* \note This creates more opportunity to vectorize the code.
* \return The Pass.
*/
TVM_DLL Pass UseAssumeToReduceBranches();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,3 +1199,16 @@ def DefaultGPUSchedule():
ret: tvm.transform.Pass
"""
return _ffi_api.DefaultGPUSchedule() # type: ignore


def UseAssumeToReduceBranches():
"""This pass attempts to eliminates layout specific pad branch by overcomputing the values
for padded region. Eliminating the branch will help to vectorize code,
and improve element wise ops performance.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.UseAssumeToReduceBranches() # type: ignore
Loading

0 comments on commit 6704175

Please sign in to comment.