Skip to content

Commit eb7db29

Browse files
committed
added latest generalized generate code
Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com>
1 parent c31f317 commit eb7db29

File tree

5 files changed

+328
-97
lines changed

5 files changed

+328
-97
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ def __repr__(self):
6363
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}"
6464

6565

66+
@dataclass
67+
class CloudAI100ExecInfoNew:
68+
batch_size: int
69+
generated_ids: Union[List[np.ndarray], np.ndarray]
70+
perf_metrics: PerfMetrics
71+
72+
def __repr__(self):
73+
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)}\
74+
\nDecode token/sec is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)}\
75+
\nTotal token/sec is= {round(self.perf_metrics.total_perf * self.batch_size, 2)}\
76+
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}"
77+
78+
6679
io_files = []
6780

6881

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_output_names(self,):
6262
return output_names
6363

6464

65-
def generate_dummy_inputs(self, kv_offload: bool = False):
65+
def get_dummy_inputs(self, kv_offload: bool = False):
6666
if kv_offload:
6767
raise ValueError("kv_offload method not supported for InternVL yet!")
6868
NUM_CROPS = 13

QEfficient/transformers/models/modeling_auto.py

Lines changed: 104 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# ----------------------------------------------------------------------------
77

88
import hashlib
9-
import logging
109
import sys
1110
import warnings
1211
from pathlib import Path
@@ -29,7 +28,7 @@
2928
from QEfficient.base.modeling_qeff import QEFFBaseModel
3029
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
3130
from QEfficient.generation.cloud_infer import QAICInferenceSession
32-
from QEfficient.generation.text_generation_inference import get_compilation_dims
31+
from QEfficient.generation.text_generation_inference import CloudAI100ExecInfoNew, PerfMetrics, get_compilation_dims
3332
from QEfficient.transformers.models.pytorch_transforms import (
3433
CustomOpsTransform,
3534
KVCacheModuleMethodMapperTransform,
@@ -820,7 +819,7 @@ def export(
820819
export_dir: Optional[str] = None,
821820
**kwargs,
822821
) -> str:
823-
inputs = self.model.generate_dummy_inputs()
822+
inputs = self.model.get_dummy_inputs()
824823
dynamic_axes = self.model.get_onnx_dynamic_axes()
825824
output_names = self.model.get_output_names()
826825
self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
@@ -843,6 +842,7 @@ def compile(
843842
output_names = self.model.get_output_names()
844843

845844
# Get specializations from modelling file
845+
# TODO: expose this via the auto class as well
846846
specializations = self.model.get_specializations(batch_size=batch_size, prefill_seq_len=prefill_seq_len,
847847
ctx_len=ctx_len, img_size=img_size, **compiler_options)
848848

@@ -873,7 +873,6 @@ def compile(
873873
)
874874
return self.qpc_path
875875

876-
@property
877876
def get_onnx_dynamic_axes(self):
878877
return self.model.get_onnx_dynamic_axes()
879878

@@ -905,111 +904,121 @@ def generate(
905904
)
906905

907906
def cloud_ai_100_generate(
908-
self,
909-
inputs: torch.Tensor,
910-
device_ids: List[int],
911-
enable_debug_logs: bool = False,
912-
generation_len: int = None,
913-
streamer: Optional[TextStreamer] = None,
914-
generation_len: Optional[int] = None
915-
) -> np.ndarray:
916-
qpc_session = QAICInferenceSession(
917-
self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False
918-
)
919-
920-
batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path)
921-
922-
# Skip inputs/outputs
923-
qpc_session.skip_buffers(
924-
[x for x in (qpc_session.input_names + qpc_session.output_names) if x.startswith("past_")] + ["pixel_values_RetainedState"]
925-
)
926-
927-
# Read prompt and ctx len from session
928-
batch_size = max(
929-
[x[qpc_session.binding_index_map["input_ids"]][1][0] for x in qpc_session.allowed_shapes]
930-
+ [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[0]]
931-
)
932-
933-
prefill_seq_len = max(
934-
[x[qpc_session.binding_index_map["input_ids"]][1][1] for x in qpc_session.allowed_shapes]
935-
+ [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[1]]
936-
)
907+
self,
908+
inputs: torch.Tensor,
909+
device_ids: List[int],
910+
enable_debug_logs: bool = False,
911+
generation_len: int = None,
912+
streamer: Optional[TextStreamer] = None,
913+
) -> np.ndarray:
914+
qpc_session = QAICInferenceSession(
915+
self.qpc_path, device_ids, enable_debug_logs=enable_debug_logs, activate=False
916+
)
937917

