-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use ONNX Rewriter and IR to simplify the mnb_to_qdq pass #1482
base: main
Are you sure you want to change the base?
Conversation
|
||
# Add Logic handling input 3 | ||
|
||
unpacked_weight_arrays = _unpack_weights( |
Check failure
Code scanning / lintrunner
RUFF/F821 Error
See https://docs.astral.sh/ruff/rules/undefined-name
@@ -7,8 +7,10 @@ | |||
from pathlib import Path | |||
from typing import TYPE_CHECKING, Any, Dict | |||
|
|||
import ml_dtypes |
Check warning
Code scanning / lintrunner
PYLINT/W0611 Warning
See unused-import.
@@ -7,8 +7,10 @@ | |||
from pathlib import Path | |||
from typing import TYPE_CHECKING, Any, Dict | |||
|
|||
import ml_dtypes |
Check warning
Code scanning / lintrunner
RUFF/F401 Warning
See https://docs.astral.sh/ruff/rules/unused-import
return False | ||
g_idx = g_idx.constant_value.numpy() | ||
trivial_g_idx = np.arange(k, dtype=np.int32) // block_size | ||
if not np.array_equal(g_idx, trivial_g_idx): |
Check warning
Code scanning / lintrunner
RUFF/SIM103 Warning
See https://docs.astral.sh/ruff/rules/needless-bool
g_idx = g_idx.constant_value.numpy() | ||
trivial_g_idx = np.arange(k, dtype=np.int32) // block_size | ||
if not np.array_equal(g_idx, trivial_g_idx): | ||
# TODO: We can log why the pattern is not matched here |
Check warning
Code scanning / lintrunner
RUFF/TD002 Warning
See https://docs.astral.sh/ruff/rules/missing-todo-author
matmul = op.Add(matmul, bias) | ||
return matmul | ||
|
||
replace_mat_mul_n_bits = orp.RewriteRule( |
Check warning
Code scanning / lintrunner
PYLINT/W0612 Warning
See unused-variable.
matmul = op.Add(matmul, bias) | ||
return matmul | ||
|
||
replace_mat_mul_n_bits = orp.RewriteRule( |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://docs.astral.sh/ruff/rules/unused-variable
graph: ir.Graph = context.graph | ||
return value in graph.initializers.values() | ||
|
||
def mat_mul_n_bits_pattern_check(context, *, q_weight, g_idx, mat_mul_n_bits_out: ir.Value, **_) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does q_weight here match for the input right before g_idx or it is whatever it is in the mat_mul_n_bits_pattern signature? The input before g_idx is qzero and can be optional. we want to check the second input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inputs of the pattern-function (mat_mul_n_bits_pattern) are bound to values in the graph, and these values are passed in as keyword-arguments to the rewrite function here. So, the order here doesn't really matter, though I usually just copy-paste and use the same argument list for both.
del node.meta["N"] | ||
|
||
# TODO(justinchuby): Register and remove initializers | ||
ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Use a more robust version conversion process
Describe your changes
Checklist before requesting a review
lintrunner -a
(Optional) Issue link