Skip to content

Commit 82460ad

Browse files
author
Sanggyu Lee
committed
Update tico/utils/validate_args_kwargs.py
1 parent a7bb828 commit 82460ad

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

tico/serialize/operators/op_attention.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from tico.serialize.operators.hashable_opcode import OpCode
2525
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
2626
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27-
27+
from tico.utils.validate_args_kwargs import CircleAttentionArgs
2828

2929

3030
@register_node_visitor
@@ -40,20 +40,7 @@ def define_node(
4040
self,
4141
node: torch.fx.Node,
4242
) -> circle.Operator.OperatorT:
43-
(
44-
hidden_states,
45-
wq,
46-
wk,
47-
wv,
48-
wo,
49-
position_cos,
50-
position_sin,
51-
attention_mask,
52-
past_key,
53-
past_value,
54-
cache_position,
55-
) = node.args
56-
43+
args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
5744
op_index = get_op_index(
5845
circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
5946
)

tico/utils/validate_args_kwargs.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,34 @@ class CatArgs:
171171
dim: int = 0
172172

173173

174+
@enforce_type
175+
@dataclass
176+
class CircleAttentionArgs:
177+
"""
178+
For circle.BuiltinOperator.BuiltinOperator.RMS_NORM
179+
"""
180+
181+
182+
@enforce_type
183+
@dataclass
184+
class CircleAttentionArgs:
185+
"""
186+
For circle.BuiltinOperator.BuiltinOperator.ATTENTION
187+
"""
188+
189+
hidden_states: torch.fx.Node
190+
wq: torch.fx.Node
191+
wk: torch.fx.Node
192+
wv: torch.fx.Node
193+
wo: torch.fx.Node
194+
position_cos: torch.fx.Node
195+
position_sin: torch.fx.Node
196+
attention_mask: torch.fx.Node
197+
past_key: torch.fx.Node
198+
past_value: torch.fx.Node
199+
cache_position: torch.fx.Node
200+
201+
174202
@enforce_type
175203
@dataclass
176204
class CircleRMSNormArgs:

0 commit comments

Comments
 (0)