-
Couldn't load subscription status.
- Fork 87
Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning #2461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning #2461
Changes from 2 commits
7bd391d
f0f41a8
189d0c8
758e92d
d4a8c57
30faab7
01e37b3
912a80b
fd95719
19d2656
0742db2
2772f77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||
| import onnx_ir as ir | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import onnxscript.onnx_types as _onnx_types | ||||||||||||||||||||||||||||
| import onnxscript.rewriter._fusion_utils as _fusion_utils | ||||||||||||||||||||||||||||
| from onnxscript.rewriter import _basics, _ir_utils, pattern | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -354,9 +355,163 @@ | |||||||||||||||||||||||||||
| _outputs=3, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): | ||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you use the docstring to document the pattern and its replacement? For the branches A, B, and C, I would consider giving them descriptive names. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The following is my understanding: if this is correct, maybe they can be renamed appropriately: I believe that A constructs the kv_range, B constructs the query_range, and C constructs the batch_range. Each constructs the corresponding range as a 4D tensor with 1s in other position (for constructing a final attention-mask of shape [Batch, NumHeads, QueryRange, KVRange] via broadcast). I am a bit puzzled that query_range and kv_range look to be the same here, it might be an artifact of this model-usage, I guess. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure what the branches referred to but I'll make changes following what Rama is suggesting. |
||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||
| super().__init__("LongRoPeGQACausalMask", remove_nodes=False) | ||||||||||||||||||||||||||||
| self._mask_cache = {} | ||||||||||||||||||||||||||||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The copilot review is reasonable: the rewrite rule class should be stateless. Is there a different way to do this other than keeping a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think use of state for this purpose is okay? It has been used before for a similar purpose: which is to introduce values that are reused across multiple rewrites. (Now that we have CSE, there is an alternative path, which is to create duplicate copies and then eliminate them via CSE ... but I am not sure it is worth the bother.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW: my GQA fusion doesn't use state, and produces multiple copies (as described above). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My concern is that the states will transfer from model to another if not careful, which is probably not a good idea. Maybe we can have a class managed state dict that will be cleared by the class? |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
| def _get_mask_key(self, attention_mask): | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| Generate a unique key for the mask based on input_ids and past_kv_cache. | ||||||||||||||||||||||||||||
| This is used to cache the mask to avoid recomputation. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
| return (id(attention_mask)) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
| Generate a unique key for the mask based on input_ids and past_kv_cache. | |
| This is used to cache the mask to avoid recomputation. | |
| """ | |
| return (id(attention_mask)) | |
| Generate a unique key for the mask based on the content of attention_mask. | |
| This is used to cache the mask to avoid recomputation. | |
| """ | |
| if isinstance(attention_mask, np.ndarray): | |
| return hash(attention_mask.tobytes()) | |
| elif isinstance(attention_mask, (list, tuple)): | |
| return hash(tuple(attention_mask)) | |
| else: | |
| raise TypeError("Unsupported type for attention_mask: {}".format(type(attention_mask))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a cache is used, it should be cleaned up like in this example so that it is not carried over from one graph/model to another
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I am not sure if we need to handle np arrays? If the key is either one or two ir.Values, that should be fine ... ir.Values can be used as keys in dictionaries directly, and that should avoid the garbage-collection problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree _get_mask_key seems unecessary. We can use the Value objects directly as keys.
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
Check warning
Code scanning / lintrunner
RUFF/UP037 Warning
See https://docs.astral.sh/ruff/rules/quoted-annotation
Check failure
Code scanning / lintrunner
RUFF/F821 Error
See https://docs.astral.sh/ruff/rules/undefined-name
Check warning
Code scanning / lintrunner
RUFF/UP037 Warning
See https://docs.astral.sh/ruff/rules/quoted-annotation
Check failure
Code scanning / lintrunner
RUFF/F821 Error
See https://docs.astral.sh/ruff/rules/undefined-name
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rewriter doesn't use onnxscript type (yet). Could you instead use a comment to document the shape of the attention_mask?
Check notice
Code scanning / CodeQL
Unused local variable Note
Check notice
Code scanning / CodeQL
Unused local variable Note
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W291 Warning
See https://docs.astral.sh/ruff/rules/trailing-whitespace
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) | |
| seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"]) |
prefer snake case for variable names when possible
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://docs.astral.sh/ruff/rules/unused-variable
Fixed
Show fixed
Hide fixed
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
Outdated
Copilot
AI
Jul 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number 262144 should be defined as a named constant to improve code readability and maintainability. Consider defining it as a class constant with a descriptive name.
| mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) | |
| mask_expanded_A_sub = op.Sub(mask_expanded_A, MASK_OFFSET, _outputs=["mask_expanded_A_sub"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to make it a pattern-variable, I think ... if I understand right, this is actually a magic sequence-length constant? Perhaps model-specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thoughts, I am guessing this is the window_size, which should become an attribute-parameter to the GQA op.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would document the branches in plain English for readers
Outdated
Copilot
AI
Jul 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented-out code should be removed if it's not needed, or properly implemented if it serves a purpose. Dead code reduces maintainability.
| #mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) | |
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
Check warning
Code scanning / lintrunner
RUFF/W291 Warning
See https://docs.astral.sh/ruff/rules/trailing-whitespace
Check warning
Code scanning / CodeQL
Variable defined multiple times Warning
redefined
Outdated
Copilot
AI
Jul 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gqa_rules variable is being reassigned, which overwrites the previous assignment on line 514. This means the first assignment gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) is completely ignored.
| gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) |
Check warning
Code scanning / lintrunner
RUFF/W292 Warning
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_onnx_types is incompatible with the rewriter (yet)