-
Notifications
You must be signed in to change notification settings - Fork 303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabled running Pallas Flash Attention on CPU. #922
Conversation
@ruomingp Could you take a look? From 975 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few thoughts missed in earlier reviews...
@@ -152,6 +153,8 @@ def test_decode_against_ref( | |||
kv_head_factor: int, | |||
window_len: int, | |||
): | |||
if jax.default_backend() != "gpu" and seq_len > 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: can we check it against "cpu" directly instead of != "gpu"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, done.
@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref( | |||
causal: bool, | |||
dtype: jnp.dtype, | |||
): | |||
if jax.default_backend() == "cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, let's avoid assuming that the backend is either gpu or cpu in multiple places.
if jax.default_backend() == "cpu": | |
if jax.default_backend() != "gpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave this code as-is as you asked
Nit: can we check it against "cpu" directly instead of != "gpu"?
In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu"
is != "gpu"
in this code.
if jax.default_backend() not in ("gpu", "cpu"):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.
I know you are making this assumption, but such dependency is fragile---what if we extend the supported backends in the future?
In this case, requiring the backend to be "gpu" is both more robust and readable. What's the downside?
if jax.default_backend() == "cpu": | ||
pytest.skip(reason="cudnn function needs GPU.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here and elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, keep using jax.default_backend() == "cpu":
seq_len=[1024, 32768], | ||
seq_len=[1024], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the sliding window size is 1024, it will be useful to keep a test case for seq_len > 1024. We can enable the test only on TPU if it's too slow on CPU. We can also use a seq_len such as 2048 for cpu if it's fast enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
softmax_scale=softmax_scale, | ||
block_size=block_size, | ||
interpret=(backend == "cpu"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given how often we do this across locations, I wonder if we can do the following:
- Make
interpret
default to None (instead of False); - If it's None, assume interpret=True if the backend is "cpu";
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion. interpret=True
applies only to the Pallas kernel. Therefore, having an interpret
variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.
Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if
statement to:
elif backend in ("cpu", "tpu"):
would allow debugging in layer_test.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for review. I responded all comments. Could you check it again?
softmax_scale=softmax_scale, | ||
block_size=block_size, | ||
interpret=(backend == "cpu"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion. interpret=True
applies only to the Pallas kernel. Therefore, having an interpret
variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.
Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if
statement to:
elif backend in ("cpu", "tpu"):
would allow debugging in layer_test.py
.
@@ -152,6 +153,8 @@ def test_decode_against_ref( | |||
kv_head_factor: int, | |||
window_len: int, | |||
): | |||
if jax.default_backend() != "gpu" and seq_len > 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, done.
@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref( | |||
causal: bool, | |||
dtype: jnp.dtype, | |||
): | |||
if jax.default_backend() == "cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave this code as-is as you asked
Nit: can we check it against "cpu" directly instead of != "gpu"?
In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu"
is != "gpu"
in this code.
if jax.default_backend() not in ("gpu", "cpu"):
if jax.default_backend() == "cpu": | ||
pytest.skip(reason="cudnn function needs GPU.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, keep using jax.default_backend() == "cpu":
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems possible to support interpret=None
to simplify the code, with the following behavior:
interpret=True/False
: enable/disable interpret;interpret=None
: let the implementation choose whether to interpret, depending on whether pallas is used and running on cpu vs. accelerator;
But I don't want to block this PR, as we can simplify it later.
Thank you for review! |
0e39bdd
to
3f4a177
Compare
Pallas supports CPU simulation (`interpret=True`), so we can use the same TPU Pallas kernel on CPU — making code debugging easier. This change lets the following unittests run on CPU as if they were on TPU, enabling easier testing and debugging: - `axlearn/common/flash_attention/tpu_attention_test.py` Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU. - `axlearn/common/flash_attention/gpu_attention_test.py` Now CI covers those tests on CPU as well. In M3 Max MacBook Pro, test coverages and processing time are as follows, * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20) * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s
@ruomingp Could you reapprove it? I rebased the commit because it’s been pending for a few days and couldn’t be merged due to an unrelated test failure. |
commit 336c75d Author: Mark Lee <markblee@apple.com> Date: Mon Mar 3 09:04:07 2025 -0800 Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029) * Allows specifying PartitionSpec to host_to_global_device_array. * Generalizes to arbitrary uniform partitioning. * Addresses comments and adds mixed shape test. commit 0881412 Author: Dongseong Hwang <dhwang2@apple.com> Date: Sat Mar 1 15:41:38 2025 -0800 Refactor Mask in Attention (apple#1028) Currently, the attention code is **hardcoded** to handle either `causal_mask` or an arbitrary `mask_fn`. To support **sliding window masks**, we previously used a **hack** by injecting the `_sliding_window_size` attribute into functions. This refactor **makes the masking logic more flexible** by allowing arbitrary `MaskFnAttentionBias`. - If downstream requires a **new mask pattern**, they can simply: 1. Implement a **subclass of `MaskFnAttentionBias`**. 2. Set `attention.mask` accordingly. commit f67d3f9 Author: Dongseong Hwang <dhwang2@apple.com> Date: Fri Feb 28 08:53:00 2025 -0800 Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026) Currently, Flash Attention infers decoding implicitly based on circumstantial evidence. This PR makes the check explicit. commit f8d2c66 Author: qdavid1 <168590940+qdavid1@users.noreply.github.com> Date: Thu Feb 27 15:26:18 2025 -0800 External KV input for _update_layer_kwargs (apple#1025) commit a3bf5e2 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Feb 26 17:23:40 2025 -0800 Minor changes to Checkpointer (apple#1024) commit 55e1841 Author: Wentao Wu <wentao_wu@apple.com> Date: Wed Feb 26 15:45:51 2025 -0800 Add an option to break ties for top_k_logits when k = 1 (apple#1022) * Add an option to support stable top_k = 1. * address comments * address comments * address comments * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * address comments --------- Co-authored-by: Mark Lee <markblee@apple.com> commit fbca3fc Author: Meng (Ethan) Li <ethanli@apple.com> Date: Wed Feb 26 14:05:25 2025 -0800 Add priority_class as a launch flag (apple#1020) commit b26bd74 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Wed Feb 26 14:04:47 2025 -0800 Fix TypeError in calcualte_goodput.py (apple#1023) commit f8191e1 Author: Dongseong Hwang <dhwang2@apple.com> Date: Wed Feb 26 11:03:44 2025 -0800 Emulate flash attentnion unittests on CPU. (apple#1021) utils.py codebase is not well covered by CI because it branches different backend. This PR introduces new CPU test, utils_test.py. This test is expected to run on CPU and is designed to validate GPU/TPU code from a CPU environment by fake mesh. It allows quick verification in CI and local environments to ensure that code changes do not break GPU/TPU Flash Attention. commit daec8c5 Author: Chang Liu <raytomato@users.noreply.github.com> Date: Tue Feb 25 12:38:43 2025 -0800 Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019) Co-authored-by: Chang Liu <changliu@apple.com> commit ac642ea Author: Dongseong Hwang <dhwang2@apple.com> Date: Tue Feb 25 12:02:57 2025 -0800 Fix crash in log-mel frontend when waveform samples are integers. (apple#1017) After updating JAX, this existing hidden bug started causing CI failures. When the sample dtype is int32 (which is valid), `jnp.finfo` returns None, even though `jnp.iinfo` is available. The previous JAX version seemed to handle this case more forgivingly. ``` ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny)) ``` commit 7c64b55 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Tue Feb 25 10:52:50 2025 -0800 Add LoadBalancer to GKE replicatedJob (apple#1015) Co-authored-by: Liang (SPG) He <lhe27@apple.com> commit 8e8a41b Author: Chang Lan <c.lan@apple.com> Date: Tue Feb 25 10:26:59 2025 -0800 Expose jax.lax.scan's unroll option to Repeat layer (apple#1016) * Expose jax.lax.scan's unroll option to Repeat layer. * Defaults to None to avoid golden config changes commit 682bce6 Author: Dongseong Hwang <dhwang2@apple.com> Date: Tue Feb 25 10:09:41 2025 -0800 Handle None bias in BiasAndResidual (apple#1018) commit f053318 Author: Ruoming Pang <ruoming@gmail.com> Date: Mon Feb 24 11:12:23 2025 -0500 Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013) * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate(). * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate(). commit a93cd1b Author: Luzy <zhiyun@users.noreply.github.com> Date: Sat Feb 22 19:55:01 2025 -0500 fix dtype in frontend pre emphasis (apple#1014) commit c1fe2e9 Author: Maggie Zhang <jiya.zhang.98@gmail.com> Date: Fri Feb 21 19:07:39 2025 -0800 GoodPut minor fix: only process 0 should start goodput uploader (apple#984) * only process 0 will start goodput uploader * Add unit test commit 4b1fbf0 Author: Chang Lan <c.lan@apple.com> Date: Fri Feb 21 13:36:45 2025 -0800 Async context invocation for checkpointing (apple#1012) * Async context invocation supoprt for checkpointer * Add comment * Add comments commit d4cd158 Author: Ruoming Pang <ruoming@gmail.com> Date: Fri Feb 21 14:22:24 2025 -0500 Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011) * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor. * Fixes pytype. * Addresses review. commit 2ae6e66 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Fri Feb 21 09:22:35 2025 -0800 Enable megascale abort on hang or error (apple#1010) * Enable megascale_error_reporter_abort on hang and error by default * Increase threshold to 10m commit ce4b2fb Author: Chunyang Wen <chunyang.wen@gmail.com> Date: Fri Feb 21 23:12:38 2025 +0800 Add GPU monitor (apple#1006) commit baf8ad7 Author: Dongseong Hwang <dhwang2@apple.com> Date: Thu Feb 20 19:35:07 2025 -0800 Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009) commit cf41112 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Feb 20 16:29:13 2025 -0800 Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008) * Revert "gRPC Checkpointer (apple#1005)" This reverts commit d27c562. * Keep some changes commit 454bdba Author: Matthew Hopkins <matthew_e_hopkins@apple.com> Date: Thu Feb 20 15:10:51 2025 -0800 upgrade jax 0.4.38 (apple#1007) commit d27c562 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Tue Feb 18 18:54:53 2025 -0800 gRPC Checkpointer (apple#1005) commit fb90620 Author: Ruoming Pang <ruoming@gmail.com> Date: Tue Feb 18 21:08:38 2025 -0500 Makes file_system.glob support multiple patterns. (apple#1003) * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. commit 334f421 Author: Mark Lee <markblee@apple.com> Date: Tue Feb 18 17:03:39 2025 -0800 Reverts sliding window attention changes. (apple#1004) * Revert "Fix flash decoding in GPU. (apple#999)" This reverts commit fdadfd8. * Revert "Supports TPU context parallel training (apple#981)" This reverts commit e151d69. * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)" This reverts commit 67645d0. * Retain model/decoder asr changes. commit 3dacc6b Author: Chang Lan <c.lan@apple.com> Date: Mon Feb 17 19:45:00 2025 -0800 Refactor aot_compilation for reuse (apple#1000) commit c44fe18 Author: Ruoming Pang <ruoming@gmail.com> Date: Mon Feb 17 21:38:17 2025 -0500 Makes checkpointer_test.py use file_system. (apple#1001) commit fdadfd8 Author: Dongseong Hwang <dhwang2@apple.com> Date: Mon Feb 17 18:23:21 2025 -0800 Fix flash decoding in GPU. (apple#999) target_positions used to be time_step, but after PR apple#995, it now represents the actual target positions with shape [batch, step_len]. apple#995 Updating the GPU decoding code to align with this change. CI did not cover GPU unit tests. TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU commit 9e64388 Author: Ruoming Pang <ruoming@gmail.com> Date: Mon Feb 17 16:40:03 2025 -0500 Makes axlearn/cloud/ use file_system. (apple#998) * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency. * Adds testing for file_system.readfile. * Fixes pytype. * Makes axlearn/cloud use file_system instead of gfile. commit 5fba4ce Author: Chang Lan <c.lan@apple.com> Date: Mon Feb 17 09:44:10 2025 -0800 AOT compilation support for inference (apple#997) * Add optional `devices` init argument to InferenceRunner for passing fake devices during AOT compilation. * Add more v5e slice types. commit e151d69 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Sun Feb 16 13:01:15 2025 -0800 Supports TPU context parallel training (apple#981) Fix Fix tests commit 67645d0 Author: Dongseong Hwang <dhwang2@apple.com> Date: Sat Feb 15 13:26:51 2025 -0800 Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995) * Revert "Transpose kv cache for better decode performance (apple#979)" This reverts commit b130416. * Update golden configs * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for sliding window attention, the KV cache is kept for the full sequence length (`seq_len`) instead of the window length (`window_len`). For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache for the full 2M tokens. It then biases 1999k invalid KV tokens before calculating attention, resulting in a computational complexity of **O(2M²)** instead of the desired **O(1k²)**. This issue persists even when using flash attention. Flash attention uses the KV cache allocated in HBM as its input. While unnecessary blocks are discarded during computation, the KV cache still occupies HBM inefficiently for the full 2M tokens. To address this, when `MultiheadAttention` detects a sliding window mask, it stores the key-value (KV) cache in a ring buffer inside the input linear layer. As a result, downstream projects using `MultiheadAttention` automatically benefit from efficient KV cache handling in `init_states` and `extend_step`. Additionally, for use cases like local-global attention in LLMs, it is recommended to use sliding window masks for even the global attention as well. For example, if you want to train an LLM with a context length of 8k, you can set the sliding window size to 8k during training. This enables functionally infinite decoding during inference. Accuracy wouldn't be good tho. Note: * query_positions in QKVLinear.forward() was introduced by apple#914. Now it returns to the caller. This PR actually moves from downstream speech/streaming/sliding_window_attention.py * transpose commit 272a4d2 Author: Chang Lan <c.lan@apple.com> Date: Fri Feb 14 10:48:21 2025 -0800 Add v5e-8 (apple#994) commit debb46a Author: Mark Lee <markblee@apple.com> Date: Thu Feb 13 22:27:28 2025 -0800 Decouples jobsets from replicated jobs. (apple#991) * Decouples jobsets from replicated jobs. * Address comments. commit 8f2b99d Author: Maggie Zhang <jiya.zhang.98@gmail.com> Date: Thu Feb 13 20:49:48 2025 -0800 Add Goodput documentation (apple#989) * Temporarily change checkpointing to every 5 steps * revert local changes * Add example command for goodput usage commit 31e8da0 Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Thu Feb 13 15:44:51 2025 -0800 Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987) commit 7f2dd9e Author: Apoorv Gupta <apoorvgu@amazon.com> Date: Thu Feb 13 14:36:11 2025 -0800 Flash Attention for Neuron (apple#939) commit 6ca4f56 Author: Philipp Dufter <philipp.dufter@gmail.com> Date: Thu Feb 13 23:09:28 2025 +0100 pass on log_warning in input_tf_data.skip_on_error (apple#990) * make log_warnings customizable in tfds skip error * address comments commit 1a8a0eb Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Feb 13 13:32:34 2025 -0800 Integrate Orbax's emergency checkpoint. (apple#820) * Integrate Orbax emergency checkpoint * Address comments * comment * Address comments * Upgrade orbax * Improve comments * Improve comments * Update for new orbax versions * Better timer * Address comments * Add step test * Fix * Add comment commit 42fd715 Author: Apoorv Gupta <apoorvgu@amazon.com> Date: Thu Feb 13 09:13:30 2025 -0800 TRN2 Meshes and Configurations (apple#916) * TRN2 Meshes and Configurations * Add get_recursive and set_recursive to ConfigBase. * Use loops inside get/set_recursively + address comments * Update partition spec * Use get_recursively inside set * Move trn2 configs to a helper function. + Fix modifier tests * TRN2 partitionspec supports DP over FSDP and TP * Use for loop in get_recursively * Update Golden Configs commit d47d5ce Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Tue Feb 11 18:13:13 2025 -0800 Add support to slice dataset based on proportions. (apple#982) commit ed8f382 Author: Mark Lee <markblee@apple.com> Date: Tue Feb 11 13:22:44 2025 -0800 Allow metrics layers to have state. (apple#978) * Allow metrics layers to have state. * Move BaseLossMetrics to a new file. commit b130416 Author: Chang Lan <c.lan@apple.com> Date: Tue Feb 11 00:01:28 2025 -0800 Transpose kv cache for better decode performance (apple#979) commit 48bf488 Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Mon Feb 10 22:25:18 2025 -0800 Add support for grain.IterDataset in sampling (apple#980) commit d4b563c Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Mon Feb 10 15:36:22 2025 -0800 Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973) commit 0666d80 Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Mon Feb 10 15:35:23 2025 -0800 Fix membership checks in tool_use_execution.py (apple#974) commit 2f4763c Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Mon Feb 10 15:31:59 2025 -0800 Remove redundant import logging (apple#975) commit 58dcf33 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Mon Feb 10 13:41:33 2025 -0800 Enable cudnn dropout (apple#913) commit ae855ed Author: Mark Lee <markblee@apple.com> Date: Mon Feb 10 12:43:50 2025 -0800 Ensures that cache_dtype is respected. (apple#977) commit cfef38b Author: Daniel Swann <daniel_swann@apple.com> Date: Mon Feb 10 10:56:10 2025 -0800 :sparkles: Add cache for CloudBuild API location queries (apple#967) commit 8fd9137 Author: Wei Liu <wliu.aiml@apple.com> Date: Sun Feb 9 15:33:53 2025 -0800 Add segment_ids option in DiTAttentionLayer (apple#976) commit e55a404 Author: Chang Lan <c.lan@apple.com> Date: Sun Feb 9 04:38:49 2025 -0800 Use broadcasting trick for KV update (apple#972) * Use vmap and dynamic_update_slice for KV update * Broadcasting trick * Simplify the impl per @markblee's suggestion * comments commit b955187 Author: Dongseong Hwang <dhwang2@apple.com> Date: Fri Feb 7 14:12:48 2025 -0800 Don't keep initial key/value inputs in the KV cache. (apple#968) The current code is weird. It stores the input key/value in the KV cache, but this doesn’t make sense in either init_states or prefill: * init_states: This is not prefill, so key/value should not be stored in the KV cache. * prefill: The extend_step() function overrides this part anyway. Thus, this PR removes this unnecessary and confusing logic. The logic was introduced in apple#860 commit c3d656d Author: zhengdong-zhang <zhengdong_zhang@apple.com> Date: Fri Feb 7 10:18:42 2025 -0800 Refactorization. (apple#963) commit 1c883d8 Author: Zhao Xu <1541199+zetaqubit@users.noreply.github.com> Date: Fri Feb 7 10:02:56 2025 -0800 Support system role when calling the Gemini API. (apple#971) commit ceab4f4 Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Thu Feb 6 20:41:07 2025 -0800 Making shared_memory configurable (apple#969) * Making shared_memory configurable * fix eol space commit 323faa3 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Thu Feb 6 12:11:28 2025 -0800 Use env id for gcp settings (apple#957) * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone * fall back to zone * address comments * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly commit 2ec3a02 Author: Chang Lan <c.lan@apple.com> Date: Wed Feb 5 22:25:58 2025 -0800 Fix incorrect number of formatting arguments (apple#966) commit d131d3b Author: Nan Du <dunanbupt@gmail.com> Date: Mon Feb 3 11:44:00 2025 -0800 Reduce the verbosity of variable norm summaries (apple#965) commit c1c6e29 Author: Kelvin Zou <xuan_zou@apple.com> Date: Fri Jan 31 22:24:39 2025 -0800 Sliding window support for GPU flash attention (apple#962) * snapshot * snapshot * snapshot * remove unexpected change * adding shape commenbt * fix pylint * snapshot commit 0936a17 Author: Mark Lee <markblee@apple.com> Date: Fri Jan 31 13:59:12 2025 -0800 Supports loss_weights and live_targets in metrics. (apple#960) * Supports loss_weights, live_targets, and module sharing in metrics. * Addresses comments. * Explicitly test flatten_metrics=True. commit 7a40f91 Author: Dipannita Shaw <shaw.dipannita@gmail.com> Date: Fri Jan 31 11:45:33 2025 -0800 Add Goodput & Badput recording and monitoring support. (apple#783) * Code clean up * Add more testing * Fix docstrings * Remove recorder calls from trainer for now * Code cleanup gcp/measurement.py Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Code cleanup common/measurement.py Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Fix pre commit errors * Adding more tests * Further clean up * Fix a test error --------- Co-authored-by: Ruoming Pang <ruoming@gmail.com> commit 031a7f3 Author: Mark Lee <markblee@apple.com> Date: Thu Jan 30 20:19:12 2025 -0800 Skipping empty grain batches during unbatch. (apple#961) * Skipping empty grain batches during unbatch. * Use a loop instead of recursion. commit 795da33 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Jan 30 07:17:16 2025 -0800 Optimizer offloading through weight-only offload (apple#867) * Optimizer offloading * Style fix * Type fix commit b1a1a5a Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Wed Jan 29 21:44:15 2025 -0800 Improve gcsfuse io (apple#959) commit d76ef6f Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Jan 29 15:10:13 2025 -0800 SplashAttention performance tuning for v6e (apple#958) * SplashAttention tuning for v6e * Add import to fix pytype errors commit 2d002e3 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Jan 29 12:07:56 2025 -0800 Use InputDispatcher for fuji models (apple#956) * Use dispatcher * Update golden configs * Remove logical feed indices commit fad264b Author: Mark Lee <markblee@apple.com> Date: Tue Jan 28 10:41:54 2025 -0800 Explicitly pass module outputs to metrics. (apple#953) * Explicitly pass module outputs to metrics. * Support and add checks for module/state updates. * Only flatten summaries. commit 59508e3 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Tue Jan 28 10:34:52 2025 -0800 Add v6e PCIe overload workaround flag (apple#955) commit 028ecfd Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Mon Jan 27 20:54:28 2025 -0800 Fix GCSFUSE flags by setting resource limit. (apple#954) commit 3e2c6dd Author: Matthew Hopkins <matthew_e_hopkins@apple.com> Date: Mon Jan 27 14:56:42 2025 -0800 update jax to 0.4.37 (apple#948) update BlockSpec usage in tpu_attention use TYPE_CHECKING for BuildDatasetFn in input_fake add todo for BuildDatasetFn commit b125f00 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Mon Jan 27 11:29:23 2025 -0800 Add v6e special meshes (apple#952) * Add v6e special mesh * Add v6e special mesh * Fix * Fix commit a854738 Author: Firenze11 <im.lezhi@gmail.com> Date: Mon Jan 27 09:17:46 2025 -0800 Allow external positions to be inputed in RoPE embedding layer (apple#926) * Allow external positions to be inputed in RoPE embedding layer Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`. * Update attention_test.py * Update dit.py * Update attention.py * Update attention_test.py * Update attention.py * Update dit.py * Update axlearn/common/attention.py Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com> * respond to comments. Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Update attention.py * Update attention.py * Update attention.py --------- Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com> Co-authored-by: Ruoming Pang <ruoming@gmail.com> commit 999401a Author: qdavid1 <168590940+qdavid1@users.noreply.github.com> Date: Mon Jan 27 09:11:17 2025 -0800 Update LoraFusedQKVLinear (apple#949) commit 1c22688 Author: Mark Lee <markblee@apple.com> Date: Sun Jan 26 04:51:02 2025 -0800 Workaround module outputs being dropped. (apple#951) commit 94c81cb Author: Meng (Ethan) Li <ethanli@apple.com> Date: Fri Jan 24 11:01:45 2025 -0800 Add link to github issue regarding kubernetes-32.0.0 (apple#947) commit a6e0f4a Author: Meng (Ethan) Li <ethanli@apple.com> Date: Fri Jan 24 08:40:25 2025 -0800 Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946) commit 076521a Author: Mark Lee <markblee@apple.com> Date: Thu Jan 23 15:11:00 2025 -0800 Forward input keys to decoder. (apple#944) commit 30284c8 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Jan 23 10:33:54 2025 -0800 Legacy flash remat fix (apple#943) * Fix the same problem for legacy tpu attn * Fix commit 6a9f980 Author: Mark Lee <markblee@apple.com> Date: Thu Jan 23 09:20:46 2025 -0800 Adds mesh rule for a3-megagpu-8g. (apple#936) commit ac7a3ed Author: Dongseong Hwang <dhwang2@apple.com> Date: Thu Jan 23 08:15:27 2025 -0800 Enabled running Pallas Flash Attention on CPU. (apple#922) Pallas supports CPU simulation (`interpret=True`), so we can use the same TPU Pallas kernel on CPU — making code debugging easier. This change lets the following unittests run on CPU as if they were on TPU, enabling easier testing and debugging: - `axlearn/common/flash_attention/tpu_attention_test.py` Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU. - `axlearn/common/flash_attention/gpu_attention_test.py` Now CI covers those tests on CPU as well. In M3 Max MacBook Pro, test coverages and processing time are as follows, * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20) * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s commit 8ea85bd Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Jan 22 09:51:15 2025 -0800 Some fixes for flash remat (apple#942) commit 185b1b5 Author: Chang Lan <c.lan@apple.com> Date: Tue Jan 21 11:21:08 2025 -0800 Repeat KV heads in Flash Attention (apple#938) * Roll back '_repeat_kv_heads' change in Flash Attention Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization, in the hope to reduce HBM usage. However the actual HBM saving would be limited in the model-parallel setting, as the heads are already sharded across devices. It also introduces some limitation which breaks some of the existing sharding configurations. For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads, we can set the model axis as 8 so that each device will have only one Q, K, V head; Without repeat_kv_heads, the max value of model axis is 4, and each device will have 2 Q heads as a result, increasing the actual HBM usage. * Repeat kv as necessary for sharding * Unit tests * Address comments. commit 4678740 Author: Chang Lan <c.lan@apple.com> Date: Mon Jan 20 20:36:44 2025 -0800 AOT compilation for v6e (apple#937) commit 357bef6 Author: Mark Lee <markblee@apple.com> Date: Mon Jan 20 20:23:39 2025 -0800 Makes causal lm metrics configurable. (apple#934) * Makes causal lm metrics configurable. * Address review comments. * Make metrics required. * Update golden configs. * Removes PredictModel. commit 16ca0c2 Author: Mark Lee <markblee@apple.com> Date: Sun Jan 19 14:19:20 2025 -0800 Supports flexible input partition specs. (apple#933) * Supports flexible input partition specs in causal lm. * Moves the input partitioning to Input. * Adds missing pytest marker. * Address review comments. * Rebase and update golden configs. * Fixes batch axis names and adds a test. commit 9b75ef1 Author: Mark Lee <markblee@apple.com> Date: Sun Jan 19 07:43:19 2025 -0800 Avoid a top-level import of tokenizers. (apple#935) commit 9996f34 Author: sychen52 <41452870+sychen52@users.noreply.github.com> Date: Sat Jan 18 09:44:04 2025 -0800 Add llama 3 tokenizer (apple#850) * Add llama 3 tokenizer add a new version called V3_TIKTOKEN. other edits based on suggestions. * Handle special tokens like other vocabularies. * use encode instead of encode_batch commit ad14de3 Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Fri Jan 17 14:19:24 2025 -0800 Add ReadOptions args to _make_autoregressive_inputs (apple#931) * Add ReadOptions args to _make_autoregressive_inputs * use read_options as args instead commit 4858070 Author: Sam Stoelinga <sammiestoel@gmail.com> Date: Fri Jan 17 13:54:05 2025 -0800 improve GCS perf: Change resource limit to request (apple#851) commit b0ee05e Author: Bailin <bailin.wang28@gmail.com> Date: Fri Jan 17 22:53:00 2025 +0800 Add Mamab2 and its Jamba variant (apple#839) * add mamab2 * merge * unify init and prefill * adapt final changes --------- Co-authored-by: bailin_wang <bwang47@apple.com> commit 1e25e4a Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Jan 16 11:25:24 2025 -0800 Cache AoT compilation result (apple#927) * Cache AoT compilation result * Fix comments * Fix * Fix * Fix * Fix
commit 336c75d Author: Mark Lee <markblee@apple.com> Date: Mon Mar 3 09:04:07 2025 -0800 Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029) * Allows specifying PartitionSpec to host_to_global_device_array. * Generalizes to arbitrary uniform partitioning. * Addresses comments and adds mixed shape test. commit 0881412 Author: Dongseong Hwang <dhwang2@apple.com> Date: Sat Mar 1 15:41:38 2025 -0800 Refactor Mask in Attention (apple#1028) Currently, the attention code is **hardcoded** to handle either `causal_mask` or an arbitrary `mask_fn`. To support **sliding window masks**, we previously used a **hack** by injecting the `_sliding_window_size` attribute into functions. This refactor **makes the masking logic more flexible** by allowing arbitrary `MaskFnAttentionBias`. - If downstream requires a **new mask pattern**, they can simply: 1. Implement a **subclass of `MaskFnAttentionBias`**. 2. Set `attention.mask` accordingly. commit f67d3f9 Author: Dongseong Hwang <dhwang2@apple.com> Date: Fri Feb 28 08:53:00 2025 -0800 Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026) Currently, Flash Attention infers decoding implicitly based on circumstantial evidence. This PR makes the check explicit. commit f8d2c66 Author: qdavid1 <168590940+qdavid1@users.noreply.github.com> Date: Thu Feb 27 15:26:18 2025 -0800 External KV input for _update_layer_kwargs (apple#1025) commit a3bf5e2 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Feb 26 17:23:40 2025 -0800 Minor changes to Checkpointer (apple#1024) commit 55e1841 Author: Wentao Wu <wentao_wu@apple.com> Date: Wed Feb 26 15:45:51 2025 -0800 Add an option to break ties for top_k_logits when k = 1 (apple#1022) * Add an option to support stable top_k = 1. * address comments * address comments * address comments * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <markblee@apple.com> * address comments --------- Co-authored-by: Mark Lee <markblee@apple.com> commit fbca3fc Author: Meng (Ethan) Li <ethanli@apple.com> Date: Wed Feb 26 14:05:25 2025 -0800 Add priority_class as a launch flag (apple#1020) commit b26bd74 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Wed Feb 26 14:04:47 2025 -0800 Fix TypeError in calcualte_goodput.py (apple#1023) commit f8191e1 Author: Dongseong Hwang <dhwang2@apple.com> Date: Wed Feb 26 11:03:44 2025 -0800 Emulate flash attentnion unittests on CPU. (apple#1021) utils.py codebase is not well covered by CI because it branches different backend. This PR introduces new CPU test, utils_test.py. This test is expected to run on CPU and is designed to validate GPU/TPU code from a CPU environment by fake mesh. It allows quick verification in CI and local environments to ensure that code changes do not break GPU/TPU Flash Attention. commit daec8c5 Author: Chang Liu <raytomato@users.noreply.github.com> Date: Tue Feb 25 12:38:43 2025 -0800 Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019) Co-authored-by: Chang Liu <changliu@apple.com> commit ac642ea Author: Dongseong Hwang <dhwang2@apple.com> Date: Tue Feb 25 12:02:57 2025 -0800 Fix crash in log-mel frontend when waveform samples are integers. (apple#1017) After updating JAX, this existing hidden bug started causing CI failures. When the sample dtype is int32 (which is valid), `jnp.finfo` returns None, even though `jnp.iinfo` is available. The previous JAX version seemed to handle this case more forgivingly. ``` ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny)) ``` commit 7c64b55 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Tue Feb 25 10:52:50 2025 -0800 Add LoadBalancer to GKE replicatedJob (apple#1015) Co-authored-by: Liang (SPG) He <lhe27@apple.com> commit 8e8a41b Author: Chang Lan <c.lan@apple.com> Date: Tue Feb 25 10:26:59 2025 -0800 Expose jax.lax.scan's unroll option to Repeat layer (apple#1016) * Expose jax.lax.scan's unroll option to Repeat layer. * Defaults to None to avoid golden config changes commit 682bce6 Author: Dongseong Hwang <dhwang2@apple.com> Date: Tue Feb 25 10:09:41 2025 -0800 Handle None bias in BiasAndResidual (apple#1018) commit f053318 Author: Ruoming Pang <ruoming@gmail.com> Date: Mon Feb 24 11:12:23 2025 -0500 Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013) * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate(). * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate(). commit a93cd1b Author: Luzy <zhiyun@users.noreply.github.com> Date: Sat Feb 22 19:55:01 2025 -0500 fix dtype in frontend pre emphasis (apple#1014) commit c1fe2e9 Author: Maggie Zhang <jiya.zhang.98@gmail.com> Date: Fri Feb 21 19:07:39 2025 -0800 GoodPut minor fix: only process 0 should start goodput uploader (apple#984) * only process 0 will start goodput uploader * Add unit test commit 4b1fbf0 Author: Chang Lan <c.lan@apple.com> Date: Fri Feb 21 13:36:45 2025 -0800 Async context invocation for checkpointing (apple#1012) * Async context invocation supoprt for checkpointer * Add comment * Add comments commit d4cd158 Author: Ruoming Pang <ruoming@gmail.com> Date: Fri Feb 21 14:22:24 2025 -0500 Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011) * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor. * Fixes pytype. * Addresses review. commit 2ae6e66 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Fri Feb 21 09:22:35 2025 -0800 Enable megascale abort on hang or error (apple#1010) * Enable megascale_error_reporter_abort on hang and error by default * Increase threshold to 10m commit ce4b2fb Author: Chunyang Wen <chunyang.wen@gmail.com> Date: Fri Feb 21 23:12:38 2025 +0800 Add GPU monitor (apple#1006) commit baf8ad7 Author: Dongseong Hwang <dhwang2@apple.com> Date: Thu Feb 20 19:35:07 2025 -0800 Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009) commit cf41112 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Feb 20 16:29:13 2025 -0800 Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008) * Revert "gRPC Checkpointer (apple#1005)" This reverts commit d27c562. * Keep some changes commit 454bdba Author: Matthew Hopkins <matthew_e_hopkins@apple.com> Date: Thu Feb 20 15:10:51 2025 -0800 upgrade jax 0.4.38 (apple#1007) commit d27c562 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Tue Feb 18 18:54:53 2025 -0800 gRPC Checkpointer (apple#1005) commit fb90620 Author: Ruoming Pang <ruoming@gmail.com> Date: Tue Feb 18 21:08:38 2025 -0500 Makes file_system.glob support multiple patterns. (apple#1003) * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. commit 334f421 Author: Mark Lee <markblee@apple.com> Date: Tue Feb 18 17:03:39 2025 -0800 Reverts sliding window attention changes. (apple#1004) * Revert "Fix flash decoding in GPU. (apple#999)" This reverts commit fdadfd8. * Revert "Supports TPU context parallel training (apple#981)" This reverts commit e151d69. * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)" This reverts commit 67645d0. * Retain model/decoder asr changes. commit 3dacc6b Author: Chang Lan <c.lan@apple.com> Date: Mon Feb 17 19:45:00 2025 -0800 Refactor aot_compilation for reuse (apple#1000) commit c44fe18 Author: Ruoming Pang <ruoming@gmail.com> Date: Mon Feb 17 21:38:17 2025 -0500 Makes checkpointer_test.py use file_system. (apple#1001) commit fdadfd8 Author: Dongseong Hwang <dhwang2@apple.com> Date: Mon Feb 17 18:23:21 2025 -0800 Fix flash decoding in GPU. (apple#999) target_positions used to be time_step, but after PR apple#995, it now represents the actual target positions with shape [batch, step_len]. apple#995 Updating the GPU decoding code to align with this change. CI did not cover GPU unit tests. TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU commit 9e64388 Author: Ruoming Pang <ruoming@gmail.com> Date: Mon Feb 17 16:40:03 2025 -0500 Makes axlearn/cloud/ use file_system. (apple#998) * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency. * Adds testing for file_system.readfile. * Fixes pytype. * Makes axlearn/cloud use file_system instead of gfile. commit 5fba4ce Author: Chang Lan <c.lan@apple.com> Date: Mon Feb 17 09:44:10 2025 -0800 AOT compilation support for inference (apple#997) * Add optional `devices` init argument to InferenceRunner for passing fake devices during AOT compilation. * Add more v5e slice types. commit e151d69 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Sun Feb 16 13:01:15 2025 -0800 Supports TPU context parallel training (apple#981) Fix Fix tests commit 67645d0 Author: Dongseong Hwang <dhwang2@apple.com> Date: Sat Feb 15 13:26:51 2025 -0800 Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995) * Revert "Transpose kv cache for better decode performance (apple#979)" This reverts commit b130416. * Update golden configs * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for sliding window attention, the KV cache is kept for the full sequence length (`seq_len`) instead of the window length (`window_len`). For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache for the full 2M tokens. It then biases 1999k invalid KV tokens before calculating attention, resulting in a computational complexity of **O(2M²)** instead of the desired **O(1k²)**. This issue persists even when using flash attention. Flash attention uses the KV cache allocated in HBM as its input. While unnecessary blocks are discarded during computation, the KV cache still occupies HBM inefficiently for the full 2M tokens. To address this, when `MultiheadAttention` detects a sliding window mask, it stores the key-value (KV) cache in a ring buffer inside the input linear layer. As a result, downstream projects using `MultiheadAttention` automatically benefit from efficient KV cache handling in `init_states` and `extend_step`. Additionally, for use cases like local-global attention in LLMs, it is recommended to use sliding window masks for even the global attention as well. For example, if you want to train an LLM with a context length of 8k, you can set the sliding window size to 8k during training. This enables functionally infinite decoding during inference. Accuracy wouldn't be good tho. Note: * query_positions in QKVLinear.forward() was introduced by apple#914. Now it returns to the caller. This PR actually moves from downstream speech/streaming/sliding_window_attention.py * transpose commit 272a4d2 Author: Chang Lan <c.lan@apple.com> Date: Fri Feb 14 10:48:21 2025 -0800 Add v5e-8 (apple#994) commit debb46a Author: Mark Lee <markblee@apple.com> Date: Thu Feb 13 22:27:28 2025 -0800 Decouples jobsets from replicated jobs. (apple#991) * Decouples jobsets from replicated jobs. * Address comments. commit 8f2b99d Author: Maggie Zhang <jiya.zhang.98@gmail.com> Date: Thu Feb 13 20:49:48 2025 -0800 Add Goodput documentation (apple#989) * Temporarily change checkpointing to every 5 steps * revert local changes * Add example command for goodput usage commit 31e8da0 Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Thu Feb 13 15:44:51 2025 -0800 Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987) commit 7f2dd9e Author: Apoorv Gupta <apoorvgu@amazon.com> Date: Thu Feb 13 14:36:11 2025 -0800 Flash Attention for Neuron (apple#939) commit 6ca4f56 Author: Philipp Dufter <philipp.dufter@gmail.com> Date: Thu Feb 13 23:09:28 2025 +0100 pass on log_warning in input_tf_data.skip_on_error (apple#990) * make log_warnings customizable in tfds skip error * address comments commit 1a8a0eb Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Feb 13 13:32:34 2025 -0800 Integrate Orbax's emergency checkpoint. (apple#820) * Integrate Orbax emergency checkpoint * Address comments * comment * Address comments * Upgrade orbax * Improve comments * Improve comments * Update for new orbax versions * Better timer * Address comments * Add step test * Fix * Add comment commit 42fd715 Author: Apoorv Gupta <apoorvgu@amazon.com> Date: Thu Feb 13 09:13:30 2025 -0800 TRN2 Meshes and Configurations (apple#916) * TRN2 Meshes and Configurations * Add get_recursive and set_recursive to ConfigBase. * Use loops inside get/set_recursively + address comments * Update partition spec * Use get_recursively inside set * Move trn2 configs to a helper function. + Fix modifier tests * TRN2 partitionspec supports DP over FSDP and TP * Use for loop in get_recursively * Update Golden Configs commit d47d5ce Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Tue Feb 11 18:13:13 2025 -0800 Add support to slice dataset based on proportions. (apple#982) commit ed8f382 Author: Mark Lee <markblee@apple.com> Date: Tue Feb 11 13:22:44 2025 -0800 Allow metrics layers to have state. (apple#978) * Allow metrics layers to have state. * Move BaseLossMetrics to a new file. commit b130416 Author: Chang Lan <c.lan@apple.com> Date: Tue Feb 11 00:01:28 2025 -0800 Transpose kv cache for better decode performance (apple#979) commit 48bf488 Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Mon Feb 10 22:25:18 2025 -0800 Add support for grain.IterDataset in sampling (apple#980) commit d4b563c Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Mon Feb 10 15:36:22 2025 -0800 Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973) commit 0666d80 Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Mon Feb 10 15:35:23 2025 -0800 Fix membership checks in tool_use_execution.py (apple#974) commit 2f4763c Author: Alexander Pivovarov <apivovarov@gmail.com> Date: Mon Feb 10 15:31:59 2025 -0800 Remove redundant import logging (apple#975) commit 58dcf33 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Mon Feb 10 13:41:33 2025 -0800 Enable cudnn dropout (apple#913) commit ae855ed Author: Mark Lee <markblee@apple.com> Date: Mon Feb 10 12:43:50 2025 -0800 Ensures that cache_dtype is respected. (apple#977) commit cfef38b Author: Daniel Swann <daniel_swann@apple.com> Date: Mon Feb 10 10:56:10 2025 -0800 :sparkles: Add cache for CloudBuild API location queries (apple#967) commit 8fd9137 Author: Wei Liu <wliu.aiml@apple.com> Date: Sun Feb 9 15:33:53 2025 -0800 Add segment_ids option in DiTAttentionLayer (apple#976) commit e55a404 Author: Chang Lan <c.lan@apple.com> Date: Sun Feb 9 04:38:49 2025 -0800 Use broadcasting trick for KV update (apple#972) * Use vmap and dynamic_update_slice for KV update * Broadcasting trick * Simplify the impl per @markblee's suggestion * comments commit b955187 Author: Dongseong Hwang <dhwang2@apple.com> Date: Fri Feb 7 14:12:48 2025 -0800 Don't keep initial key/value inputs in the KV cache. (apple#968) The current code is weird. It stores the input key/value in the KV cache, but this doesn’t make sense in either init_states or prefill: * init_states: This is not prefill, so key/value should not be stored in the KV cache. * prefill: The extend_step() function overrides this part anyway. Thus, this PR removes this unnecessary and confusing logic. The logic was introduced in apple#860 commit c3d656d Author: zhengdong-zhang <zhengdong_zhang@apple.com> Date: Fri Feb 7 10:18:42 2025 -0800 Refactorization. (apple#963) commit 1c883d8 Author: Zhao Xu <1541199+zetaqubit@users.noreply.github.com> Date: Fri Feb 7 10:02:56 2025 -0800 Support system role when calling the Gemini API. (apple#971) commit ceab4f4 Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Thu Feb 6 20:41:07 2025 -0800 Making shared_memory configurable (apple#969) * Making shared_memory configurable * fix eol space commit 323faa3 Author: Meng (Ethan) Li <ethanli@apple.com> Date: Thu Feb 6 12:11:28 2025 -0800 Use env id for gcp settings (apple#957) * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone * fall back to zone * address comments * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly commit 2ec3a02 Author: Chang Lan <c.lan@apple.com> Date: Wed Feb 5 22:25:58 2025 -0800 Fix incorrect number of formatting arguments (apple#966) commit d131d3b Author: Nan Du <dunanbupt@gmail.com> Date: Mon Feb 3 11:44:00 2025 -0800 Reduce the verbosity of variable norm summaries (apple#965) commit c1c6e29 Author: Kelvin Zou <xuan_zou@apple.com> Date: Fri Jan 31 22:24:39 2025 -0800 Sliding window support for GPU flash attention (apple#962) * snapshot * snapshot * snapshot * remove unexpected change * adding shape commenbt * fix pylint * snapshot commit 0936a17 Author: Mark Lee <markblee@apple.com> Date: Fri Jan 31 13:59:12 2025 -0800 Supports loss_weights and live_targets in metrics. (apple#960) * Supports loss_weights, live_targets, and module sharing in metrics. * Addresses comments. * Explicitly test flatten_metrics=True. commit 7a40f91 Author: Dipannita Shaw <shaw.dipannita@gmail.com> Date: Fri Jan 31 11:45:33 2025 -0800 Add Goodput & Badput recording and monitoring support. (apple#783) * Code clean up * Add more testing * Fix docstrings * Remove recorder calls from trainer for now * Code cleanup gcp/measurement.py Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Code cleanup common/measurement.py Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Fix pre commit errors * Adding more tests * Further clean up * Fix a test error --------- Co-authored-by: Ruoming Pang <ruoming@gmail.com> commit 031a7f3 Author: Mark Lee <markblee@apple.com> Date: Thu Jan 30 20:19:12 2025 -0800 Skipping empty grain batches during unbatch. (apple#961) * Skipping empty grain batches during unbatch. * Use a loop instead of recursion. commit 795da33 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Jan 30 07:17:16 2025 -0800 Optimizer offloading through weight-only offload (apple#867) * Optimizer offloading * Style fix * Type fix commit b1a1a5a Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Wed Jan 29 21:44:15 2025 -0800 Improve gcsfuse io (apple#959) commit d76ef6f Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Jan 29 15:10:13 2025 -0800 SplashAttention performance tuning for v6e (apple#958) * SplashAttention tuning for v6e * Add import to fix pytype errors commit 2d002e3 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Jan 29 12:07:56 2025 -0800 Use InputDispatcher for fuji models (apple#956) * Use dispatcher * Update golden configs * Remove logical feed indices commit fad264b Author: Mark Lee <markblee@apple.com> Date: Tue Jan 28 10:41:54 2025 -0800 Explicitly pass module outputs to metrics. (apple#953) * Explicitly pass module outputs to metrics. * Support and add checks for module/state updates. * Only flatten summaries. commit 59508e3 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Tue Jan 28 10:34:52 2025 -0800 Add v6e PCIe overload workaround flag (apple#955) commit 028ecfd Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Mon Jan 27 20:54:28 2025 -0800 Fix GCSFUSE flags by setting resource limit. (apple#954) commit 3e2c6dd Author: Matthew Hopkins <matthew_e_hopkins@apple.com> Date: Mon Jan 27 14:56:42 2025 -0800 update jax to 0.4.37 (apple#948) update BlockSpec usage in tpu_attention use TYPE_CHECKING for BuildDatasetFn in input_fake add todo for BuildDatasetFn commit b125f00 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Mon Jan 27 11:29:23 2025 -0800 Add v6e special meshes (apple#952) * Add v6e special mesh * Add v6e special mesh * Fix * Fix commit a854738 Author: Firenze11 <im.lezhi@gmail.com> Date: Mon Jan 27 09:17:46 2025 -0800 Allow external positions to be inputed in RoPE embedding layer (apple#926) * Allow external positions to be inputed in RoPE embedding layer Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`. * Update attention_test.py * Update dit.py * Update attention.py * Update attention_test.py * Update attention.py * Update dit.py * Update axlearn/common/attention.py Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com> * respond to comments. Co-authored-by: Ruoming Pang <ruoming@gmail.com> * Update attention.py * Update attention.py * Update attention.py --------- Co-authored-by: Mark Lee <mmaarrkklleeee@gmail.com> Co-authored-by: Ruoming Pang <ruoming@gmail.com> commit 999401a Author: qdavid1 <168590940+qdavid1@users.noreply.github.com> Date: Mon Jan 27 09:11:17 2025 -0800 Update LoraFusedQKVLinear (apple#949) commit 1c22688 Author: Mark Lee <markblee@apple.com> Date: Sun Jan 26 04:51:02 2025 -0800 Workaround module outputs being dropped. (apple#951) commit 94c81cb Author: Meng (Ethan) Li <ethanli@apple.com> Date: Fri Jan 24 11:01:45 2025 -0800 Add link to github issue regarding kubernetes-32.0.0 (apple#947) commit a6e0f4a Author: Meng (Ethan) Li <ethanli@apple.com> Date: Fri Jan 24 08:40:25 2025 -0800 Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946) commit 076521a Author: Mark Lee <markblee@apple.com> Date: Thu Jan 23 15:11:00 2025 -0800 Forward input keys to decoder. (apple#944) commit 30284c8 Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Jan 23 10:33:54 2025 -0800 Legacy flash remat fix (apple#943) * Fix the same problem for legacy tpu attn * Fix commit 6a9f980 Author: Mark Lee <markblee@apple.com> Date: Thu Jan 23 09:20:46 2025 -0800 Adds mesh rule for a3-megagpu-8g. (apple#936) commit ac7a3ed Author: Dongseong Hwang <dhwang2@apple.com> Date: Thu Jan 23 08:15:27 2025 -0800 Enabled running Pallas Flash Attention on CPU. (apple#922) Pallas supports CPU simulation (`interpret=True`), so we can use the same TPU Pallas kernel on CPU — making code debugging easier. This change lets the following unittests run on CPU as if they were on TPU, enabling easier testing and debugging: - `axlearn/common/flash_attention/tpu_attention_test.py` Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU. - `axlearn/common/flash_attention/gpu_attention_test.py` Now CI covers those tests on CPU as well. In M3 Max MacBook Pro, test coverages and processing time are as follows, * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20) * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s commit 8ea85bd Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Wed Jan 22 09:51:15 2025 -0800 Some fixes for flash remat (apple#942) commit 185b1b5 Author: Chang Lan <c.lan@apple.com> Date: Tue Jan 21 11:21:08 2025 -0800 Repeat KV heads in Flash Attention (apple#938) * Roll back '_repeat_kv_heads' change in Flash Attention Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization, in the hope to reduce HBM usage. However the actual HBM saving would be limited in the model-parallel setting, as the heads are already sharded across devices. It also introduces some limitation which breaks some of the existing sharding configurations. For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads, we can set the model axis as 8 so that each device will have only one Q, K, V head; Without repeat_kv_heads, the max value of model axis is 4, and each device will have 2 Q heads as a result, increasing the actual HBM usage. * Repeat kv as necessary for sharding * Unit tests * Address comments. commit 4678740 Author: Chang Lan <c.lan@apple.com> Date: Mon Jan 20 20:36:44 2025 -0800 AOT compilation for v6e (apple#937) commit 357bef6 Author: Mark Lee <markblee@apple.com> Date: Mon Jan 20 20:23:39 2025 -0800 Makes causal lm metrics configurable. (apple#934) * Makes causal lm metrics configurable. * Address review comments. * Make metrics required. * Update golden configs. * Removes PredictModel. commit 16ca0c2 Author: Mark Lee <markblee@apple.com> Date: Sun Jan 19 14:19:20 2025 -0800 Supports flexible input partition specs. (apple#933) * Supports flexible input partition specs in causal lm. * Moves the input partitioning to Input. * Adds missing pytest marker. * Address review comments. * Rebase and update golden configs. * Fixes batch axis names and adds a test. commit 9b75ef1 Author: Mark Lee <markblee@apple.com> Date: Sun Jan 19 07:43:19 2025 -0800 Avoid a top-level import of tokenizers. (apple#935) commit 9996f34 Author: sychen52 <41452870+sychen52@users.noreply.github.com> Date: Sat Jan 18 09:44:04 2025 -0800 Add llama 3 tokenizer (apple#850) * Add llama 3 tokenizer add a new version called V3_TIKTOKEN. other edits based on suggestions. * Handle special tokens like other vocabularies. * use encode instead of encode_batch commit ad14de3 Author: Haoshuo Huang <haoshuo_huang@apple.com> Date: Fri Jan 17 14:19:24 2025 -0800 Add ReadOptions args to _make_autoregressive_inputs (apple#931) * Add ReadOptions args to _make_autoregressive_inputs * use read_options as args instead commit 4858070 Author: Sam Stoelinga <sammiestoel@gmail.com> Date: Fri Jan 17 13:54:05 2025 -0800 improve GCS perf: Change resource limit to request (apple#851) commit b0ee05e Author: Bailin <bailin.wang28@gmail.com> Date: Fri Jan 17 22:53:00 2025 +0800 Add Mamab2 and its Jamba variant (apple#839) * add mamab2 * merge * unify init and prefill * adapt final changes --------- Co-authored-by: bailin_wang <bwang47@apple.com> commit 1e25e4a Author: Hanzhi Zhou <hanzhi_zhou@apple.com> Date: Thu Jan 16 11:25:24 2025 -0800 Cache AoT compilation result (apple#927) * Cache AoT compilation result * Fix comments * Fix * Fix * Fix * Fix
Enabled running Pallas Flash Attention on CPU.
Pallas supports CPU simulation (
interpret=True
), so we can use the sameTPU Pallas kernel on CPU — making code debugging easier.
This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:
axlearn/common/flash_attention/tpu_attention_test.py
Similarly,
gpu_attention_test.py
can also be run on CPU as if they were on GPU.axlearn/common/flash_attention/gpu_attention_test.py
Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,