-
Notifications
You must be signed in to change notification settings - Fork 207
Feature : LC2ST MLP GPU support (closes : #1160) #1715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature : LC2ST MLP GPU support (closes : #1160) #1715
Conversation
…darshan/sbi into fix/pymc-version-pin-1397
|
Hi @Dev-Sudarshan, Thank you for this contribution adding GPU support to LC2ST! Apologies for the long silence on this PR. Note: This PR contains some unrelated PyMC version pinning changes. Please remove those by rebasing on main so we can focus on the GPU support implementation. I've done a review and found the following issues: Issues Found
Current: Fix: (Keep
The
The current implementation only checks for CUDA. Please add MPS support (for Apple Silicon) and refactor to use a generic use_gpu pattern: Similarly, update the RandomForest warning to check for device.lower() in ('cuda', 'mps').
When a user explicitly requests device="cuda" (or "mps") but that backend isn't available, the code silently falls back to
Instead of running the lc2st device test on both cpu and gpu, restrict it to GPU by adding a marker decorator: and testing only cuda and mps (if available). Thanks again for working on this feature – the overall approach using Let me know if you have any questions about the feedback above. |
janfb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See PR comment above.
|
Thank you for the feedback. I will review everything thoroughly and make the required changes. |
This PR adds optional GPU support to L-C2ST by introducing a PyTorch-based
MLP classifier implemented via skorch. This addresses issue #1160 .
Changes: