Skip to content

Should tree_sum use jax.tree.reduce_associative? #1498

@SobhanMP

Description

@SobhanMP

In tree_sum is implemented using jax.tree.map and jax.tree.reduce but shouldn't jax.tree.reduce_associative be used instead? The runtime seems to be very close but the compile time is significantly lower (18s vs 23s) on the one problem I tested.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:supportFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions