Skip to content

Commit

Permalink
Update docstrings, Cleanup TextGeneration class
Browse files Browse the repository at this point in the history
Signed-off-by: quic-suppugun <quic_suppugun@quicinc.com>
  • Loading branch information
quic-suppugun committed Nov 15, 2024
1 parent 8878f4a commit a1832a2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 91 deletions.
198 changes: 108 additions & 90 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ class CloudAI100ExecInfo:
:batch_size (int): Batch size of the QPC compilation.
:generated_texts (Union[List[List[str]], List[str]]): Generated text(s).
:generated_ids (Union[List[np.ndarray], np.ndarray]): Generated IDs.
:prefill_time (float): Time for prefilling.
:decode_perf (float): Decoding performance.
:total_perf (float): Total performance.
:total_time (float): Total time.
:perf_metrics (PerfMetrics): Performance metrics.
"""

batch_size: int
Expand Down Expand Up @@ -158,8 +155,17 @@ def latency_stats_bertstyle(
print(round((cur_len - init_len) / (end - start), 2), "tok/s")


# Read from specializations.json, Relative path computed from qpc path
def get_compilation_dims(qpc_path: str) -> Tuple[int, int]:
"""
Function to fetch compilation dimensions from specializations.json.
Uses qpc path to compute path to specilaizations.json.
Args:
qpc_path (str): Path to directory comprising generated binary file after compilation.
Returns:
:tuple: compilation batch size, compilation context length
"""
qpc_base_path = os.path.dirname(os.path.normpath(qpc_path))
specialization_file_path = os.path.join(qpc_base_path, "specializations.json")
logger.info(f"specialization_file_path : {specialization_file_path}")
Expand Down Expand Up @@ -237,17 +243,24 @@ def print_latency_stats_kv(prompt, exec_info, automation: bool = False):
print("=====================================================================")

def calculate_latency(total_decoded_tokens, loop_start, start, end, decode_pause_time=0):
"""
Method will calculate the latency metrics using the time loops and based on the total decoded token count.
"""
Method will calculate the latency metrics using the time loops and based on the total decoded token count.
Returns:
total_num_decoded_tokens, prefill_perf, decode_perf, total_perf
"""
prefill_time = loop_start - start + decode_pause_time
decode_perf = (total_decoded_tokens) / (end - loop_start - decode_pause_time)
total_perf = (total_decoded_tokens) / (end - start)
total_time = end - start
return prefill_time, decode_perf, total_perf, total_time
Args:
:total_decoded_tokens (int): Number of tokens generated in decode stage.
:loop_start (float): Start time of decode loop.
:start (float): Start time.
:end (float): End time.
:decode_pause_time (float): Total decode pause time in continuous batching decode stage.
Returns:
:tuple: prefill time, decode performance, total performance, total time
"""
prefill_time = loop_start - start + decode_pause_time
decode_perf = (total_decoded_tokens) / (end - loop_start - decode_pause_time)
total_perf = (total_decoded_tokens) / (end - start)
total_time = end - start
return prefill_time, decode_perf, total_perf, total_time

def cloud_ai_100_exec_kv(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
Expand Down Expand Up @@ -307,7 +320,7 @@ def cloud_ai_100_exec_kv(
)
if full_batch_size is None:
exec_info = [
generate_text.cloud_ai_100_exec_kv_helper(prompt[i : i + batch_size], generation_len, stream)
generate_text.generate(prompt[i : i + batch_size], generation_len, stream)
for i in range(0, len(prompt), batch_size)
]
prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info])
Expand All @@ -324,12 +337,12 @@ def cloud_ai_100_exec_kv(
perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time),
)
else:
exec_info = generate_text.cloud_ai_100_exec_kv_helper(prompt=prompt, generation_len=generation_len)
exec_info = generate_text.generate(prompt=prompt, generation_len=generation_len)

print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation)
return exec_info

class TextGenerationBase:
class QEffTextGenerationBase:
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
Expand Down Expand Up @@ -365,6 +378,10 @@ def __init__(

self.tokenizer = tokenizer
self._set_tokenizer_params() # set tokenizer params
# Skip inputs/outputs
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("past_")]
)

def _set_tokenizer_params(self):
"""
Expand Down Expand Up @@ -484,7 +501,6 @@ def initialize_decode_inputs(self, num_prompts, execution_batch_size, max_gen_le
self.decode_input_ids = np.zeros((execution_batch_size, 1), np.int64)
self.decode_pos_ids = np.zeros((execution_batch_size, 1), np.int64)
self.generation_len = np.zeros((execution_batch_size, 1), np.int64)
return

def update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None):
"""
Expand Down Expand Up @@ -545,11 +561,6 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
position_ids (array): The position IDs.
generation_len (int): The generation length.
"""
# Skip inputs/outputs
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("past_")]
)

# Run prefill
inputs = self.tokenizer(prompt, return_tensors="np", padding=True)
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
Expand Down Expand Up @@ -661,7 +672,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
generated_id_current_index[decode_batch_id] += 1
return decode_pause_time

