Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add v6e special meshes #952

Merged
merged 4 commits into from
Jan 27, 2025
Merged

Add v6e special meshes #952

merged 4 commits into from
Jan 27, 2025

Conversation

hanzhi713
Copy link
Member

This special mesh improves MFU of 150b model by 1.5% (or 5% step time) when using non-native mesh shape of 64x4 on v6e.

@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners January 26, 2025 08:31
@hanzhi713
Copy link
Member Author

@ruomingp Could you please take a look?

@hanzhi713 hanzhi713 added this pull request to the merge queue Jan 27, 2025
Merged via the queue into apple:main with commit b125f00 Jan 27, 2025
6 checks passed
@hanzhi713 hanzhi713 deleted the custom-mesh branch January 27, 2025 20:02
rahul003 added a commit to rahul003/axlearn that referenced this pull request Mar 4, 2025
commit 336c75d
Author: Mark Lee <markblee@apple.com>
Date:   Mon Mar 3 09:04:07 2025 -0800

    Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029)

    * Allows specifying PartitionSpec to host_to_global_device_array.

    * Generalizes to arbitrary uniform partitioning.

    * Addresses comments and adds mixed shape test.

commit 0881412
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Sat Mar 1 15:41:38 2025 -0800

    Refactor Mask in Attention (apple#1028)

    Currently, the attention code is **hardcoded** to handle either `causal_mask`
    or an arbitrary `mask_fn`.

    To support **sliding window masks**, we previously used a **hack** by injecting
    the `_sliding_window_size` attribute into functions.

    This refactor **makes the masking logic more flexible** by allowing arbitrary
    `MaskFnAttentionBias`.
    - If downstream requires a **new mask pattern**, they can simply:
      1. Implement a **subclass of `MaskFnAttentionBias`**.
      2. Set `attention.mask` accordingly.

commit f67d3f9
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Fri Feb 28 08:53:00 2025 -0800

    Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026)

    Currently, Flash Attention infers decoding implicitly based on circumstantial
    evidence. This PR makes the check explicit.

