Skip to content

improve error message with when custom_vjp bwd rule produces wrong shape/dtype #28757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented May 14, 2025

Even the best of us can miss important error information, like noticing when shapes disagree. So let's make things better.

import jax
import jax.numpy as jnp

@jax.custom_vjp
def f(x):
  return x

def f_fwd(x):
  return x, ()

def f_bwd(_, g):
  return jnp.ones((10,), g.dtype),

f.defvjp(f_fwd, f_bwd)

jax.grad(f)(3.14)

Before:
Traceback (most recent call last):
File "/usr/local/google/home/mattjj/packages/jax/rahul11.py", line 16, in
jax.grad(f)(3.14)
ValueError: Custom VJP bwd rule must produce an output with the same shape/dtypes as the args tuple of the primal function, but at output[0] the bwd rule produced an output of shape/dtype float32[10] corresponding to an input of shape/dtype float32[].

After (emphasis added):
Traceback (most recent call last):
File "/usr/local/google/home/mattjj/packages/jax/rahul11.py", line 16, in
jax.grad(f)(3.14)
ValueError: Custom VJP bwd rule must produce an output with the same shape/dtypes as the args tuple of the primal function, but at output[0] the bwd rule produced an output of shape/dtype float32[10] corresponding to an input of shape/dtype float32[], so the shapes do not match

@mattjj mattjj requested a review from froystig May 14, 2025 21:46
@mattjj mattjj self-assigned this May 14, 2025
@mattjj mattjj added skill issue better_errors Improve the error reporting labels May 14, 2025
@mattjj mattjj force-pushed the custom-vjp-aval-mismatch-extra branch from c7cf7c1 to 5667659 Compare May 15, 2025 16:38
@mattjj mattjj added the pull ready Ready for copybara import and testing label May 15, 2025
@mattjj mattjj force-pushed the custom-vjp-aval-mismatch-extra branch from 5667659 to 0984dc8 Compare May 15, 2025 16:46
@copybara-service copybara-service bot merged commit 0533263 into jax-ml:main May 15, 2025
23 checks passed
@mattjj mattjj deleted the custom-vjp-aval-mismatch-extra branch May 15, 2025 18:06
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good (hindsight review).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting pull ready Ready for copybara import and testing skill issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants