-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUGFIX] fix illegal memory access bug in reduce op schedule by constriant thread_y #8566
[BUGFIX] fix illegal memory access bug in reduce op schedule by constriant thread_y #8566
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. @vinx13 could you help review this PR as well?
Adding mem-check
could definitely avoid such issues, but it needs more discussions and considerations to make it robust and flaky-free. It would be great to file an RFC here (https://github.com/apache/tvm-rfcs) if you have a concrete proposal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @MasterJH5574 would be great to check if this is an issue with autotir
c00d701
to
6260154
Compare
…riant threadIdx.y Signed-off-by: ziqiang.pzq <ziqiang.pzq@alibaba-inc.com>
6260154
to
a86e87e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @ZQPei ! |
…riant threadIdx.y (apache#8566) Signed-off-by: ziqiang.pzq <ziqiang.pzq@alibaba-inc.com> Co-authored-by: ziqiang.pzq <ziqiang.pzq@alibaba-inc.com>
…riant threadIdx.y (apache#8566) Signed-off-by: ziqiang.pzq <ziqiang.pzq@alibaba-inc.com> Co-authored-by: ziqiang.pzq <ziqiang.pzq@alibaba-inc.com>
This commit is to fix an illegal memory access bug in reduction op.
Bug description
Earlier last week I was running a tvm module with
cuda-memcheck
to make sure it was safe with memory process. However, the module failed to pass the memory check. Then I realized there must be something wrong with the CUDA kernels and finally I found it.To reproduce this error, you can simply run the following commands on a cuda enabled machine,
cuda-memcheck --report-api-errors no python3 ${TVM_HOME}/tests/python/topi/python/test_topi_reduce.py
and the terminal will show a stack trace like this,
This means that there are illegal memory accessed in reduction CUDA kernel.
Bug analysis
To solve this error, I wrote a simple python debug code as follows, which build and run a sum op and also it save the CUDA kernel at the same time,
test_reduce_sum.py
Also, I can reproduce the same memcheck error by running
and the CUDA kernel code in my simple
test_reduce_sum.py
will be saved tolib_sum.cu
.Also, we can infer the kernel are launched with grid(1, 1, 1) and block(32, 32, 1) from
python/tvm/topi/cuda/reduction.py
.From the CUDA kernel code and the error report, we can find that the code lacks a constriant to
threadIdx.y
at the end of buffer copy step.If the output size is less than 32, then
threadIdx.y
may access the illegal memory. The code from line 51 to line 53 should be like this,Fix the reduction schedule
After analysising the CUDA kernel, we can fix the schedule of all reduction ops in
python/tvm/topi/cuda/reduction.py
.I amend the code in line 89 and add a constriant to thread_y by the following code
BTW, since this bug can only be detected with
cuda-memcheck
tool, I think it is essential to addcuda-memcheck
to tvm Github Action to avoid bugs like this.