Skip to content

Commit 0c8b4cd

Browse files
author
Sanggyu Lee
committed
Update layer.py
1 parent 0d7a67f commit 0c8b4cd

File tree

1 file changed

+5
-3
lines changed
  • test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention

1 file changed

+5
-3
lines changed

test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626

2727
from tico.utils.record_input import RecordingInput
2828

29+
target_model = model.model.layers[0]
2930
condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0
3031

31-
with torch.no_grad(), RecordingInput(model.model.layers[0], condition_fn) as rec:
32+
with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec:
3233
outputs = model.generate(
3334
**inputs,
3435
max_new_tokens=32,
@@ -123,12 +124,13 @@ def forward_adapter(
123124
)
124125

125126

126-
LlamaAttention.forward = forward_adapter
127-
128127
# Tico
129128
import tico
130129

131130
model = AutoModelForCausalLM.from_pretrained(model_name)
131+
132+
LlamaAttention.forward = forward_adapter
133+
132134
model.eval()
133135
circle_model = tico.convert(model.model.layers[0], captured_input)
134136
circle_model.save(f"tinyllama.layer.attn.circle")

0 commit comments

Comments
 (0)