File tree Expand file tree Collapse file tree 2 files changed +30
-15
lines changed
Expand file tree Collapse file tree 2 files changed +30
-15
lines changed Original file line number Diff line number Diff line change 2424from tico .serialize .operators .hashable_opcode import OpCode
2525from tico .serialize .operators .node_visitor import NodeVisitor , register_node_visitor
2626from tico .serialize .operators .utils import create_builtin_operator , get_op_index
27-
27+ from tico . utils . validate_args_kwargs import CircleAttentionArgs
2828
2929
3030@register_node_visitor
@@ -40,20 +40,7 @@ def define_node(
4040 self ,
4141 node : torch .fx .Node ,
4242 ) -> circle .Operator .OperatorT :
43- (
44- hidden_states ,
45- wq ,
46- wk ,
47- wv ,
48- wo ,
49- position_cos ,
50- position_sin ,
51- attention_mask ,
52- past_key ,
53- past_value ,
54- cache_position ,
55- ) = node .args
56-
43+ args = CircleAttentionArgs (* node .args , ** node .kwargs ) # type: ignore[arg-type]
5744 op_index = get_op_index (
5845 circle .BuiltinOperator .BuiltinOperator .ATTENTION , self ._op_codes
5946 )
Original file line number Diff line number Diff line change @@ -171,6 +171,34 @@ class CatArgs:
171171 dim : int = 0
172172
173173
174+ @enforce_type
175+ @dataclass
176+ class CircleAttentionArgs :
177+ """
178+ For circle.BuiltinOperator.BuiltinOperator.RMS_NORM
179+ """
180+
181+
182+ @enforce_type
183+ @dataclass
184+ class CircleAttentionArgs :
185+ """
186+ For circle.BuiltinOperator.BuiltinOperator.ATTENTION
187+ """
188+
189+ hidden_states : torch .fx .Node
190+ wq : torch .fx .Node
191+ wk : torch .fx .Node
192+ wv : torch .fx .Node
193+ wo : torch .fx .Node
194+ position_cos : torch .fx .Node
195+ position_sin : torch .fx .Node
196+ attention_mask : torch .fx .Node
197+ past_key : torch .fx .Node
198+ past_value : torch .fx .Node
199+ cache_position : torch .fx .Node
200+
201+
174202@enforce_type
175203@dataclass
176204class CircleRMSNormArgs :
You can’t perform that action at this time.
0 commit comments