-
Notifications
You must be signed in to change notification settings - Fork 22
Add attention operator and adapter for onert #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b59177e
fa45293
e94388d
2c49c53
275b6cf
c01470d
6a5b1f5
be57232
9281efc
ad53700
c882fe6
3233267
c817971
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # DO NOT REMOVE THIS FILE |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Dict, List, TYPE_CHECKING | ||
|
|
||
| import torch | ||
|
|
||
| from transformers.cache_utils import DynamicCache | ||
| from transformers.models.llama.modeling_llama import LlamaAttention | ||
|
|
||
|
|
||
| def llama_attention_forward_adapter( | ||
| self: LlamaAttention, | ||
| hidden_states: torch.Tensor, | ||
| position_embeddings: List[torch.Tensor], | ||
| attention_mask: torch.Tensor, | ||
| past_key_value: DynamicCache, | ||
| cache_position: torch.Tensor, | ||
| **kwargs, | ||
| ): | ||
| # past_key_value is a dict with key_cache and value_cache. | ||
| # It needs to be decomposed for tico and circle which does not know dict. | ||
| key_cache = past_key_value.key_cache # type: ignore[union-attr] | ||
| value_cache = past_key_value.value_cache # type: ignore[union-attr] | ||
| return ( | ||
| torch.ops.circle_custom.attention( | ||
| hidden_states, | ||
| self.q_proj.weight, | ||
| self.k_proj.weight, | ||
| self.v_proj.weight, | ||
| self.o_proj.weight, | ||
| position_embeddings[0], # cos | ||
| position_embeddings[1], # sin | ||
| attention_mask, | ||
| key_cache[self.layer_idx], | ||
| value_cache[self.layer_idx], # Same to value_cache | ||
| cache_position, | ||
| ), | ||
| None, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Dict, List, TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| import torch._ops | ||
| import torch.fx | ||
| import torch | ||
| from circle_schema import circle | ||
|
|
||
| from tico.serialize.circle_graph import CircleSubgraph | ||
| from tico.serialize.operators.hashable_opcode import OpCode | ||
| from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor | ||
| from tico.serialize.operators.utils import create_builtin_operator, get_op_index | ||
| from tico.utils.validate_args_kwargs import CircleAttentionArgs | ||
|
|
||
|
|
||
| @register_node_visitor | ||
| class AttentionVisitor(NodeVisitor): | ||
| target: List[torch._ops.OpOverload] = [ | ||
| torch.ops.circle_custom.attention.default, | ||
| ] | ||
|
|
||
| def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): | ||
| super().__init__(op_codes, graph) | ||
|
|
||
| def define_node( | ||
| self, | ||
| node: torch.fx.Node, | ||
| ) -> circle.Operator.OperatorT: | ||
| args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type] | ||
| op_index = get_op_index( | ||
| circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes | ||
| ) | ||
|
|
||
| inputs = node.args | ||
| outputs = [node] | ||
| operator = create_builtin_operator(self.graph, op_index, inputs, outputs) | ||
|
|
||
| # Op-specific option | ||
| operator.builtinOptionsType = ( | ||
| circle.BuiltinOptions.BuiltinOptions.AttentionOptions | ||
| ) | ||
| operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT() | ||
|
|
||
| return operator | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -727,6 +727,40 @@ def _( | |
| return hidden_states.new_empty(hidden_states.size()) | ||
|
|
||
|
|
||
| def CircleAttention(): | ||
| @custom_op("circle_custom::attention", mutates_args=()) | ||
| def attention( | ||
| hidden_states: torch.Tensor, | ||
| wq: torch.Tensor, | ||
| wk: torch.Tensor, | ||
| wv: torch.Tensor, | ||
| wo: torch.Tensor, | ||
| position_cos: torch.Tensor, | ||
| position_sin: torch.Tensor, | ||
| attention_mask: torch.Tensor, | ||
| past_key: torch.Tensor, | ||
| past_value: torch.Tensor, | ||
| cache_position: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| return None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @glistening If you'd like to enable some tests on TICO, it could be done by implementing this function and add simple 'attention' function with KV cache. TICO also uses onert nightly, so it's testable. # FILE: test/requirements_pre.txt
onert==0.2.0.dev250922
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no idea to test attention in TICO. It is beyond the scope of TICO. I edited the title to specify the scope of this PR. If it is meaningful, prefill phase of tinyllama (which has no fused-attention) may be testable usign |
||
|
|
||
| @register_fake("circle_custom::attention") | ||
| def _( | ||
| hidden_states: torch.Tensor, | ||
| wq: torch.Tensor, | ||
| wk: torch.Tensor, | ||
| wv: torch.Tensor, | ||
| wo: torch.Tensor, | ||
| position_cos: torch.Tensor, | ||
| position_sin: torch.Tensor, | ||
| attention_mask: torch.Tensor, | ||
| past_key: torch.Tensor, | ||
| past_value: torch.Tensor, | ||
| cache_position: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| return hidden_states | ||
|
|
||
|
|
||
| # Add custom ops to the torch namespace | ||
| def RegisterOps(): | ||
| CircleResizeNearestNeighbor() | ||
|
|
@@ -740,3 +774,4 @@ def RegisterOps(): | |
| CircleInstanceNorm() | ||
| CircleQuantizeMX() | ||
| CircleRMSNorm() | ||
| CircleAttention() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinevening Are you okay for this location? You told me that you prefer seperate directory for onert-only operators like
op_attention.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The location can be like below.
tico/serialize/operators/tico/utils/register_custom_op.pytico/serialize/operators/adapters/There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put adapters in the current directory (
tico/serialize/operators/adapters/onert) and place other codes according to the @mhs4670go 's suggestion?adaptersmay be open to users, so I think it would be good to specify "This adapter is only for onert".Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is exactly same to my understanding. Thank you for confirming.
As I wrote in #400 (comment), I will follow TICO's way like #266.
(Though I don't think it is a good way to put all operators in a single file — register_custom_op.py)
(It is 740 lines and will grow and grow as supported operators increase.)
(It would be good to give each operator its own file, and let register_custom_op.py do registering only.)
Anyway, again, I will follow TICO's current way.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinevening I totally agree.
adapterswill vary per models.I am trying encoder-decoder model (for translation task).
I will sugegst adapters code structure after finishing
encoder-decodermodel, which requires another adpater (i.e.TRIVMultiheadAttention).