commit f8d2c66
Author: qdavid1 <168590940+qdavid1@users.noreply.github.com>
Date:   Thu Feb 27 15:26:18 2025 -0800

    External KV input for _update_layer_kwargs (apple#1025)

commit a3bf5e2
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Feb 26 17:23:40 2025 -0800

    Minor changes to Checkpointer (apple#1024)

commit 55e1841
Author: Wentao Wu <wentao_wu@apple.com>
Date:   Wed Feb 26 15:45:51 2025 -0800

    Add an option to break ties for top_k_logits when k = 1 (apple#1022)

    * Add an option to support stable top_k = 1.

    * address comments

    * address comments

    * address comments

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * address comments

    ---------

    Co-authored-by: Mark Lee <markblee@apple.com>

commit fbca3fc
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Wed Feb 26 14:05:25 2025 -0800

    Add priority_class as a launch flag (apple#1020)

commit b26bd74
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Wed Feb 26 14:04:47 2025 -0800

    Fix TypeError in calcualte_goodput.py (apple#1023)

commit f8191e1
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Wed Feb 26 11:03:44 2025 -0800

    Emulate flash attentnion unittests on CPU. (apple#1021)

    utils.py codebase is not well covered by CI because it branches different
    backend.

    This PR introduces new CPU test, utils_test.py.
    This test is expected to run on CPU and is designed to validate GPU/TPU code
    from a CPU environment by fake mesh.
    It allows quick verification in CI and local environments to ensure that code
    changes do not break GPU/TPU Flash Attention.

commit daec8c5
Author: Chang Liu <raytomato@users.noreply.github.com>
Date:   Tue Feb 25 12:38:43 2025 -0800

    Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019)

    Co-authored-by: Chang Liu <changliu@apple.com>

commit ac642ea
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Tue Feb 25 12:02:57 2025 -0800

    Fix crash in log-mel frontend when waveform samples are integers. (apple#1017)

    After updating JAX, this existing hidden bug started causing CI failures.
    When the sample dtype is int32 (which is valid), `jnp.finfo` returns None,
    even though `jnp.iinfo` is available.
    The previous JAX version seemed to handle this case more forgivingly.

    ```
    ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram
        return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny))
    ```

commit 7c64b55
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Tue Feb 25 10:52:50 2025 -0800

    Add LoadBalancer to GKE replicatedJob (apple#1015)

    Co-authored-by: Liang (SPG) He <lhe27@apple.com>

commit 8e8a41b
Author: Chang Lan <c.lan@apple.com>
Date:   Tue Feb 25 10:26:59 2025 -0800

    Expose jax.lax.scan's unroll option to Repeat layer (apple#1016)

    * Expose jax.lax.scan's unroll option to Repeat layer.

    * Defaults to None to avoid golden config changes

commit 682bce6
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Tue Feb 25 10:09:41 2025 -0800

    Handle None bias in BiasAndResidual (apple#1018)

commit f053318
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Mon Feb 24 11:12:23 2025 -0500

    Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013)

    * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate().

    * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate().

commit a93cd1b
Author: Luzy <zhiyun@users.noreply.github.com>
Date:   Sat Feb 22 19:55:01 2025 -0500

    fix dtype in frontend pre emphasis (apple#1014)

commit c1fe2e9
Author: Maggie Zhang <jiya.zhang.98@gmail.com>
Date:   Fri Feb 21 19:07:39 2025 -0800

    GoodPut minor fix: only process 0 should start goodput uploader (apple#984)

    * only process 0 will start goodput uploader

    * Add unit test

commit 4b1fbf0
Author: Chang Lan <c.lan@apple.com>
Date:   Fri Feb 21 13:36:45 2025 -0800

    Async context invocation for checkpointing (apple#1012)

    * Async context invocation supoprt for checkpointer

    * Add comment

    * Add comments

commit d4cd158
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Fri Feb 21 14:22:24 2025 -0500

    Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011)

    * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase.

    This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor.

    * Fixes pytype.

    * Addresses review.

commit 2ae6e66
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Fri Feb 21 09:22:35 2025 -0800

    Enable megascale abort on hang or error (apple#1010)

    * Enable megascale_error_reporter_abort on hang and error by default

    * Increase threshold to 10m

commit ce4b2fb
Author: Chunyang Wen <chunyang.wen@gmail.com>
Date:   Fri Feb 21 23:12:38 2025 +0800

    Add GPU monitor (apple#1006)

commit baf8ad7
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Thu Feb 20 19:35:07 2025 -0800

    Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009)

commit cf41112
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Feb 20 16:29:13 2025 -0800

    Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008)

    * Revert "gRPC Checkpointer (apple#1005)"

    This reverts commit d27c562.

    * Keep some changes

commit 454bdba
Author: Matthew Hopkins <matthew_e_hopkins@apple.com>
Date:   Thu Feb 20 15:10:51 2025 -0800

    upgrade jax 0.4.38 (apple#1007)

commit d27c562
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Tue Feb 18 18:54:53 2025 -0800

    gRPC Checkpointer (apple#1005)

commit fb90620
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Tue Feb 18 21:08:38 2025 -0500

    Makes file_system.glob support multiple patterns. (apple#1003)

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

commit 334f421
Author: Mark Lee <markblee@apple.com>
Date:   Tue Feb 18 17:03:39 2025 -0800

    Reverts sliding window attention changes. (apple#1004)

    * Revert "Fix flash decoding in GPU. (apple#999)"

    This reverts commit fdadfd8.

    * Revert "Supports TPU context parallel training (apple#981)"

    This reverts commit e151d69.

    * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)"

    This reverts commit 67645d0.

    * Retain model/decoder asr changes.

commit 3dacc6b
Author: Chang Lan <c.lan@apple.com>
Date:   Mon Feb 17 19:45:00 2025 -0800

    Refactor aot_compilation for reuse (apple#1000)

commit c44fe18
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Mon Feb 17 21:38:17 2025 -0500

    Makes checkpointer_test.py use file_system. (apple#1001)

commit fdadfd8
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Mon Feb 17 18:23:21 2025 -0800

    Fix flash decoding in GPU. (apple#999)

    target_positions used to be time_step, but after PR apple#995, it now represents the
    actual target positions with shape [batch, step_len].
    apple#995

    Updating the GPU decoding code to align with this change.

    CI did not cover GPU unit tests.

    TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU

commit 9e64388
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Mon Feb 17 16:40:03 2025 -0500

    Makes axlearn/cloud/ use file_system. (apple#998)

    * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency.

    * Adds testing for file_system.readfile.

    * Fixes pytype.

    * Makes axlearn/cloud use file_system instead of gfile.

commit 5fba4ce
Author: Chang Lan <c.lan@apple.com>
Date:   Mon Feb 17 09:44:10 2025 -0800

    AOT compilation support for inference (apple#997)

    * Add optional `devices` init argument to InferenceRunner for passing
      fake devices during AOT compilation.

    * Add more v5e slice types.

commit e151d69
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Sun Feb 16 13:01:15 2025 -0800

    Supports TPU context parallel training (apple#981)

    Fix

    Fix tests

commit 67645d0
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Sat Feb 15 13:26:51 2025 -0800

    Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)

    * Revert "Transpose kv cache for better decode performance (apple#979)"

    This reverts commit b130416.

    * Update golden configs

    * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding.

    Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for
    sliding window attention, the KV cache is kept for the full sequence length
    (`seq_len`) instead of the window length (`window_len`).

    For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache
    for the full 2M tokens. It then biases 1999k invalid KV tokens before
    calculating attention, resulting in a computational complexity of **O(2M²)**
    instead of the desired **O(1k²)**.

    This issue persists even when using flash attention. Flash attention uses the
    KV cache allocated in HBM as its input. While unnecessary blocks are discarded
    during computation, the KV cache still occupies HBM inefficiently for the full
    2M tokens.

    To address this, when `MultiheadAttention` detects a sliding window mask, it
    stores the key-value (KV) cache in a ring buffer inside the input linear layer.
    As a result, downstream projects using `MultiheadAttention` automatically
    benefit from efficient KV cache handling in `init_states` and `extend_step`.

    Additionally, for use cases like local-global attention in LLMs, it is
    recommended to use sliding window masks for even the global attention as well.
    For example, if you want to train an LLM with a context length of 8k, you can
    set the sliding window size to 8k during training. This enables functionally
    infinite decoding during inference. Accuracy wouldn't be good tho.

    Note:
    * query_positions in QKVLinear.forward() was introduced by
      apple#914. Now it returns to the caller.

    This PR actually moves from downstream speech/streaming/sliding_window_attention.py

    * transpose

commit 272a4d2
Author: Chang Lan <c.lan@apple.com>
Date:   Fri Feb 14 10:48:21 2025 -0800

    Add v5e-8 (apple#994)

commit debb46a
Author: Mark Lee <markblee@apple.com>
Date:   Thu Feb 13 22:27:28 2025 -0800

    Decouples jobsets from replicated jobs. (apple#991)

    * Decouples jobsets from replicated jobs.

    * Address comments.

commit 8f2b99d
Author: Maggie Zhang <jiya.zhang.98@gmail.com>
Date:   Thu Feb 13 20:49:48 2025 -0800

    Add Goodput documentation (apple#989)

    * Temporarily change checkpointing to every 5 steps

    * revert local changes

    * Add example command for goodput usage

commit 31e8da0
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Thu Feb 13 15:44:51 2025 -0800

    Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987)

commit 7f2dd9e
Author: Apoorv Gupta <apoorvgu@amazon.com>
Date:   Thu Feb 13 14:36:11 2025 -0800

    Flash Attention for Neuron (apple#939)

commit 6ca4f56
Author: Philipp Dufter <philipp.dufter@gmail.com>
Date:   Thu Feb 13 23:09:28 2025 +0100

    pass on log_warning in input_tf_data.skip_on_error (apple#990)

    * make log_warnings customizable in tfds skip error

    * address comments

commit 1a8a0eb
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Feb 13 13:32:34 2025 -0800

    Integrate Orbax's emergency checkpoint. (apple#820)

    * Integrate Orbax emergency checkpoint

    * Address comments

    * comment

    * Address comments

    * Upgrade orbax

    * Improve comments

    * Improve comments

    * Update for new orbax versions

    * Better timer

    * Address comments

    * Add step test

    * Fix

    * Add comment

commit 42fd715
Author: Apoorv Gupta <apoorvgu@amazon.com>
Date:   Thu Feb 13 09:13:30 2025 -0800

    TRN2 Meshes and Configurations (apple#916)

    * TRN2 Meshes and Configurations

    * Add get_recursive and set_recursive to ConfigBase.

    * Use loops inside get/set_recursively

    + address comments

    * Update partition spec

    * Use get_recursively inside set

    * Move trn2 configs to a helper function.

    + Fix modifier tests

    * TRN2 partitionspec supports DP over FSDP and TP

    * Use for loop in get_recursively

    * Update Golden Configs

commit d47d5ce
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Tue Feb 11 18:13:13 2025 -0800

    Add support to slice dataset based on proportions. (apple#982)

commit ed8f382
Author: Mark Lee <markblee@apple.com>
Date:   Tue Feb 11 13:22:44 2025 -0800

    Allow metrics layers to have state. (apple#978)

    * Allow metrics layers to have state.

    * Move BaseLossMetrics to a new file.

commit b130416
Author: Chang Lan <c.lan@apple.com>
Date:   Tue Feb 11 00:01:28 2025 -0800

    Transpose kv cache for better decode performance (apple#979)

commit 48bf488
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Mon Feb 10 22:25:18 2025 -0800

    Add support for grain.IterDataset in sampling (apple#980)

commit d4b563c
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Mon Feb 10 15:36:22 2025 -0800

    Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973)

commit 0666d80
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Mon Feb 10 15:35:23 2025 -0800

    Fix membership checks in tool_use_execution.py (apple#974)

commit 2f4763c
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Mon Feb 10 15:31:59 2025 -0800

    Remove redundant import logging (apple#975)

commit 58dcf33
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Mon Feb 10 13:41:33 2025 -0800

    Enable cudnn dropout (apple#913)

commit ae855ed
Author: Mark Lee <markblee@apple.com>
Date:   Mon Feb 10 12:43:50 2025 -0800

    Ensures that cache_dtype is respected. (apple#977)

commit cfef38b
Author: Daniel Swann <daniel_swann@apple.com>
Date:   Mon Feb 10 10:56:10 2025 -0800

    :sparkles: Add cache for CloudBuild API location queries (apple#967)

commit 8fd9137
Author: Wei Liu <wliu.aiml@apple.com>
Date:   Sun Feb 9 15:33:53 2025 -0800

    Add segment_ids option in DiTAttentionLayer (apple#976)

commit e55a404
Author: Chang Lan <c.lan@apple.com>
Date:   Sun Feb 9 04:38:49 2025 -0800

    Use broadcasting trick for KV update (apple#972)

    * Use vmap and dynamic_update_slice for KV update

    * Broadcasting trick

    * Simplify the impl per @markblee's suggestion

    * comments

commit b955187
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Fri Feb 7 14:12:48 2025 -0800

    Don't keep initial key/value inputs in the KV cache. (apple#968)

    The current code is weird. It stores the input key/value in the KV cache, but
    this doesn’t make sense in either init_states or prefill:
    * init_states: This is not prefill, so key/value should not be stored in the KV cache.
    * prefill: The extend_step() function overrides this part anyway.

    Thus, this PR removes this unnecessary and confusing logic.
    The logic was introduced in apple#860

commit c3d656d
Author: zhengdong-zhang <zhengdong_zhang@apple.com>
Date:   Fri Feb 7 10:18:42 2025 -0800

    Refactorization. (apple#963)

commit 1c883d8
Author: Zhao Xu <1541199+zetaqubit@users.noreply.github.com>
Date:   Fri Feb 7 10:02:56 2025 -0800

    Support system role when calling the Gemini API. (apple#971)

commit ceab4f4
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Thu Feb 6 20:41:07 2025 -0800

    Making shared_memory configurable (apple#969)

    * Making shared_memory configurable

    * fix eol space

commit 323faa3
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Thu Feb 6 12:11:28 2025 -0800

    Use env id for gcp settings (apple#957)

    * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone

    * fall back to zone

    * address comments

    * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly

commit 2ec3a02
Author: Chang Lan <c.lan@apple.com>
Date:   Wed Feb 5 22:25:58 2025 -0800

    Fix incorrect number of formatting arguments (apple#966)

commit d131d3b
Author: Nan Du <dunanbupt@gmail.com>
Date:   Mon Feb 3 11:44:00 2025 -0800

    Reduce the verbosity of variable norm summaries (apple#965)

commit c1c6e29
Author: Kelvin Zou <xuan_zou@apple.com>
Date:   Fri Jan 31 22:24:39 2025 -0800

    Sliding window support for GPU flash attention (apple#962)

    * snapshot

    * snapshot

    * snapshot

    * remove unexpected change

    * adding shape commenbt

    * fix pylint

    * snapshot

commit 0936a17
Author: Mark Lee <markblee@apple.com>
Date:   Fri Jan 31 13:59:12 2025 -0800

    Supports loss_weights and live_targets in metrics. (apple#960)

    * Supports loss_weights, live_targets, and module sharing in metrics.

    * Addresses comments.

    * Explicitly test flatten_metrics=True.

commit 7a40f91
Author: Dipannita Shaw <shaw.dipannita@gmail.com>
Date:   Fri Jan 31 11:45:33 2025 -0800

    Add Goodput & Badput recording and monitoring support. (apple#783)

    * Code clean up

    * Add more testing

    * Fix docstrings

    * Remove recorder calls from trainer for now

    * Code cleanup gcp/measurement.py

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

    * Code cleanup  common/measurement.py

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

    * Fix pre commit errors

    * Adding more tests

    * Further clean up

    * Fix a test error

    ---------

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

commit 031a7f3
Author: Mark Lee <markblee@apple.com>
Date:   Thu Jan 30 20:19:12 2025 -0800

    Skipping empty grain batches during unbatch. (apple#961)

    * Skipping empty grain batches during unbatch.

    * Use a loop instead of recursion.

commit 795da33
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Jan 30 07:17:16 2025 -0800

    Optimizer offloading through weight-only offload (apple#867)

    * Optimizer offloading

    * Style fix

    * Type fix

commit b1a1a5a
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Wed Jan 29 21:44:15 2025 -0800

    Improve gcsfuse io (apple#959)

commit d76ef6f
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Jan 29 15:10:13 2025 -0800

    SplashAttention performance tuning for v6e (apple#958)

    * SplashAttention tuning for v6e

    * Add import to fix pytype errors

commit 2d002e3
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Jan 29 12:07:56 2025 -0800

    Use InputDispatcher for fuji models (apple#956)

    * Use dispatcher

    * Update golden configs

    * Remove logical feed indices

commit fad264b
Author: Mark Lee <markblee@apple.com>
Date:   Tue Jan 28 10:41:54 2025 -0800

    Explicitly pass module outputs to metrics. (apple#953)

    * Explicitly pass module outputs to metrics.

    * Support and add checks for module/state updates.

    * Only flatten summaries.

commit 59508e3
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Tue Jan 28 10:34:52 2025 -0800

    Add v6e PCIe overload workaround flag (apple#955)

commit 028ecfd
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Mon Jan 27 20:54:28 2025 -0800

    Fix GCSFUSE flags by setting resource limit. (apple#954)

commit 3e2c6dd
Author: Matthew Hopkins <matthew_e_hopkins@apple.com>
Date:   Mon Jan 27 14:56:42 2025 -0800

    update jax to 0.4.37 (apple#948)

    update BlockSpec usage in tpu_attention
    use TYPE_CHECKING for BuildDatasetFn in input_fake
    add todo for BuildDatasetFn

commit b125f00
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Mon Jan 27 11:29:23 2025 -0800

    Add v6e special meshes (apple#952)

    * Add v6e special mesh

    * Add v6e special mesh

    * Fix

    * Fix

commit a854738
Author: Firenze11 <im.lezhi@gmail.com>
Date:   Mon Jan 27 09:17:46 2025 -0800

    Allow external positions to be inputed in RoPE embedding layer (apple#926)

    * Allow external positions to be inputed in RoPE embedding layer

    Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.

    * Update attention_test.py

    * Update dit.py

    * Update attention.py

    * Update attention_test.py

    * Update attention.py

    * Update dit.py

    * Update axlearn/common/attention.py

    Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com>

    * respond to comments.

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

    * Update attention.py

    * Update attention.py

    * Update attention.py

    ---------

    Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com>
    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

commit 999401a
Author: qdavid1 <168590940+qdavid1@users.noreply.github.com>
Date:   Mon Jan 27 09:11:17 2025 -0800

    Update LoraFusedQKVLinear (apple#949)

commit 1c22688
Author: Mark Lee <markblee@apple.com>
Date:   Sun Jan 26 04:51:02 2025 -0800

    Workaround module outputs being dropped. (apple#951)

commit 94c81cb
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Fri Jan 24 11:01:45 2025 -0800

    Add link to github issue regarding kubernetes-32.0.0 (apple#947)

commit a6e0f4a
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Fri Jan 24 08:40:25 2025 -0800

    Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946)

commit 076521a
Author: Mark Lee <markblee@apple.com>
Date:   Thu Jan 23 15:11:00 2025 -0800

    Forward input keys to decoder. (apple#944)

commit 30284c8
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Jan 23 10:33:54 2025 -0800

    Legacy flash remat fix (apple#943)

    * Fix the same problem for legacy tpu attn

    * Fix

commit 6a9f980
Author: Mark Lee <markblee@apple.com>
Date:   Thu Jan 23 09:20:46 2025 -0800

    Adds mesh rule for a3-megagpu-8g. (apple#936)

commit ac7a3ed
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Thu Jan 23 08:15:27 2025 -0800

    Enabled running Pallas Flash Attention on CPU. (apple#922)

    Pallas supports CPU simulation (`interpret=True`), so we can use the same
    TPU Pallas kernel on CPU — making code debugging easier.

    This change lets the following unittests run on CPU as if they were on TPU,
    enabling easier testing and debugging:
    - `axlearn/common/flash_attention/tpu_attention_test.py`

    Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
    - `axlearn/common/flash_attention/gpu_attention_test.py`

    Now CI covers those tests on CPU as well.
    In M3 Max MacBook Pro, test coverages and processing time are as follows,
    * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
    * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s

commit 8ea85bd
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Jan 22 09:51:15 2025 -0800

    Some fixes for flash remat (apple#942)

commit 185b1b5
Author: Chang Lan <c.lan@apple.com>
Date:   Tue Jan 21 11:21:08 2025 -0800

    Repeat KV heads in Flash Attention (apple#938)

    * Roll back '_repeat_kv_heads' change in Flash Attention

    Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization,
    in the hope to reduce HBM usage. However the actual HBM saving would be limited
    in the model-parallel setting, as the heads are already sharded across devices.
    It also introduces some limitation which breaks some of the existing sharding
    configurations.

    For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads,
    we can set the model axis as 8 so that each device will have only one Q, K, V head;
    Without repeat_kv_heads, the max value of model axis is 4, and each device will have
    2 Q heads as a result, increasing the actual HBM usage.

    * Repeat kv as necessary for sharding

    * Unit tests

    * Address comments.

commit 4678740
Author: Chang Lan <c.lan@apple.com>
Date:   Mon Jan 20 20:36:44 2025 -0800

    AOT compilation for v6e (apple#937)

commit 357bef6
Author: Mark Lee <markblee@apple.com>
Date:   Mon Jan 20 20:23:39 2025 -0800

    Makes causal lm metrics configurable. (apple#934)

    * Makes causal lm metrics configurable.

    * Address review comments.

    * Make metrics required.

    * Update golden configs.

    * Removes PredictModel.

commit 16ca0c2
Author: Mark Lee <markblee@apple.com>
Date:   Sun Jan 19 14:19:20 2025 -0800

    Supports flexible input partition specs. (apple#933)

    * Supports flexible input partition specs in causal lm.

    * Moves the input partitioning to Input.

    * Adds missing pytest marker.

    * Address review comments.

    * Rebase and update golden configs.

    * Fixes batch axis names and adds a test.

commit 9b75ef1
Author: Mark Lee <markblee@apple.com>
Date:   Sun Jan 19 07:43:19 2025 -0800

    Avoid a top-level import of tokenizers. (apple#935)

commit 9996f34
Author: sychen52 <41452870+sychen52@users.noreply.github.com>
Date:   Sat Jan 18 09:44:04 2025 -0800

    Add llama 3 tokenizer (apple#850)

    * Add llama 3 tokenizer

    add a new version called V3_TIKTOKEN.

    other edits based on suggestions.

    * Handle special tokens like other vocabularies.

    * use encode instead of encode_batch

commit ad14de3
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Fri Jan 17 14:19:24 2025 -0800

    Add ReadOptions args to _make_autoregressive_inputs (apple#931)

    * Add ReadOptions args to _make_autoregressive_inputs

    * use read_options as args instead

commit 4858070
Author: Sam Stoelinga <sammiestoel@gmail.com>
Date:   Fri Jan 17 13:54:05 2025 -0800

    improve GCS perf: Change resource limit to request (apple#851)

commit b0ee05e
Author: Bailin <bailin.wang28@gmail.com>
Date:   Fri Jan 17 22:53:00 2025 +0800

    Add Mamab2 and its Jamba variant (apple#839)

    * add mamab2

    * merge

    * unify init and prefill

    * adapt final changes

    ---------

    Co-authored-by: bailin_wang <bwang47@apple.com>

commit 1e25e4a
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Jan 16 11:25:24 2025 -0800

    Cache AoT compilation result (apple#927)

    * Cache AoT compilation result

    * Fix comments

    * Fix

    * Fix

    * Fix

    * Fix
rahul003 added a commit to rahul003/axlearn that referenced this pull request Mar 5, 2025
commit 336c75d
Author: Mark Lee <markblee@apple.com>
Date:   Mon Mar 3 09:04:07 2025 -0800

    Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029)

    * Allows specifying PartitionSpec to host_to_global_device_array.

    * Generalizes to arbitrary uniform partitioning.

    * Addresses comments and adds mixed shape test.

commit 0881412
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Sat Mar 1 15:41:38 2025 -0800

    Refactor Mask in Attention (apple#1028)

    Currently, the attention code is **hardcoded** to handle either `causal_mask`
    or an arbitrary `mask_fn`.

    To support **sliding window masks**, we previously used a **hack** by injecting
    the `_sliding_window_size` attribute into functions.

    This refactor **makes the masking logic more flexible** by allowing arbitrary
    `MaskFnAttentionBias`.
    - If downstream requires a **new mask pattern**, they can simply:
      1. Implement a **subclass of `MaskFnAttentionBias`**.
      2. Set `attention.mask` accordingly.

commit f67d3f9
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Fri Feb 28 08:53:00 2025 -0800

    Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026)

    Currently, Flash Attention infers decoding implicitly based on circumstantial
    evidence. This PR makes the check explicit.

commit f8d2c66
Author: qdavid1 <168590940+qdavid1@users.noreply.github.com>
Date:   Thu Feb 27 15:26:18 2025 -0800

    External KV input for _update_layer_kwargs (apple#1025)

commit a3bf5e2
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Feb 26 17:23:40 2025 -0800

    Minor changes to Checkpointer (apple#1024)

commit 55e1841
Author: Wentao Wu <wentao_wu@apple.com>
Date:   Wed Feb 26 15:45:51 2025 -0800

    Add an option to break ties for top_k_logits when k = 1 (apple#1022)

    * Add an option to support stable top_k = 1.

    * address comments

    * address comments

    * address comments

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <markblee@apple.com>

    * address comments

    ---------

    Co-authored-by: Mark Lee <markblee@apple.com>

commit fbca3fc
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Wed Feb 26 14:05:25 2025 -0800

    Add priority_class as a launch flag (apple#1020)

commit b26bd74
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Wed Feb 26 14:04:47 2025 -0800

    Fix TypeError in calcualte_goodput.py (apple#1023)

commit f8191e1
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Wed Feb 26 11:03:44 2025 -0800

    Emulate flash attentnion unittests on CPU. (apple#1021)

    utils.py codebase is not well covered by CI because it branches different
    backend.

    This PR introduces new CPU test, utils_test.py.
    This test is expected to run on CPU and is designed to validate GPU/TPU code
    from a CPU environment by fake mesh.
    It allows quick verification in CI and local environments to ensure that code
    changes do not break GPU/TPU Flash Attention.

commit daec8c5
Author: Chang Liu <raytomato@users.noreply.github.com>
Date:   Tue Feb 25 12:38:43 2025 -0800

    Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019)

    Co-authored-by: Chang Liu <changliu@apple.com>

commit ac642ea
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Tue Feb 25 12:02:57 2025 -0800

    Fix crash in log-mel frontend when waveform samples are integers. (apple#1017)

    After updating JAX, this existing hidden bug started causing CI failures.
    When the sample dtype is int32 (which is valid), `jnp.finfo` returns None,
    even though `jnp.iinfo` is available.
    The previous JAX version seemed to handle this case more forgivingly.

    ```
    ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram
        return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny))
    ```

commit 7c64b55
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Tue Feb 25 10:52:50 2025 -0800

    Add LoadBalancer to GKE replicatedJob (apple#1015)

    Co-authored-by: Liang (SPG) He <lhe27@apple.com>

commit 8e8a41b
Author: Chang Lan <c.lan@apple.com>
Date:   Tue Feb 25 10:26:59 2025 -0800

    Expose jax.lax.scan's unroll option to Repeat layer (apple#1016)

    * Expose jax.lax.scan's unroll option to Repeat layer.

    * Defaults to None to avoid golden config changes

commit 682bce6
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Tue Feb 25 10:09:41 2025 -0800

    Handle None bias in BiasAndResidual (apple#1018)

commit f053318
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Mon Feb 24 11:12:23 2025 -0500

    Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013)

    * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate().

    * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate().

commit a93cd1b
Author: Luzy <zhiyun@users.noreply.github.com>
Date:   Sat Feb 22 19:55:01 2025 -0500

    fix dtype in frontend pre emphasis (apple#1014)

commit c1fe2e9
Author: Maggie Zhang <jiya.zhang.98@gmail.com>
Date:   Fri Feb 21 19:07:39 2025 -0800

    GoodPut minor fix: only process 0 should start goodput uploader (apple#984)

    * only process 0 will start goodput uploader

    * Add unit test

commit 4b1fbf0
Author: Chang Lan <c.lan@apple.com>
Date:   Fri Feb 21 13:36:45 2025 -0800

    Async context invocation for checkpointing (apple#1012)

    * Async context invocation supoprt for checkpointer

    * Add comment

    * Add comments

commit d4cd158
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Fri Feb 21 14:22:24 2025 -0500

    Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011)

    * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase.

    This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor.

    * Fixes pytype.

    * Addresses review.

commit 2ae6e66
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Fri Feb 21 09:22:35 2025 -0800

    Enable megascale abort on hang or error (apple#1010)

    * Enable megascale_error_reporter_abort on hang and error by default

    * Increase threshold to 10m

commit ce4b2fb
Author: Chunyang Wen <chunyang.wen@gmail.com>
Date:   Fri Feb 21 23:12:38 2025 +0800

    Add GPU monitor (apple#1006)

commit baf8ad7
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Thu Feb 20 19:35:07 2025 -0800

    Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009)

commit cf41112
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Feb 20 16:29:13 2025 -0800

    Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008)

    * Revert "gRPC Checkpointer (apple#1005)"

    This reverts commit d27c562.

    * Keep some changes

commit 454bdba
Author: Matthew Hopkins <matthew_e_hopkins@apple.com>
Date:   Thu Feb 20 15:10:51 2025 -0800

    upgrade jax 0.4.38 (apple#1007)

commit d27c562
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Tue Feb 18 18:54:53 2025 -0800

    gRPC Checkpointer (apple#1005)

commit fb90620
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Tue Feb 18 21:08:38 2025 -0500

    Makes file_system.glob support multiple patterns. (apple#1003)

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

commit 334f421
Author: Mark Lee <markblee@apple.com>
Date:   Tue Feb 18 17:03:39 2025 -0800

    Reverts sliding window attention changes. (apple#1004)

    * Revert "Fix flash decoding in GPU. (apple#999)"

    This reverts commit fdadfd8.

    * Revert "Supports TPU context parallel training (apple#981)"

    This reverts commit e151d69.

    * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)"

    This reverts commit 67645d0.

    * Retain model/decoder asr changes.

commit 3dacc6b
Author: Chang Lan <c.lan@apple.com>
Date:   Mon Feb 17 19:45:00 2025 -0800

    Refactor aot_compilation for reuse (apple#1000)

commit c44fe18
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Mon Feb 17 21:38:17 2025 -0500

    Makes checkpointer_test.py use file_system. (apple#1001)

commit fdadfd8
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Mon Feb 17 18:23:21 2025 -0800

    Fix flash decoding in GPU. (apple#999)

    target_positions used to be time_step, but after PR apple#995, it now represents the
    actual target positions with shape [batch, step_len].
    apple#995

    Updating the GPU decoding code to align with this change.

    CI did not cover GPU unit tests.

    TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU

commit 9e64388
Author: Ruoming Pang <ruoming@gmail.com>
Date:   Mon Feb 17 16:40:03 2025 -0500

    Makes axlearn/cloud/ use file_system. (apple#998)

    * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency.

    * Adds testing for file_system.readfile.

    * Fixes pytype.

    * Makes axlearn/cloud use file_system instead of gfile.

commit 5fba4ce
Author: Chang Lan <c.lan@apple.com>
Date:   Mon Feb 17 09:44:10 2025 -0800

    AOT compilation support for inference (apple#997)

    * Add optional `devices` init argument to InferenceRunner for passing
      fake devices during AOT compilation.

    * Add more v5e slice types.

commit e151d69
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Sun Feb 16 13:01:15 2025 -0800

    Supports TPU context parallel training (apple#981)

    Fix

    Fix tests

commit 67645d0
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Sat Feb 15 13:26:51 2025 -0800

    Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)

    * Revert "Transpose kv cache for better decode performance (apple#979)"

    This reverts commit b130416.

    * Update golden configs

    * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding.

    Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for
    sliding window attention, the KV cache is kept for the full sequence length
    (`seq_len`) instead of the window length (`window_len`).

    For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache
    for the full 2M tokens. It then biases 1999k invalid KV tokens before
    calculating attention, resulting in a computational complexity of **O(2M²)**
    instead of the desired **O(1k²)**.

    This issue persists even when using flash attention. Flash attention uses the
    KV cache allocated in HBM as its input. While unnecessary blocks are discarded
    during computation, the KV cache still occupies HBM inefficiently for the full
    2M tokens.

    To address this, when `MultiheadAttention` detects a sliding window mask, it
    stores the key-value (KV) cache in a ring buffer inside the input linear layer.
    As a result, downstream projects using `MultiheadAttention` automatically
    benefit from efficient KV cache handling in `init_states` and `extend_step`.

    Additionally, for use cases like local-global attention in LLMs, it is
    recommended to use sliding window masks for even the global attention as well.
    For example, if you want to train an LLM with a context length of 8k, you can
    set the sliding window size to 8k during training. This enables functionally
    infinite decoding during inference. Accuracy wouldn't be good tho.

    Note:
    * query_positions in QKVLinear.forward() was introduced by
      apple#914. Now it returns to the caller.

    This PR actually moves from downstream speech/streaming/sliding_window_attention.py

    * transpose

commit 272a4d2
Author: Chang Lan <c.lan@apple.com>
Date:   Fri Feb 14 10:48:21 2025 -0800

    Add v5e-8 (apple#994)

commit debb46a
Author: Mark Lee <markblee@apple.com>
Date:   Thu Feb 13 22:27:28 2025 -0800

    Decouples jobsets from replicated jobs. (apple#991)

    * Decouples jobsets from replicated jobs.

    * Address comments.

commit 8f2b99d
Author: Maggie Zhang <jiya.zhang.98@gmail.com>
Date:   Thu Feb 13 20:49:48 2025 -0800

    Add Goodput documentation (apple#989)

    * Temporarily change checkpointing to every 5 steps

    * revert local changes

    * Add example command for goodput usage

commit 31e8da0
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Thu Feb 13 15:44:51 2025 -0800

    Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987)

commit 7f2dd9e
Author: Apoorv Gupta <apoorvgu@amazon.com>
Date:   Thu Feb 13 14:36:11 2025 -0800

    Flash Attention for Neuron (apple#939)

commit 6ca4f56
Author: Philipp Dufter <philipp.dufter@gmail.com>
Date:   Thu Feb 13 23:09:28 2025 +0100

    pass on log_warning in input_tf_data.skip_on_error (apple#990)

    * make log_warnings customizable in tfds skip error

    * address comments

commit 1a8a0eb
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Feb 13 13:32:34 2025 -0800

    Integrate Orbax's emergency checkpoint. (apple#820)

    * Integrate Orbax emergency checkpoint

    * Address comments

    * comment

    * Address comments

    * Upgrade orbax

    * Improve comments

    * Improve comments

    * Update for new orbax versions

    * Better timer

    * Address comments

    * Add step test

    * Fix

    * Add comment

commit 42fd715
Author: Apoorv Gupta <apoorvgu@amazon.com>
Date:   Thu Feb 13 09:13:30 2025 -0800

    TRN2 Meshes and Configurations (apple#916)

    * TRN2 Meshes and Configurations

    * Add get_recursive and set_recursive to ConfigBase.

    * Use loops inside get/set_recursively

    + address comments

    * Update partition spec

    * Use get_recursively inside set

    * Move trn2 configs to a helper function.

    + Fix modifier tests

    * TRN2 partitionspec supports DP over FSDP and TP

    * Use for loop in get_recursively

    * Update Golden Configs

commit d47d5ce
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Tue Feb 11 18:13:13 2025 -0800

    Add support to slice dataset based on proportions. (apple#982)

commit ed8f382
Author: Mark Lee <markblee@apple.com>
Date:   Tue Feb 11 13:22:44 2025 -0800

    Allow metrics layers to have state. (apple#978)

    * Allow metrics layers to have state.

    * Move BaseLossMetrics to a new file.

commit b130416
Author: Chang Lan <c.lan@apple.com>
Date:   Tue Feb 11 00:01:28 2025 -0800

    Transpose kv cache for better decode performance (apple#979)

commit 48bf488
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Mon Feb 10 22:25:18 2025 -0800

    Add support for grain.IterDataset in sampling (apple#980)

commit d4b563c
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Mon Feb 10 15:36:22 2025 -0800

    Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973)

commit 0666d80
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Mon Feb 10 15:35:23 2025 -0800

    Fix membership checks in tool_use_execution.py (apple#974)

commit 2f4763c
Author: Alexander Pivovarov <apivovarov@gmail.com>
Date:   Mon Feb 10 15:31:59 2025 -0800

    Remove redundant import logging (apple#975)

commit 58dcf33
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Mon Feb 10 13:41:33 2025 -0800

    Enable cudnn dropout (apple#913)

commit ae855ed
Author: Mark Lee <markblee@apple.com>
Date:   Mon Feb 10 12:43:50 2025 -0800

    Ensures that cache_dtype is respected. (apple#977)

commit cfef38b
Author: Daniel Swann <daniel_swann@apple.com>
Date:   Mon Feb 10 10:56:10 2025 -0800

    :sparkles: Add cache for CloudBuild API location queries (apple#967)

commit 8fd9137
Author: Wei Liu <wliu.aiml@apple.com>
Date:   Sun Feb 9 15:33:53 2025 -0800

    Add segment_ids option in DiTAttentionLayer (apple#976)

commit e55a404
Author: Chang Lan <c.lan@apple.com>
Date:   Sun Feb 9 04:38:49 2025 -0800

    Use broadcasting trick for KV update (apple#972)

    * Use vmap and dynamic_update_slice for KV update

    * Broadcasting trick

    * Simplify the impl per @markblee's suggestion

    * comments

commit b955187
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Fri Feb 7 14:12:48 2025 -0800

    Don't keep initial key/value inputs in the KV cache. (apple#968)

    The current code is weird. It stores the input key/value in the KV cache, but
    this doesn’t make sense in either init_states or prefill:
    * init_states: This is not prefill, so key/value should not be stored in the KV cache.
    * prefill: The extend_step() function overrides this part anyway.

    Thus, this PR removes this unnecessary and confusing logic.
    The logic was introduced in apple#860

commit c3d656d
Author: zhengdong-zhang <zhengdong_zhang@apple.com>
Date:   Fri Feb 7 10:18:42 2025 -0800

    Refactorization. (apple#963)

commit 1c883d8
Author: Zhao Xu <1541199+zetaqubit@users.noreply.github.com>
Date:   Fri Feb 7 10:02:56 2025 -0800

    Support system role when calling the Gemini API. (apple#971)

commit ceab4f4
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Thu Feb 6 20:41:07 2025 -0800

    Making shared_memory configurable (apple#969)

    * Making shared_memory configurable

    * fix eol space

commit 323faa3
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Thu Feb 6 12:11:28 2025 -0800

    Use env id for gcp settings (apple#957)

    * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone

    * fall back to zone

    * address comments

    * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly

commit 2ec3a02
Author: Chang Lan <c.lan@apple.com>
Date:   Wed Feb 5 22:25:58 2025 -0800

    Fix incorrect number of formatting arguments (apple#966)

commit d131d3b
Author: Nan Du <dunanbupt@gmail.com>
Date:   Mon Feb 3 11:44:00 2025 -0800

    Reduce the verbosity of variable norm summaries (apple#965)

commit c1c6e29
Author: Kelvin Zou <xuan_zou@apple.com>
Date:   Fri Jan 31 22:24:39 2025 -0800

    Sliding window support for GPU flash attention (apple#962)

    * snapshot

    * snapshot

    * snapshot

    * remove unexpected change

    * adding shape commenbt

    * fix pylint

    * snapshot

commit 0936a17
Author: Mark Lee <markblee@apple.com>
Date:   Fri Jan 31 13:59:12 2025 -0800

    Supports loss_weights and live_targets in metrics. (apple#960)

    * Supports loss_weights, live_targets, and module sharing in metrics.

    * Addresses comments.

    * Explicitly test flatten_metrics=True.

commit 7a40f91
Author: Dipannita Shaw <shaw.dipannita@gmail.com>
Date:   Fri Jan 31 11:45:33 2025 -0800

    Add Goodput & Badput recording and monitoring support. (apple#783)

    * Code clean up

    * Add more testing

    * Fix docstrings

    * Remove recorder calls from trainer for now

    * Code cleanup gcp/measurement.py

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

    * Code cleanup  common/measurement.py

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

    * Fix pre commit errors

    * Adding more tests

    * Further clean up

    * Fix a test error

    ---------

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

commit 031a7f3
Author: Mark Lee <markblee@apple.com>
Date:   Thu Jan 30 20:19:12 2025 -0800

    Skipping empty grain batches during unbatch. (apple#961)

    * Skipping empty grain batches during unbatch.

    * Use a loop instead of recursion.

commit 795da33
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Jan 30 07:17:16 2025 -0800

    Optimizer offloading through weight-only offload (apple#867)

    * Optimizer offloading

    * Style fix

    * Type fix

commit b1a1a5a
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Wed Jan 29 21:44:15 2025 -0800

    Improve gcsfuse io (apple#959)

commit d76ef6f
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Jan 29 15:10:13 2025 -0800

    SplashAttention performance tuning for v6e (apple#958)

    * SplashAttention tuning for v6e

    * Add import to fix pytype errors

commit 2d002e3
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Jan 29 12:07:56 2025 -0800

    Use InputDispatcher for fuji models (apple#956)

    * Use dispatcher

    * Update golden configs

    * Remove logical feed indices

commit fad264b
Author: Mark Lee <markblee@apple.com>
Date:   Tue Jan 28 10:41:54 2025 -0800

    Explicitly pass module outputs to metrics. (apple#953)

    * Explicitly pass module outputs to metrics.

    * Support and add checks for module/state updates.

    * Only flatten summaries.

commit 59508e3
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Tue Jan 28 10:34:52 2025 -0800

    Add v6e PCIe overload workaround flag (apple#955)

commit 028ecfd
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Mon Jan 27 20:54:28 2025 -0800

    Fix GCSFUSE flags by setting resource limit. (apple#954)

commit 3e2c6dd
Author: Matthew Hopkins <matthew_e_hopkins@apple.com>
Date:   Mon Jan 27 14:56:42 2025 -0800

    update jax to 0.4.37 (apple#948)

    update BlockSpec usage in tpu_attention
    use TYPE_CHECKING for BuildDatasetFn in input_fake
    add todo for BuildDatasetFn

commit b125f00
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Mon Jan 27 11:29:23 2025 -0800

    Add v6e special meshes (apple#952)

    * Add v6e special mesh

    * Add v6e special mesh

    * Fix

    * Fix

commit a854738
Author: Firenze11 <im.lezhi@gmail.com>
Date:   Mon Jan 27 09:17:46 2025 -0800

    Allow external positions to be inputed in RoPE embedding layer (apple#926)

    * Allow external positions to be inputed in RoPE embedding layer

    Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.

    * Update attention_test.py

    * Update dit.py

    * Update attention.py

    * Update attention_test.py

    * Update attention.py

    * Update dit.py

    * Update axlearn/common/attention.py

    Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com>

    * respond to comments.

    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

    * Update attention.py

    * Update attention.py

    * Update attention.py

    ---------

    Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com>
    Co-authored-by: Ruoming Pang <ruoming@gmail.com>

commit 999401a
Author: qdavid1 <168590940+qdavid1@users.noreply.github.com>
Date:   Mon Jan 27 09:11:17 2025 -0800

    Update LoraFusedQKVLinear (apple#949)

commit 1c22688
Author: Mark Lee <markblee@apple.com>
Date:   Sun Jan 26 04:51:02 2025 -0800

    Workaround module outputs being dropped. (apple#951)

commit 94c81cb
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Fri Jan 24 11:01:45 2025 -0800

    Add link to github issue regarding kubernetes-32.0.0 (apple#947)

commit a6e0f4a
Author: Meng (Ethan) Li <ethanli@apple.com>
Date:   Fri Jan 24 08:40:25 2025 -0800

    Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946)

commit 076521a
Author: Mark Lee <markblee@apple.com>
Date:   Thu Jan 23 15:11:00 2025 -0800

    Forward input keys to decoder. (apple#944)

commit 30284c8
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Jan 23 10:33:54 2025 -0800

    Legacy flash remat fix (apple#943)

    * Fix the same problem for legacy tpu attn

    * Fix

commit 6a9f980
Author: Mark Lee <markblee@apple.com>
Date:   Thu Jan 23 09:20:46 2025 -0800

    Adds mesh rule for a3-megagpu-8g. (apple#936)

commit ac7a3ed
Author: Dongseong Hwang <dhwang2@apple.com>
Date:   Thu Jan 23 08:15:27 2025 -0800

    Enabled running Pallas Flash Attention on CPU. (apple#922)

    Pallas supports CPU simulation (`interpret=True`), so we can use the same
    TPU Pallas kernel on CPU — making code debugging easier.

    This change lets the following unittests run on CPU as if they were on TPU,
    enabling easier testing and debugging:
    - `axlearn/common/flash_attention/tpu_attention_test.py`

    Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
    - `axlearn/common/flash_attention/gpu_attention_test.py`

    Now CI covers those tests on CPU as well.
    In M3 Max MacBook Pro, test coverages and processing time are as follows,
    * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
    * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s

commit 8ea85bd
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Wed Jan 22 09:51:15 2025 -0800

    Some fixes for flash remat (apple#942)

commit 185b1b5
Author: Chang Lan <c.lan@apple.com>
Date:   Tue Jan 21 11:21:08 2025 -0800

    Repeat KV heads in Flash Attention (apple#938)

    * Roll back '_repeat_kv_heads' change in Flash Attention

    Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization,
    in the hope to reduce HBM usage. However the actual HBM saving would be limited
    in the model-parallel setting, as the heads are already sharded across devices.
    It also introduces some limitation which breaks some of the existing sharding
    configurations.

    For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads,
    we can set the model axis as 8 so that each device will have only one Q, K, V head;
    Without repeat_kv_heads, the max value of model axis is 4, and each device will have
    2 Q heads as a result, increasing the actual HBM usage.

    * Repeat kv as necessary for sharding

    * Unit tests

    * Address comments.

commit 4678740
Author: Chang Lan <c.lan@apple.com>
Date:   Mon Jan 20 20:36:44 2025 -0800

    AOT compilation for v6e (apple#937)

commit 357bef6
Author: Mark Lee <markblee@apple.com>
Date:   Mon Jan 20 20:23:39 2025 -0800

    Makes causal lm metrics configurable. (apple#934)

    * Makes causal lm metrics configurable.

    * Address review comments.

    * Make metrics required.

    * Update golden configs.

    * Removes PredictModel.

commit 16ca0c2
Author: Mark Lee <markblee@apple.com>
Date:   Sun Jan 19 14:19:20 2025 -0800

    Supports flexible input partition specs. (apple#933)

    * Supports flexible input partition specs in causal lm.

    * Moves the input partitioning to Input.

    * Adds missing pytest marker.

    * Address review comments.

    * Rebase and update golden configs.

    * Fixes batch axis names and adds a test.

commit 9b75ef1
Author: Mark Lee <markblee@apple.com>
Date:   Sun Jan 19 07:43:19 2025 -0800

    Avoid a top-level import of tokenizers. (apple#935)

commit 9996f34
Author: sychen52 <41452870+sychen52@users.noreply.github.com>
Date:   Sat Jan 18 09:44:04 2025 -0800

    Add llama 3 tokenizer (apple#850)

    * Add llama 3 tokenizer

    add a new version called V3_TIKTOKEN.

    other edits based on suggestions.

    * Handle special tokens like other vocabularies.

    * use encode instead of encode_batch

commit ad14de3
Author: Haoshuo Huang <haoshuo_huang@apple.com>
Date:   Fri Jan 17 14:19:24 2025 -0800

    Add ReadOptions args to _make_autoregressive_inputs (apple#931)

    * Add ReadOptions args to _make_autoregressive_inputs

    * use read_options as args instead

commit 4858070
Author: Sam Stoelinga <sammiestoel@gmail.com>
Date:   Fri Jan 17 13:54:05 2025 -0800

    improve GCS perf: Change resource limit to request (apple#851)

commit b0ee05e
Author: Bailin <bailin.wang28@gmail.com>
Date:   Fri Jan 17 22:53:00 2025 +0800

    Add Mamab2 and its Jamba variant (apple#839)

    * add mamab2

    * merge

    * unify init and prefill

    * adapt final changes

    ---------

    Co-authored-by: bailin_wang <bwang47@apple.com>

commit 1e25e4a
Author: Hanzhi Zhou <hanzhi_zhou@apple.com>
Date:   Thu Jan 16 11:25:24 2025 -0800

    Cache AoT compilation result (apple#927)

    * Cache AoT compilation result

    * Fix comments

    * Fix

    * Fix

    * Fix

    * Fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants