Skip to content

Commit 00596a2

Browse files
author
Sanggyu Lee
committed
Rename model.py to layers.py
1 parent 0c8b4cd commit 00596a2

File tree

1 file changed

+19
-53
lines changed
  • test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention

1 file changed

+19
-53
lines changed

test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py renamed to test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py

Lines changed: 19 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,6 @@
22
prompt = "Lily picked up a flower."
33
model_name = "Maykeye/TinyLLama-v0"
44

5-
captured_input = ()
6-
7-
import copy, inspect, types
8-
9-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
10-
11-
forward_org = LlamaDecoderLayer.forward
12-
13-
14-
def capture_and_forward(self, *args, **kwargs):
15-
global captured_input
16-
17-
# Prepare args tuple for TICO.convert()
18-
# Get arg_names in positional args order using inspect
19-
sig = inspect.signature(forward_org)
20-
args_names = [
21-
# signature includes `self`` and `kwargs``.
22-
# Just retrieve the ordinary positional inputs only
23-
name
24-
for name in sig.parameters.keys()
25-
if name
26-
not in ("self", "kwargs", "use_cache", "position_ids", "output_attentions")
27-
]
28-
29-
args_dict = dict(zip(args_names, args))
30-
args_dict.update(kwargs)
31-
32-
def populate_args(args_dict, filter):
33-
for key in filter:
34-
args_dict.pop(key, None)
35-
args_tuple = tuple(args_dict.get(name, None) for name in args_names)
36-
return copy.deepcopy(args_tuple)
37-
38-
if args_dict["past_key_value"].get_seq_length() != 0 and captured_input == ():
39-
input_to_remove = []
40-
captured_input = populate_args(args_dict, input_to_remove)
41-
42-
return forward_org(self, *args, **kwargs)
43-
44-
455
# Tokenizer
466
from transformers import AutoTokenizer
477

@@ -56,30 +16,35 @@ def populate_args(args_dict, filter):
5616
truncation=True,
5717
)
5818

59-
6019
# Generator
6120
import torch
6221

6322
from transformers import AutoModelForCausalLM
6423

6524
model = AutoModelForCausalLM.from_pretrained(model_name)
6625
model.eval()
67-
model.model.layers[0].forward = types.MethodType(
68-
capture_and_forward, model.model.layers[0]
69-
)
70-
with torch.no_grad():
26+
27+
from tico.utils.record_input import RecordingInput
28+
29+
target_model = model.model.layers[0]
30+
condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0
31+
32+
with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec:
7133
outputs = model.generate(
7234
**inputs,
7335
max_new_tokens=32,
7436
do_sample=False,
7537
pad_token_id=tokenizer.eos_token_id,
7638
)
39+
captured_input = rec.captured_input
40+
7741
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
7842
print(generated_text)
7943

44+
8045
# ATTENTION FUSER
8146

82-
from typing import List, Optional, Tuple
47+
from typing import Any, List, Optional, Tuple
8348

8449

8550
@torch.library.impl("circle::attention.llama", "CPU")
@@ -160,7 +125,6 @@ def forward_adapter(
160125

161126

162127
# Tico
163-
164128
import tico
165129

166130
from torch import nn
@@ -179,11 +143,13 @@ def forward(
179143
self,
180144
hidden_states: torch.Tensor,
181145
attention_mask: Optional[torch.Tensor] = None,
182-
past_key_values: Optional[Cache] = None,
146+
position_ids: Optional[torch.LongTensor] = None,
147+
past_key_value: Optional[Cache] = None,
148+
output_attentions: Optional[bool] = False,
149+
use_cache: Optional[bool] = False,
183150
cache_position: Optional[torch.LongTensor] = None,
184-
position_embeddings: Optional[
185-
Tuple[torch.Tensor, torch.Tensor]
186-
] = None, # necessary, but kept here for BC
151+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
152+
**kwargs: Any,
187153
) -> Tuple[
188154
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
189155
]:
@@ -192,7 +158,7 @@ def forward(
192158
layer_outputs = decoder_layer(
193159
hidden_states,
194160
attention_mask=attention_mask,
195-
past_key_value=past_key_values,
161+
past_key_value=past_key_value,
196162
cache_position=cache_position,
197163
position_embeddings=position_embeddings,
198164
)
@@ -205,4 +171,4 @@ def forward(
205171
LlamaAttention.forward = forward_adapter
206172
layers.eval()
207173
circle_model = tico.convert(layers, captured_input)
208-
circle_model.save(f"tinyllama.model.attn.circle")
174+
circle_model.save(f"tinyllama.layers.attn.circle")

0 commit comments

Comments
 (0)