Skip to content

switch tests using check vjp to use PyTorch autograd #2134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented May 25, 2025

Before submitting
  • n/a Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • n/a Did you make sure to update the docs?
  • n/a Did you write any new necessary tests?

What does this PR do?

This PR switches one test to use the PyTorch autograd integration instead of circumventing it.
One obstacle (aside from it not working) could be time. We might want to look into further reducing the size of numerically differentiated tensors.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@t-vi
Copy link
Collaborator Author

t-vi commented May 25, 2025

Errors:

FAILED thunder/tests/test_grad.py::test_vjp_correctness_embedding_manual_embedding_torch_cuda_thunder.dtypes.float64 - ValueError: tensor([1, 1], device='cuda:0') had an unexpected type <class 'torch.Tensor'>. Supported types are <class 'thunder.core.proxies.TensorProxy'>
FAILED thunder/tests/test_grad.py::test_vjp_correctness_sdpa_manual_grad_forward_scaled_dot_product_attention_nvfuser_cuda_thunder.dtypes.bfloat16 - ValueError: attn_mask.dtype=torch.bool is expected to be of the boolean or a floating type
FAILED thunder/tests/test_grad.py::test_vjp_correctness_sdpa_manual_grad_forward_scaled_dot_product_attention_torch_cuda_thunder.dtypes.bfloat16 - ValueError: attn_mask.dtype=torch.bool is expected to be of the boolean or a floating type
FAILED thunder/tests/test_grad.py::test_vjp_correctness_embedding_manual_embedding_nvfuser_cuda_thunder.dtypes.float64 - ValueError: tensor([4, 2], device='cuda:0') had an unexpected type <class 'torch.Tensor'>. Supported types are <class 'thunder.core.proxies.TensorProxy'>
FAILED thunder/tests/test_grad.py::test_vjp_correctness_sdpa_manual_grad_forward_scaled_dot_product_attention_torch_cuda_thunder.dtypes.float16 - ValueError: attn_mask.dtype=torch.bool is expected to be of the boolean or a floating type
FAILED thunder/tests/test_grad.py::test_vjp_correctness_sdpa_manual_grad_forward_scaled_dot_product_attention_nvfuser_cuda_thunder.dtypes.float16 - ValueError: attn_mask.dtype=torch.bool is expected to be of the boolean or a floating type

And way too slow: (also getitem is xfail...)

215.61s call     thunder/tests/test_grad.py::test_vjp_correctness_getitem_torch_cpu_thunder.dtypes.float64
164.32s call     thunder/tests/test_grad.py::test_vjp_correctness_scatter_torch_cpu_thunder.dtypes.float64
151.37s call     thunder/tests/test_grad.py::test_vjp_correctness_repeat_torch_cpu_thunder.dtypes.float64
130.76s call     thunder/tests/test_grad.py::test_vjp_correctness_topk_torch_cpu_thunder.dtypes.float64
129.41s call     thunder/tests/test_grad.py::test_vjp_correctness_softplus_torch_cpu_thunder.dtypes.float64
116.76s call     thunder/tests/test_grad.py::test_vjp_correctness_sort_torch_cpu_thunder.dtypes.float64
111.70s call     thunder/tests/test_grad.py::test_vjp_correctness_softshrink_torch_cpu_thunder.dtypes.float64
110.24s call     thunder/tests/test_grad.py::test_vjp_correctness_celu_torch_cpu_thunder.dtypes.float64
107.17s call     thunder/tests/test_grad.py::test_vjp_correctness_reshape_torch_cpu_thunder.dtypes.float64
106.61s call     thunder/tests/test_grad.py::test_vjp_correctness_elu_torch_cpu_thunder.dtypes.float64
106.28s call     thunder/tests/test_grad.py::test_vjp_correctness_gelu_torch_cpu_thunder.dtypes.float64
102.32s call     thunder/tests/test_grad.py::test_vjp_correctness_hardshrink_torch_cpu_thunder.dtypes.float64
101.89s call     thunder/tests/test_grad.py::test_vjp_correctness_scaled_dot_product_attention_torch_cpu_thunder.dtypes.float64
89.05s call     thunder/tests/test_grad.py::test_vjp_correctness_interpolate_torch_cpu_thunder.dtypes.float64
87.42s call     thunder/tests/test_grad.py::test_vjp_correctness_rrelu_torch_cpu_thunder.dtypes.float64
84.84s call     thunder/tests/test_grad.py::test_vjp_correctness_pad_torch_cpu_thunder.dtypes.float64
84.03s call     thunder/tests/test_grad.py::test_vjp_correctness_polygamma_torch_cpu_thunder.dtypes.float64
83.88s call     thunder/tests/test_grad.py::test_vjp_correctness_var_mean_torch_cpu_thunder.dtypes.float64
82.87s call     thunder/tests/test_grad.py::test_vjp_correctness_var_torch_cpu_thunder.dtypes.float64
81.75s call     thunder/tests/test_grad.py::test_vjp_correctness_take_torch_cpu_thunder.dtypes.float64
73.97s call     thunder/tests/test_grad.py::test_vjp_correctness_index_copy_torch_cpu_thunder.dtypes.float64
67.17s call     thunder/tests/test_grad.py::test_vjp_correctness_expand_torch_cpu_thunder.dtypes.float64
66.87s call     thunder/tests/test_grad.py::test_vjp_correctness_torch_pad_torch_cpu_thunder.dtypes.float64
64.46s call     thunder/tests/test_grad.py::test_vjp_correctness_scatter_add_torch_cpu_thunder.dtypes.float64
62.40s call     thunder/tests/test_grad.py::test_vjp_correctness_leaky_relu_torch_cpu_thunder.dtypes.float64
58.74s call     thunder/tests/test_grad.py::test_vjp_correctness_normalize_torch_cpu_thunder.dtypes.float64
58.44s call     thunder/tests/test_grad.py::test_vjp_correctness_mse_loss_torch_cpu_thunder.dtypes.float64

@mruberry
Copy link
Collaborator

We should think if these tests would be different from

def test_phantom_grad_vs_torch_consistency(op, device: str, dtype: dtypes.dtype, executor, comp):

which, despite its name, I think just tests for grad consistency with PyTorch ops now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants