Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement process_call for LinearizeTrace #25481

Merged
merged 1 commit into from
Dec 14, 2024

Conversation

dougalm
Copy link
Collaborator

@dougalm dougalm commented Dec 13, 2024

No description provided.

@dougalm dougalm requested a review from mattjj December 13, 2024 21:24
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
num_residuals = len(consts)
del attrs_tracked # TODO: attrs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise an exception if there are any? (You can also unpack into an empty tuple to check the same, though with no explanation of what the error means.)

@@ -4153,9 +4153,10 @@ def _broadcast_in_dim_typecheck_rule(

def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
shape, broadcast_dimensions, sharding):
aval = core.get_aval(operand)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I would've thought operand would be an UndefinedPrimal here. Does get_aval work on that? Why this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I changed it because operand was sometimes a literal. I thought the implementation was just missing that case (sometimes we assume things are tracers when they can be literals or tracers). But you're right, something else must be going wrong. I'll investigate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figured it out. I was instantiating zeros inside the linear jaxpr rather than outside it. Gotta dedent your instantiate_zeros.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Dec 13, 2024
@dougalm dougalm force-pushed the custom-linearize-process-call branch 2 times, most recently from 19058ba to 032799f Compare December 13, 2024 22:10
@dougalm dougalm force-pushed the custom-linearize-process-call branch from 032799f to dea51cb Compare December 13, 2024 22:12
@copybara-service copybara-service bot merged commit f4e5f14 into main Dec 14, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants