-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement process_call
for LinearizeTrace
#25481
Conversation
jax/_src/interpreters/ad.py
Outdated
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) | ||
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents) | ||
num_residuals = len(consts) | ||
del attrs_tracked # TODO: attrs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raise an exception if there are any? (You can also unpack into an empty tuple to check the same, though with no explanation of what the error means.)
jax/_src/lax/lax.py
Outdated
@@ -4153,9 +4153,10 @@ def _broadcast_in_dim_typecheck_rule( | |||
|
|||
def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, | |||
shape, broadcast_dimensions, sharding): | |||
aval = core.get_aval(operand) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, I would've thought operand would be an UndefinedPrimal here. Does get_aval work on that? Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I changed it because operand
was sometimes a literal. I thought the implementation was just missing that case (sometimes we assume things are tracers when they can be literals or tracers). But you're right, something else must be going wrong. I'll investigate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Figured it out. I was instantiating zeros inside the linear jaxpr rather than outside it. Gotta dedent your instantiate_zeros.
19058ba
to
032799f
Compare
032799f
to
dea51cb
Compare
No description provided.