22prompt = "Lily picked up a flower."
33model_name = "Maykeye/TinyLLama-v0"
44
5- captured_input = ()
6-
7- import copy , inspect , types
8-
9- from transformers .models .llama .modeling_llama import LlamaDecoderLayer
10-
11- forward_org = LlamaDecoderLayer .forward
12-
13-
14- def capture_and_forward (self , * args , ** kwargs ):
15- global captured_input
16-
17- # Prepare args tuple for TICO.convert()
18- # Get arg_names in positional args order using inspect
19- sig = inspect .signature (forward_org )
20- args_names = [
21- # signature includes `self`` and `kwargs``.
22- # Just retrieve the ordinary positional inputs only
23- name
24- for name in sig .parameters .keys ()
25- if name
26- not in ("self" , "kwargs" , "use_cache" , "position_ids" , "output_attentions" )
27- ]
28-
29- args_dict = dict (zip (args_names , args ))
30- args_dict .update (kwargs )
31-
32- def populate_args (args_dict , filter ):
33- for key in filter :
34- args_dict .pop (key , None )
35- args_tuple = tuple (args_dict .get (name , None ) for name in args_names )
36- return copy .deepcopy (args_tuple )
37-
38- if args_dict ["past_key_value" ].get_seq_length () != 0 and captured_input == ():
39- input_to_remove = []
40- captured_input = populate_args (args_dict , input_to_remove )
41-
42- return forward_org (self , * args , ** kwargs )
43-
44-
455# Tokenizer
466from transformers import AutoTokenizer
477
@@ -56,30 +16,35 @@ def populate_args(args_dict, filter):
5616 truncation = True ,
5717)
5818
59-
6019# Generator
6120import torch
6221
6322from transformers import AutoModelForCausalLM
6423
6524model = AutoModelForCausalLM .from_pretrained (model_name )
6625model .eval ()
67- model .model .layers [0 ].forward = types .MethodType (
68- capture_and_forward , model .model .layers [0 ]
69- )
70- with torch .no_grad ():
26+
27+ from tico .utils .record_input import RecordingInput
28+
29+ target_model = model .model .layers [0 ]
30+ condition_fn = lambda args_dict : args_dict ["past_key_value" ].get_seq_length () != 0
31+
32+ with torch .no_grad (), RecordingInput (target_model , condition_fn ) as rec :
7133 outputs = model .generate (
7234 ** inputs ,
7335 max_new_tokens = 32 ,
7436 do_sample = False ,
7537 pad_token_id = tokenizer .eos_token_id ,
7638 )
39+ captured_input = rec .captured_input
40+
7741generated_text = tokenizer .decode (outputs [0 ], skip_special_tokens = True )
7842print (generated_text )
7943
44+
8045# ATTENTION FUSER
8146
82- from typing import List , Optional , Tuple
47+ from typing import Any , List , Optional , Tuple
8348
8449
8550@torch .library .impl ("circle::attention.llama" , "CPU" )
@@ -160,7 +125,6 @@ def forward_adapter(
160125
161126
162127# Tico
163-
164128import tico
165129
166130from torch import nn
@@ -179,11 +143,13 @@ def forward(
179143 self ,
180144 hidden_states : torch .Tensor ,
181145 attention_mask : Optional [torch .Tensor ] = None ,
182- past_key_values : Optional [Cache ] = None ,
146+ position_ids : Optional [torch .LongTensor ] = None ,
147+ past_key_value : Optional [Cache ] = None ,
148+ output_attentions : Optional [bool ] = False ,
149+ use_cache : Optional [bool ] = False ,
183150 cache_position : Optional [torch .LongTensor ] = None ,
184- position_embeddings : Optional [
185- Tuple [torch .Tensor , torch .Tensor ]
186- ] = None , # necessary, but kept here for BC
151+ position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
152+ ** kwargs : Any ,
187153 ) -> Tuple [
188154 torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]
189155 ]:
@@ -192,7 +158,7 @@ def forward(
192158 layer_outputs = decoder_layer (
193159 hidden_states ,
194160 attention_mask = attention_mask ,
195- past_key_value = past_key_values ,
161+ past_key_value = past_key_value ,
196162 cache_position = cache_position ,
197163 position_embeddings = position_embeddings ,
198164 )
@@ -205,4 +171,4 @@ def forward(
205171LlamaAttention .forward = forward_adapter
206172layers .eval ()
207173circle_model = tico .convert (layers , captured_input )
208- circle_model .save (f"tinyllama.model .attn.circle" )
174+ circle_model .save (f"tinyllama.layers .attn.circle" )
0 commit comments