Skip to content

Conversation

Steboss
Copy link
Contributor

@Steboss Steboss commented Sep 10, 2025

@matthew-e-hopkins
Hey people, this is a huge update, to allow us to use JAX > 0.5.3 (we're currently testing AXLearn with JAX 0.7.2).
I've implemented the following changes:

  • I've created a back compatibility option _JAX_MEMORY_SPACE_SUPPORT so that all these changes can work with different versions of JAX
  • In utils.py JAX' from jax._src.sharding_impls import TransferToMemoryKind has been substituted with its correspondent version for JAX 0.7 (jax.memory.Space.*). I am preserving the previous option by checking the jax version:
if _JAX_MEMORY_SPACE_SUPPORT:
    MemoryKind = [jax.memory.Space.Device, jax.memory.Space.Host]
    DEVICE_MEMORY = jax.memory.Space.Device
    HOST_MEMORY = jax.memory.Space.Host

    def transfer_to_memory_kind(tensor: Tensor, memory_kind: MemoryKind) -> Tensor:
        return jax.device_put(tensor, memory_kind)

else:
    from jax._src.sharding_impls import TransferToMemoryKind  # pylint: disable=ungrouped-imports

    MemoryKind = Literal["device", "pinned_host"]
    DEVICE_MEMORY = "device"
    HOST_MEMORY = "pinned_host" 
  • These changes have been propagated to optimizers_test.py and optimizers.py
  • jax.experimental.pallas.triton.TritonCompilerParams has now changed in .CompilerParams, so gpu_attention.py, gpu_decoding.py, gpu_paged_attention.py and paged_kv_cache_gpu_kernel.py have been changed accordingly. Again, as before, I'm importing _JAX_MEMORY_SPACE_SUPPORT to check the JAX version and preserving the previous code.

I've tested the changes with Fuji models, it would be great to find an optimal solution for this, as we'd like to support AXLearn in JAX-Toolbox again newer JAX versions.
Please, let me know if you want some changes. Thank you

changlan and others added 6 commits July 28, 2025 15:10
- Rename live_step_len -> unpadded_len across attention and KV cache modules
- Update documentation to clarify that unpadded_len specifies the number of
  non-padding tokens per sequence, with actual behavior depending on KV cache implementation
- Fix pre-existing pylint error in rattention.py where rla_output was used before assignment
- Update all test files to use the new parameter name

The new name better reflects the parameter's purpose: indicating the number of
non-padding tokens in each sequence, rather than the ambiguous "live step length".
Implementation behavior varies by KV cache type:
- Standard KVCache: ignores the parameter
- SlidingWindowKVCache: uses it for sequence masking
- PagedKVCache: ignores the parameter

GitOrigin-RevId: 5b0d848
@Steboss Steboss requested a review from a team as a code owner September 10, 2025 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants