-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Support memory reuse for dynamic shared memory #9341
Conversation
Hmm, for reuse on the constant-size dynamic shmem, I thought the existing |
I think this is to support reusing heterogenous data type allocs |
|
Thanks @jinhongyii, that you are able to reduce the alloc size from |
Let's double check and get it in if there is no further issue :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise Lgtm
tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jinhongyii
* reuse shared dyn * minor * format * format * address comment and fix * address comment * address comment
* reuse shared dyn * minor * format * format * address comment and fix * address comment * address comment
This PR applies part of the algorithm of storage rewrite to detect memory reuse possibility for dynamic shared memory, and does rewrite to the body. This functionality is important in matmul because we may want to allocate the shared memory for data fetch and the shared memory for write back in the same location.