Skip to content

fix device bug on GPU #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 26, 2025
Merged

fix device bug on GPU #254

merged 2 commits into from
Mar 26, 2025

Conversation

noahho
Copy link
Collaborator

@noahho noahho commented Mar 26, 2025

Fix #253

@noahho noahho requested a review from Copilot March 26, 2025 07:33
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR addresses a bug related to device handling on GPU by updating how the device is inferred and validated during runtime.

  • Added a new import for device inference.
  • Modified the device-check logic to use an inferred device mapping.
Comments suppressed due to low confidence (2)

src/tabpfn/base.py:257

  • Verify that infer_device_and_type consistently returns a type accepted by torch.device to prevent runtime errors on GPU.
device_mapped = infer_device_and_type(device)

src/tabpfn/base.py:265

  • [nitpick] Consider renaming 'device_mapped' to a more descriptive name such as 'resolved_device' to clarify its purpose.
if torch.device(device_mapped).type == "cpu":

@noahho noahho merged commit d778401 into main Mar 26, 2025
8 checks passed
@LeoGrin LeoGrin deleted the fix-device-bug branch March 26, 2025 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

"cpu" in device - TypeError: argument of type 'torch.device' is not iterable
1 participant