def run_decode(self, decode_inputs, generation_len, stream):
def run_decode(self, decode_inputs, generation_len, streamer: Optional[transformers.TextStreamer] = None):
"""
Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length.
Expand All @@ -670,16 +681,14 @@ def run_decode(self, decode_inputs, generation_len, stream):
Args:
decode_inputs (dict): The initial inputs for decoding. This should be a dictionary containing 'input_ids' and 'position_ids'.
generation_len (int): Max allowed length for generating tokens. The decoding process will be terminated when generation length is reached.
streamer (transformers.TextStreamer): TextStreamer object to print decoded tokens to console.
Returns:
num_token (int): The number of tokens processed in the decoding process.
"""
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
if stream:
streamer = transformers.TextStreamer(self.tokenizer)
num_token = 0
for num_token in range(1, generation_len):
if stream:
if streamer:
streamer.put(decode_inputs["input_ids"][0])
outputs = self._session.run(decode_inputs)

Expand Down Expand Up @@ -741,7 +750,7 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
) -> None:
self._qaic_model = TextGenerationBase(tokenizer,
self._qaic_model = QEffTextGenerationBase(tokenizer,
qpc_path,
full_batch_size,
ctx_len,
Expand All @@ -750,48 +759,58 @@ def __init__(
write_io_dir)
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
self._perf_metrics: Optional[PerfMetrics] = None
self._prompt_queue = None
self._ctx_len = ctx_len
self._perf_metrics = None
self._prompt_queue = None
self._text_streamer = None

@property
def perf_metrics(self):
return self._perf_metrics

def setup(self, prompt: List[str], generation_len):
def _setup_model_execution_inputs(self, prompt: List[str], generation_len: Optional[int] = None):
"""
This method should be called to set/reset inputs
Args:
:prompt (List[str]): prompts for the model text generation
:generation_len (Optional[int], optional): Number of tokens to be generated.
"""
execution_batch_size = self._full_batch_size if self._full_batch_size is not None else self._qaic_model.batch_size
max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len)

# Create a prompt queue.
self._prompt_queue = deque(prompt)
# initialize np arrays for storing the prefill output for all the decode batch size.
# Initialize np arrays for storing the prefill output for all the decode batch size.
num_prompts = len(self._prompt_queue)

self._qaic_model.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length)
return

def regular_model_execution(self, prompt: List[str], generation_len: Optional[int] = None, stream: bool = True):
def _regular_model_execution(self, prompt: List[str], generation_len: Optional[int] = None, stream: Optional[bool] = True):
"""
Executes the model in regular mode.
This method runs the prefill, prepares the decode inputs, and then runs the decode. The generated texts are decoded and optionally streamed. Latency metrics are calculated and returned.
Args:
:prompt (str): Sample prompt for the model text generation.
:generation_len (int): Number of tokens to be generated.
:prompt (List[str]): The list of prompts for the model.
:generation_len (Optional[int], optional): The generation length.
:stream (Optional[bool], optional): Boolean flag to enable stream output to console.
Returns:
:tuple: A tuple containing prefill time, decode performance, total performance, total time and generated_texts.
:tuple: A tuple containing performance metrics and generated texts.
"""
self._setup_model_execution_inputs(prompt, generation_len)
if stream and self._text_streamer is None:
self._text_streamer = transformers.TextStreamer(self._tokenizer)
start = perf_counter()
outputs, position_ids, generation_len = self._qaic_model.run_prefill(
prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size
)
self._qaic_model.update_decode_input(outputs, position_ids, generation_len)

decode_inputs = self._qaic_model.prepare_decode_inputs()

loop_start = perf_counter() # Start decode loop timer
num_token = self._qaic_model.run_decode(decode_inputs, generation_len, stream)
num_token = self._qaic_model.run_decode(decode_inputs, generation_len, self._text_streamer)
end = perf_counter()
generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True)

Expand All @@ -802,19 +821,56 @@ def regular_model_execution(self, prompt: List[str], generation_len: Optional[in
self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time)
return self._perf_metrics, generated_texts

def regular_model_execution_stream_tokens(self, prompt: List[str], generation_len: Optional[int] = None):
def _continuous_batching_execution(self, prompt: List[str], generation_len: Optional[int] = None):
"""
Executes the model in regular mode.
Executes the model using continuous batching.
This method handles the execution of the model when continuous batching is enabled. It runs the prefill step for all inputs, performs continuous batching decode, and then decodes the generated texts. The texts are optionally streamed. Latency metrics are calculated and returned.
Args:
:prompt (List[str]): The list of prompts for the model.
:generation_len (Optional[int], optional): The generation length.
Returns:
:tuple: A tuple containing performance metrics and generated texts.
"""
self._setup_model_execution_inputs(prompt, generation_len)
self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1)
start = perf_counter()
self._qaic_model.run_prefill_for_all_inputs(self._prompt_queue, generation_len)