938-
input_len = inputs["attention_mask"].sum(1, keepdims=True)
939-
padded_len = inputs["input_ids"].shape[1]
940-
num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float
941-
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
918+
batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path)
942919

943-
if generation_len is None:
944-
generation_len = ctx_len - input_len.max()
920+
pad_token_id = 1
945921

946-
assert generation_len > 0, "generation length should be greater than zero"
947-
generated_ids = np.full((batch_size, generation_len + 1), -1)
922+
# Skip inputs/outputs
923+
qpc_session.skip_buffers(
924+
[x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_") or x.endswith("_RetainedState")]
925+
)
948926

949-
# Prepare inputs for prefill
950-
start = perf_counter()
927+
# Read prompt and ctx len from session
928+
batch_size = max(
929+
[x[qpc_session.binding_index_map["input_ids"]][1][0] for x in qpc_session.allowed_shapes]
930+
+ [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[0]]
931+
)
951932

952-
inputs["position_ids"] = np.where(
953-
inputs.pop("attention_mask"), np.arange(padded_len), -1
954-
) # Need to use -1 as position_ids for invalid tokens
955-
inputs = dict(inputs)
933+
prefill_seq_len = max(
934+
[x[qpc_session.binding_index_map["input_ids"]][1][1] for x in qpc_session.allowed_shapes]
935+
+ [qpc_session.bindings[qpc_session.binding_index_map["input_ids"]].dims[1]]
936+
)
956937

957-
# vision_session.deactivate()
958-
qpc_session.activate()
938+
input_len = inputs["attention_mask"].sum(1, keepdims=True)
939+
input_ids_length = inputs["input_ids"].shape[1]
959940

960-
# Run prefill
961-
for i in range(num_chunks):
962-
chunk_inputs = inputs.copy()
963-
chunk_inputs["input_ids"] = inputs["input_ids"].numpy()[:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
964-
chunk_inputs['pixel_values'] = chunk_inputs['pixel_values'].numpy().astype(np.float16)
965-
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
966-
outputs = qpc_session.run(chunk_inputs)
941+
num_chunks = -(input_ids_length //-prefill_seq_len) # ceil divide without float
942+
943+
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
944+
if generation_len is None:
945+
generation_len = ctx_len - input_len.max()
946+
947+
assert generation_len > 0, "generation length should be greater than zero"
948+
generated_ids = np.full((batch_size, generation_len + 1), pad_token_id)
949+
950+
# Prepare inputs for prefill
951+
prefill_start = perf_counter()
952+
953+
input_ids = inputs["input_ids"]
954+
input_ids_size = input_ids.shape[1]
955+
inputs["input_ids"] = torch.nn.functional.pad(
956+
inputs["input_ids"],
957+
(0, padded_len - input_ids_size),
958+
"constant",
959+
1,
960+
)
961+
inputs["attention_mask"] = torch.nn.functional.pad(
962+
inputs["attention_mask"],
963+
(0, padded_len - input_ids_size), "constant", 0
964+
)
967965

968-
# Skip inputs/outputs again
969-
qpc_session.skip_buffers(
970-
["pixel_values"]
971-
)
966+
for k, v in inputs.items():
967+
inputs[k] = np.array(v)
972968

973-
# Get first token
974-
inputs = {}
975-
inputs["input_ids"] = outputs["logits"].argmax(2)
976-
inputs["position_ids"] = input_len.numpy().astype(np.int64)
977-
# inputs["cross_attention_mask"] = inputs["cross_attention_mask"][:, -1:, :, :]
978-
generated_ids[:, 0] = inputs["input_ids"].squeeze(1)
979-
# finished_sequences = inputs["input_ids"] == self.tokenizer.eos_token_id
980-
if streamer:
981-
streamer.put(inputs["input_ids"][0])
969+
inputs["pixel_values"] = inputs["pixel_values"].astype("float16")
970+
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
982971

983-
# Decode loop
984-
loop_start = perf_counter()
985-
for num_token in range(1, generation_len):
986-
outputs = qpc_session.run(inputs)
972+
qpc_session.activate()
987973

988-
# Prepare inputs for next iteration
974+
# Run prefill
975+
976+
for i in range(num_chunks):
977+
chunk_inputs = inputs.copy()
978+
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
979+
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
980+
outputs = qpc_session.run(chunk_inputs)
981+
982+
prefill_time = prefill_start-perf_counter()
983+
# Get first token
989984
inputs["input_ids"] = outputs["logits"].argmax(2)
990-
inputs["position_ids"] += 1
991-
generated_ids[:, num_token] = inputs["input_ids"].squeeze(1)
985+
inputs["position_ids"] = input_len.numpy()
986+
generated_ids[:, 0] = inputs["input_ids"].squeeze(1)
992987
if streamer:
993988
streamer.put(inputs["input_ids"][0])
994989

995-
end = perf_counter()
996-
if streamer:
997-
streamer.end()
998-
999-
prefill_perf = 1 / (loop_start - start)
1000-
decode_perf = (num_token - 1) / (end - loop_start)
1001-
total_perf = num_token / (end - start)
990+
qpc_session.skip_buffers(["pixel_values"])
991+
inputs.pop("pixel_values")
992+
993+
# Decode loop
994+
decode_start = perf_counter()
995+
for num_token in range(1, generation_len):
996+
outputs = qpc_session.run(inputs)
997+
# Prepare inputs for next iteration
998+
inputs["input_ids"] = outputs["logits"].argmax(2)
999+
inputs["position_ids"] += 1
1000+
generated_ids[:, num_token] = inputs["input_ids"].squeeze(1)
1001+
if streamer:
1002+
streamer.put(inputs["input_ids"][0])
1003+
1004+
decode_end = perf_counter()
1005+
if streamer:
1006+
streamer.end()
10021007

1003-
print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr)
1004-
print("E2ET:", round(end - start, 2), "s", file=sys.stderr)
1005-
print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr)
1006-
print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr)
1007-
print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr)
1008-
if batch_size > 1:
1009-
print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr)
1010-
print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr)
1011-
print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr)
1012-
return generated_ids[:, :generation_len]
1008+
decode_perf = (num_token - 1) / (decode_end - decode_start)
1009+
total_time = decode_end - prefill_start
1010+
total_perf = num_token / total_time
1011+
1012+
return CloudAI100ExecInfoNew(
1013+
batch_size=batch_size,
1014+
generated_ids=generated_ids,
1015+
perf_metrics=PerfMetrics(
1016+
prefill_time=prefill_time,
1017+
decode_perf=decode_perf,
1018+
total_perf=total_perf,
1019+
total_time=total_time
1020+
)
1021+
)
10131022

