Skip to content

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Feb 6, 2020

… params.

The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).

For now the code assumes that all the call (final-style) primitives
carry exactly one subjaxpr with the parameter name call_jaxpr. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.

@gnecula gnecula requested a review from mattjj February 6, 2020 10:21
@gnecula gnecula requested a review from jekbradbury February 6, 2020 10:21
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have one main suggestion, basically to avoid switching behavior on whether there's a special "call_jaxpr" parameter.

… params.

The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).

For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
@gnecula gnecula merged commit fb7e48f into jax-ml:master Feb 11, 2020
@gnecula gnecula deleted the simple_jaxpr2 branch February 11, 2020 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants