Skip to content

Remove int_scaled_mm's dependency on triton for cpu #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 29, 2024

Conversation

Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Apr 8, 2024

The int_scaled_mm op in Torchao is designed for CUDA at the beginning. However, this op is also needed for CPU. This PR adds a path for CPU in intmm.int_scaled_mm. It is not registered as and implementation for torchao.int_scaled_mm because we want to use Inductor for further optimization which cannot recognize the torchao.int_scaled_mm op.

This change requires this PR pytorch/pytorch#136942. Otherwise, there might be numerical issues. So, it works with pytorch nightly since 20241026.

Test is covered by test/kernel/test_autotuner.py and test/prototype/test_smoothquant.py.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 8, 2024
@Xia-Weiwen
Copy link
Collaborator Author

Hi @cpuhrsch Could you please review and see if the changes are reasonable to you? Thanks.

@Xia-Weiwen
Copy link
Collaborator Author

Hi @cpuhrsch Could you please suggest how to deal with the issue (CPU impl availability depends on triton and AUTOTUNER_ENABLE)? Thanks!

@cpuhrsch
Copy link
Contributor

Hey @Xia-Weiwen - Thank you for the PR! Sorry for the delay in review. Also, please note the CI hasn't run green.

Another way to resolve this could be to move

@torch.library.impl(lib, "int_scaled_matmul", "CPU")
def int_scaled_matmul_cpu(a, b, scales1):
    c = torch._int_mm(a, b)
    return c.to(scales1.dtype) * scales1

into torchao/kernel/intmm.py which shouldn't have a dependency on triton. Just be sure to also define lib = torch.library.Library("torchao", "FRAGMENT")

@Xia-Weiwen
Copy link
Collaborator Author

@cpuhrsch Thanks! I will give it a try. A question is what AUTOTUNER_ENABLE is and whether CPU impl should depend on it or not.

@cpuhrsch
Copy link
Contributor

@Xia-Weiwen - it's used for a Triton autotuner that allows us to cycle over a very large number of configs for a given fixed input shape. See https://github.com/pytorch-labs/ao/tree/main/torchao/kernel#autotuner-and-custom-triton-kernels

@Xia-Weiwen
Copy link
Collaborator Author

Thank you @cpuhrsch. Looks like CPU impl does not need this.

Copy link

pytorch-bot bot commented Oct 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/128

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 97dfea8 with merge base cbd90e3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Xia-Weiwen Xia-Weiwen requested a review from cpuhrsch October 29, 2024 01:55
@Xia-Weiwen
Copy link
Collaborator Author

Hi @cpuhrsch This PR requires latest torch nightly (after 20241026) to pass CI. May I know when torch nightly will be updated in the CI? Thanks.

@cpuhrsch
Copy link
Contributor

@Xia-Weiwen - can you try merging the latest version of main? You might be built on top of a commit that pinned the nightly version. I see dev20241022 here: https://github.com/pytorch/ao/actions/runs/11550979799/job/32147052309?pr=128

@Xia-Weiwen
Copy link
Collaborator Author

@Xia-Weiwen - can you try merging the latest version of main? You might be built on top of a commit that pinned the nightly version. I see dev20241022 here: https://github.com/pytorch/ao/actions/runs/11550979799/job/32147052309?pr=128

Thanks

@Xia-Weiwen
Copy link
Collaborator Author

Hi @cpuhrsch CI is green. Could you please review? Thanks.

@cpuhrsch cpuhrsch merged commit 5cfc4c7 into pytorch:main Oct 29, 2024
17 checks passed
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* code beautification

* debug info

* debug

* add missing args

* typo

* fix dtype check
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants