-
Couldn't load subscription status.
- Fork 87
Added LongRoPe Model Causal Mask Pattern Fusion #2473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…xscript into longrope_causal_mask
| """ | ||
| Pattern for LongRoPe GQA Causal Mask. | ||
| This pattern computes the causal mask for Group Query Attention with LongRoPe. | ||
| It constructs the mask based on input_ids and past_kv_cache, and handles the |
Check notice
Code scanning / CodeQL
Unused local variable Note
| """ | ||
| Pattern for LongRoPe GQA Causal Mask. | ||
| This pattern computes the causal mask for Group Query Attention with LongRoPe. | ||
| It constructs the mask based on input_ids and past_kv_cache, and handles the |
Check notice
Code scanning / CodeQL
Unused local variable Note
| mask_key = _get_mask_key(attention_mask) | ||
|
|
||
| if mask_key in self._mask_cache: | ||
| total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
| mask_key = _get_mask_key(attention_mask) | ||
|
|
||
| if mask_key in self._mask_cache: | ||
| total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
| # Licensed under the MIT License. See License.txt in the project root for | ||
| # license information. | ||
| # -------------------------------------------------------------------------- | ||
| import onnx |
Check notice
Code scanning / CodeQL
Unused import Note
| # -------------------------------------------------------------------------- | ||
| import onnx | ||
| from onnxscript import ir | ||
| import onnx.helper |
Check notice
Code scanning / CodeQL
Unused import Note
| cache_length = self.rotemb_attrs["cache_length"] | ||
| position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length) | ||
|
|
||
| inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| with torch.autocast(device_type=device_type, enabled=False): | ||
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2) | ||
| emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim) | ||
| cos_cache = emb.cos() * attention_factor # (1, cache_length, dim) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"] | ||
|
|
||
| inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim | ||
| inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| if "rescale_inv_freq" in self.rotemb_attrs: | ||
| inv_freq = self.make_inv_freq_rescaled(inv_freq) | ||
|
|
||
| return inv_freq, attention_factor |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
…t#2465) Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Provide a way to indicate that a pattern-variable can match successfully against a None-valued input. Cleanup current handling which was inconsistent in one place. Add test cases. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This PR adds comprehensive documentation for the rewriter pattern
options that were previously undocumented. The rewriter pattern system
supports four key options for controlling pattern matching and
replacement behavior:
## New Documentation Added
### `_allow_other_inputs` option
- **File**: `docs/tutorial/rewriter/allow_other_inputs.md`
- **Purpose**: Controls whether patterns can match nodes with additional
inputs beyond those specified
- **Default**: `False` (exact input matching)
- **Example**: Matching `Conv` operations that may have optional bias
inputs
```python
def conv_pattern(op, input, weight):
# Matches Conv with 2 or 3 inputs (weight + optional bias)
return op.Conv(input, weight, _allow_other_inputs=True)
```
### `_domain` option
- **File**: `docs/tutorial/rewriter/domain_option.md`
- **Purpose**: Specifies operator domains for pattern matching and
replacement
- **Use cases**: Domain-specific rewrites, migrating between operator
domains
- **Example**: Targeting operations from specific domains like
"com.microsoft"
```python
def custom_relu_pattern(op, input):
# Only matches Relu from custom domain
return op.Relu(input, _domain="custom.domain")
```
### `_outputs` option
- **File**: `docs/tutorial/rewriter/outputs_option.md`
- **Purpose**: Specifies number and names of operation outputs
- **Formats**: Integer count (`_outputs=2`) or named list
(`_outputs=["first", "second"]`)
- **Example**: Handling multi-output operations like `Split`
```python
def split_pattern(op, input):
# Matches Split operations with exactly 2 outputs
return op.Split(input, num_outputs=2, axis=0, _outputs=2)
```
### Enhanced `_allow_other_attributes` documentation
- **File**: `docs/tutorial/rewriter/attributes.md` (improved formatting)
- **Already documented**: Controls whether patterns match nodes with
additional attributes
- **Default**: `True` (allows extra attributes)
## Documentation Structure Improvements
- Added "Pattern Options" section to main rewriter documentation
- Integrated all option docs into the tutorial flow
- Created working code examples for each option
- Followed existing documentation patterns and style
- All examples compile and run successfully
- Documentation builds correctly with Sphinx
The documentation now provides complete coverage of all rewriter pattern
options with practical examples showing real-world usage patterns.
Fixes microsoft#2405.
> [!WARNING]
>
> <details>
> <summary>Firewall rules blocked me from connecting to one or more
addresses</summary>
>
> #### I tried to connect to the following addresses, but was blocked by
firewall rules:
>
> - `docs.python.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `docs.scipy.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `matplotlib.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `numpy.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `onnx.ai`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `onnxruntime.ai`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `pytorch.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
>
> If you need me to access, download, or install something from one of
these locations, you can either:
>
> - Configure [Actions setup
steps](https://gh.io/copilot/actions-setup-steps) to set up my
environment, which run before the firewall is enabled
> - Add the appropriate URLs or hosts to my [firewall allow
list](https://gh.io/copilot/firewall-config)
>
> </details>
<!-- START COPILOT CODING AGENT TIPS -->
---
💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to
start the survey.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
In onnx2script, nan, inf etc. were converted to plain text, which causes evaluation to fail because they don't exist in the script. I updated the logic to replace them with np. values. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Simplify implementation for `aten_chunk` and allow it to work on all data types. Original author: @xadupre Updated: Conditionally use the new implementation when torch>=2.7 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2473 +/- ##
==========================================
- Coverage 69.81% 69.01% -0.81%
==========================================
Files 209 211 +2
Lines 25313 25978 +665
Branches 2525 2612 +87
==========================================
+ Hits 17673 17928 +255
- Misses 6762 7175 +413
+ Partials 878 875 -3 ☔ View full report in Codecov by Sentry. |
This PR introduces a specialized LongRoPe (Long Range Rotary Position Embedding) GQA (Group Query Attention) causal mask fusion rule specifically designed for Phi-4-mini-reasoning and similar models. The implementation optimizes attention mask computation for models using sliding window attention with LongRoPe position embeddings.
New LongRoPeGQACausalMask Class
Advanced Mask Computation
Note: This PR is meant to replace #2461 by introducing the requested changes.