-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API Compat fix on max_pool2d_with_indices_() #65
Conversation
The purpose of this PR is to address several test cases for This passes an additional 6 tests |
the red x here is because of whate we discussed offline? missing torch-mlir wheel? edit: indeed it is |
# Conflicts: # cpp_ext/TorchOps.cpp # cpp_ext/TorchOps.h
Co-authored-by: Arham Khan <arhamkhan@Arhams-MacBook-Pro.local>
Co-authored-by: Arham Khan <arhamkhan@Arhams-MacBook-Pro.local>
"resolved conflict btw previous PR"
Allowed |
@brucekimrokcmu it seems like you rebased against main but your PR is trying to change the torch-mlir submodule, can you verify whether your local branch is at the most recent torch-mlir commit in |
Co-authored-by: Arham Khan <arhamkhan@Arhams-MacBook-Pro.local>
conflict resolved
This PR addresses the following error to resolve a number of test cases:
Essentially, a recursive template function generates 16 different combinations of
PyAnyTorchListOfTorchIntValue
andPyTorch_IntValue
for four arguments (kernel_size, stride, padding, dilation) to be accepted formax_pool2d_with_indices_
Ops.Note that to avoid segfault,
max_pool2d_with_indices()
redefines defaultloc
andip
before all casted arguments are passed intomax_pool2d_with_indices_
() ops function.