10141023
@property
10151024
def model_hash(self) -> str:
@@ -1040,9 +1049,9 @@ class QEFFAutoModelForImageTextToText:
10401049
@classmethod
10411050
def from_pytorch_model(cls, model: nn.Module, kv_offload=False, **kwargs):
10421051
if kv_offload:
1043-
return QEffAutoModelForImageTextToText2QPC(model, **kwargs)
1052+
return _QEffAutoModelForImageTextToText2QPC(model, **kwargs)
10441053
else:
1045-
return QEFFAutoModelForImageTextToText1QPC(model, **kwargs)
1054+
return _QEFFAutoModelForImageTextToText1QPC(model, **kwargs)
10461055

10471056

10481057
@classmethod

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
377377
_match_string_replace_method = {
378378
"InternVLChatModel": {
379379
"forward": QEffInternVLModel.forward,
380-
"generate_dummy_inputs": QEffInternVLModel.generate_dummy_inputs,
380+
"get_dummy_inputs": QEffInternVLModel.get_dummy_inputs,
381381
"get_specializations": QEffInternVLModel.get_specializations,
382382
"get_onnx_dynamic_axes": QEffInternVLModel.get_onnx_dynamic_axes,
383383
"get_output_names": QEffInternVLModel.get_output_names,

0 commit comments

Comments
 (0)