loop_start = perf_counter() # Start decode loop timer
decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len)
end = perf_counter()

generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True)

total_decode_tokens = sum(
np.sum(self._qaic_model.generated_ids[i] != self._tokenizer.pad_token_id) - 1 for i in range(len(prompt))
)
prefill_time, decode_perf, total_perf, total_time = calculate_latency(
total_decode_tokens, loop_start, start, end, decode_pause_time
)
prefill_time /= len(prompt) # Average prefill time for continuous batching
self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time)
return self._perf_metrics, generated_texts

def generate_stream_tokens(self, prompt: List[str], generation_len: Optional[int] = None):
"""
Executes the model for a given list of prompts and a specified generation length.
This method runs the prefill, prepares the decode inputs, and then runs the decode. The tokens are decoded and streamed as they are generated. Latency metrics are calculated and can be retreived
after all tokens are streamed.
Args:
:prompt (str): Sample prompt for the model text generation.
:generation_len (int): Number of tokens to be generated.
:prompt (List[str]): The list of prompts for the model.
:generation_len (Optional[int], optional): The generation length.
Returns:
:tuple: A list containing decoded tokens corresponding to each index of batch size.
Yields:
:list: A list containing decoded tokens corresponding to each index of batch size.
"""
if self._full_batch_size is not None:
raise NotImplementedError("Streaming tokens is currently unavailable for continuous batch execution.")
self._setup_model_execution_inputs(prompt, generation_len)
start = perf_counter()
outputs, position_ids, generation_len = self._qaic_model.run_prefill(
prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size
Expand All @@ -839,63 +895,25 @@ def regular_model_execution_stream_tokens(self, prompt: List[str], generation_le
)
self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time)

def continuous_batching_execution(self, prompt: List[str], prompt_queue: Deque[List[str]], generation_len: Optional[int] = None):
"""
Executes the model using continuous batching.
This method handles the execution of the model when continuous batching is enabled. It runs the prefill step for all inputs, performs continuous batching decode, and then decodes the generated texts. The texts are optionally streamed. Latency metrics are calculated and returned.
Args:
:prompt (list): prompts for the model text generation
:prompt_queue (list): Queue of prompts for the model text generation.
:generation_len (int): Number of tokens to be generated.
Returns:
:tuple: A tuple containing prefill time, decode performance, total performance, total time and generated_texts.
"""
self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1)
start = perf_counter()
self._qaic_model.run_prefill_for_all_inputs(prompt_queue, generation_len)

loop_start = perf_counter() # Start decode loop timer
decode_pause_time = self._qaic_model.run_continuous_batching_decode(prompt_queue, generation_len)
end = perf_counter()

generated_texts = self._tokenizer.batch_decode(self._qaic_model.generated_ids, skip_special_tokens=True)

total_decode_tokens = sum(
np.sum(self._qaic_model.generated_ids[i] != self._tokenizer.pad_token_id) - 1 for i in range(len(prompt))
)
prefill_time, decode_perf, total_perf, total_time = calculate_latency(
total_decode_tokens, loop_start, start, end, decode_pause_time
)
prefill_time /= len(prompt) # Average prefill time for continuous batching
self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time)
return self._perf_metrics, generated_texts

def cloud_ai_100_exec_kv_helper(self, prompt: List[str], generation_len: Optional[int] = None, stream: bool = True):
def generate(self, prompt: List[str], generation_len: Optional[int] = None, stream: bool = True):
"""
Executes the model for a given list of prompts and a specified generation length.
Args:
prompt (List[str]): The list of prompts for the model.
generation_len (Optional[int], optional): The generation length.
stream (Optional[bool], optional): Boolean flag to enable stream output to console.
Returns:
latency_stats (tuple): A tuple containing the generated texts, prefill time, decode performance, total
performance, and total time.
latency_stats (tuple): A tuple containing the generated texts, performance metrics.
"""
self.setup(prompt, generation_len)

if self._full_batch_size is not None:
logger.warning("Streamer is currently unavailable for continuous batch execution.")
perf_metrics, generated_texts = self.continuous_batching_execution(
prompt, self._prompt_queue, generation_len
)
perf_metrics, generated_texts = self._continuous_batching_execution(prompt, generation_len)
else:
if stream:
print("\nPrompt : " + prompt[0] + "\nCompletion :", flush=True, end="")
perf_metrics, generated_texts = self.regular_model_execution(
prompt, generation_len, stream
)
perf_metrics, generated_texts = self._regular_model_execution(prompt, generation_len, stream)

if stream:
stream_start = 0 if self._full_batch_size else 1
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/utils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def run_kv_model_on_cloud_ai_100(self, qpc_path, device_group=None):
device_id=device_group,
ctx_len=self.input_handler.ctx_len,
full_batch_size=self.input_handler.full_batch_size,
).cloud_ai_100_exec_kv_helper(
).generate(
prompt=self.input_handler.prompt,
generation_len=self.gen_len,
stream=False
Expand Down

0 comments on commit a1832a2

Please sign in to comment.