Add IndexCache support for GLM5 DSA#45424
Conversation
| self.is_nextn = config.is_nextn | ||
| if self.is_nextn: | ||
| self.skip_topk = False | ||
| self.next_skip_topk = False | ||
| else: | ||
| self.index_topk_freq = config.index_topk_freq | ||
| self.index_topk_pattern = config.index_topk_pattern | ||
| if self.index_topk_pattern is None: | ||
| self.skip_topk = max(layer_idx - 1, 0) % self.index_topk_freq != 0 | ||
| self.next_skip_topk = layer_idx % self.index_topk_freq != 0 | ||
| else: | ||
| self.skip_topk = self.index_topk_pattern[layer_idx] == "S" | ||
| if layer_idx < len(self.index_topk_pattern) - 1: | ||
| self.next_skip_topk = self.index_topk_pattern[layer_idx + 1] == "S" | ||
| else: | ||
| self.next_skip_topk = False | ||
|
|
There was a problem hiding this comment.
all of this should never happen here. You should jsut be doing self.is_nextn = config.is_next_n[layer_idx].
This makes it explicit which layers are skipping topK, and which are not!
There was a problem hiding this comment.
Agreed! I'll refactor to use is_next_n: List[bool] in Config (similar to mlp_type ) instead of the complex derivation logic in init . Much cleaner.
| index_topk_freq: int = 1 | ||
| index_topk_pattern: str | None = None |
There was a problem hiding this comment.
I kept index_topk_freq and index_topk_pattern to align with the IndexCache paper terminology (Shared vs Full patterns). However, as per your first suggestion, these will only be used in the Config to construct the is_next_n list—there won't be any derivation logic in the layer init . The layer will simply read config.is_next_n[layer_idx] .
There was a problem hiding this comment.
yep its much simpler, explicit and aligned with what we try to have !
| if self.next_skip_topk is None: | ||
| return attn_output, attn_weights | ||
| else: | ||
| if self.next_skip_topk: | ||
| return attn_output, attn_weights, topk_indices | ||
| else: | ||
| return attn_output, attn_weights, None | ||
|
|
There was a problem hiding this comment.
let's always return topk maybe? let's simmplify our life
There was a problem hiding this comment.
My concern is that the original implementation only returned 2 values, so forcing a 3-tuple return might break backward compatibility for existing code that expects (output, weights) . However, I agree that a consistent API is cleaner, so I'll refactor to always return 3 values and handle the compatibility aspect properly.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
cool that looks simpler!
| self.index_topk_pattern = "F" * self.num_hidden_layers | ||
| else: | ||
| self.index_topk_pattern = "".join( | ||
| "F" if (i == 0 or (i - 1) % self.index_topk_freq == 0) else "S" |
There was a problem hiding this comment.
IDK what F stands for! + typing should reflex pattern should be tuple no? you can default it to `("skip", "meaningful name", ...)
There was a problem hiding this comment.
"F" stands for Full (layers that run the indexer independently to compute top-k indices) and "S" stands for Shared (layers that reuse cached indices from the nearest preceding Full layer). We keep this as a string pattern (e.g., "FFSF...") rather than a tuple to maintain consistency with the IndexCache paper (GLM-5) and the existing implementations in SGLang and vLLM.
There was a problem hiding this comment.
IDK what F stands for! + typing should reflex pattern should be tuple no? you can default it to `("skip", "meaningful name", ...)
Done! Changed from string pattern "FSFS..." to list format ["full", "shared", ...] with explicit naming. The generation logic now uses max(i - 1, 0) % freq to match the official IndexCache implementation exactly, and the type annotation is updated to list[str]. This should be much clearer while maintaining consistency with the reference implementation.
b43c161 to
b246219
Compare
Moves index_topk_pattern generation from Attention.__init__ to Config.__post_init__ as suggested. Layers now simply check `config.index_topk_pattern[layer_idx]` instead of computing skip conditions, matching the mlp_layer_types pattern for consistent explicit configuration.
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM thanks for adding this!
| if self.index_topk_pattern is None: | ||
| self.index_topk_pattern = [ |
There was a problem hiding this comment.
can we just use something similar to layer_types? instead of freq + pattern we just have a list that we default to the pattern? 🤗
There was a problem hiding this comment.
can we just use something similar to
layer_types? instead of freq + pattern we just have a list that we default to the pattern? 🤗
Updated as suggested:
if self.indexer_types is None:
pattern = kwargs.pop("index_topk_pattern", None)
freq = kwargs.pop("index_topk_freq", 1)
if pattern is not None:
self.indexer_types = [{"F": "full", "S": "shared"}[c] for c in pattern] if isinstance(pattern, str) else list(pattern)
else:
self.indexer_types = ["full" if (max(i - 1, 0) % freq) == 0 else "shared" for i in range(self.num_hidden_layers)]The legacy fallbacks are kept because the official IndexCache repo's patches for vLLM and SGLang currently expose these exact kwargs to end users. For example, in SGLang users launch with:
--json-model-override-args '{"index_topk_freq": 2}'
# or
--json-model-override-args '{"index_topk_pattern": "FFSFSSSFSSFFFSSSFFFSFSSSSSSFFSFFSFFSSFFFFFFSFFFFFSFFSSSSSSFSFFFSFSSSFSFFSFFSSS"}'And in vLLM:
--hf-overrides '{"index_topk_freq": 2}'
# or
--hf-overrides '{"index_topk_pattern": "FFSF..."}'The official README documents index_topk_freq and index_topk_pattern as the two configuration parameters for both engines . Removing them outright would break existing deployments that rely on these patches. New usage can pass indexer_types directly; the old args are deprecated and only consulted as fallbacks.
If this looks good, I'll push the commit shortly.
There was a problem hiding this comment.
yeah of course! sounds
|
@louzongzhi is this ready for merge? let us know 🤗 |
Please give me a moment. Installing TileLang messed up my environment, so I'm reconfiguring it now. I'll submit the commit shortly. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: glm_moe_dsa |
|
@vasqu @ArthurZucker All done. indexer_types is in place with backward-compatible fallback for index_topk_pattern/index_topk_freq, and modeling references are updated. Please take a look. |
vasqu
left a comment
There was a problem hiding this comment.
LGTM, thanks for iterating, merging in a second
What does this PR do?
This PR implements IndexCache support for GLM5's DeepSeek Sparse Attention (DSA), enabling cross-layer index reuse to accelerate long-context inference.
IndexCache accelerates sparse attention by reusing top-k token indices across consecutive layers, removing ~75% of redundant indexer computations while maintaining accuracy.
Key implementation details:
index_topk_freq,index_topk_pattern, andis_nextntoGlmMoeDsaConfigfor flexible layer scheduling (Full/Shared pattern)skip_topk/next_skip_topklogic inGlmMoeDsaAttentionto determine whether to compute new indices or reuse previous layer's indicesprev_topk_indicesparameter propagation throughGlmMoeDsaDecoderLayerandGlmMoeDsaModelfor cross-layer index sharingtopk_indiceswhen enabledPerformance impact:
Reference: https://github.com/THUDM/IndexCache
Code Agent Policy
Before submitting
Who can review?
@ArthurZucker @Cyrilvallez @vasqu