|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2025 The HuggingFace Inc. team |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +import argparse |
| 16 | +import json |
| 17 | +import os |
1 | 18 | import time |
| 19 | +from typing import Optional |
2 | 20 |
|
3 | 21 | import datasets |
4 | 22 | import torch |
|
7 | 25 | from transformers.generation import GenerationConfig |
8 | 26 |
|
9 | 27 |
|
10 | | -torch.set_float32_matmul_precision("high") |
| 28 | +MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" |
11 | 29 |
|
12 | | -model_id = "meta-llama/Llama-3.2-3b-Instruct" |
13 | | -model = ( |
14 | | - AutoModelForCausalLM.from_pretrained( |
15 | | - model_id, |
16 | | - attn_implementation="paged_attention|kernels-community/flash-attn", |
17 | | - dtype=torch.bfloat16, |
| 30 | + |
| 31 | +def generate_simple( |
| 32 | + attn_implementation: str, simple_batch_inputs: list[int], generation_config: GenerationConfig |
| 33 | +) -> list[str]: |
| 34 | + attn_implementation = { |
| 35 | + "sdpa_paged": "sdpa", |
| 36 | + "eager_paged": "eager", |
| 37 | + "flash_paged": "flash_attention_2", |
| 38 | + }[attn_implementation] |
| 39 | + |
| 40 | + model = ( |
| 41 | + AutoModelForCausalLM.from_pretrained( |
| 42 | + MODEL_ID, |
| 43 | + torch_dtype=torch.bfloat16, |
| 44 | + attn_implementation=attn_implementation, |
| 45 | + ) |
| 46 | + .cuda() |
| 47 | + .eval() |
18 | 48 | ) |
19 | | - .eval() |
20 | | - .cuda() |
21 | | -) |
22 | | -tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") |
23 | | - |
24 | | -generation_config = GenerationConfig( |
25 | | - max_new_tokens=512, |
26 | | - # use_cuda_graph=False, |
27 | | - eos_token_id=tokenizer.eos_token_id, |
28 | | - pad_token_id=tokenizer.pad_token_id, |
29 | | - do_sample=False, |
30 | | -) |
31 | | - |
32 | | -train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") |
33 | | -train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version |
34 | | -print("--- Running CB Generation Example ---") |
35 | | - |
36 | | - |
37 | | -def tokenize_function(examples): |
38 | | - return tokenizer(examples["question"]) |
39 | | - |
40 | | - |
41 | | -tokenized_datasets = train_dataset.map(tokenize_function, batched=True) |
42 | | -simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] |
43 | | - |
44 | | -start_time_simple = time.time() |
45 | | -model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") |
46 | | -batch_outputs = model.generate_batch( |
47 | | - inputs=simple_batch_inputs, |
48 | | - generation_config=generation_config, |
49 | | -) |
50 | | -end_time_simple = time.time() |
51 | | -token_count = 0 |
52 | | -for request in batch_outputs: |
53 | | - input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) |
| 49 | + |
| 50 | + decoded_outputs = [] |
| 51 | + for input_ids in simple_batch_inputs: |
| 52 | + input_ids = torch.tensor([input_ids]).to("cuda") |
| 53 | + attention_mask = torch.ones_like(input_ids) |
| 54 | + outputs = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config) |
| 55 | + generated_tokens = outputs[0][input_ids.shape[1] :] |
| 56 | + decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| 57 | + decoded_outputs.append(decoded_output) |
| 58 | + |
| 59 | + return decoded_outputs |
| 60 | + |
| 61 | + |
| 62 | +def setup_metrics(): |
54 | 63 | try: |
55 | | - output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) |
56 | | - token_count += len(batch_outputs[request].generated_tokens[1:]) |
| 64 | + from opentelemetry import metrics, trace |
| 65 | + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter |
| 66 | + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter |
| 67 | + from opentelemetry.sdk.metrics import MeterProvider |
| 68 | + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader |
| 69 | + from opentelemetry.sdk.resources import Resource |
| 70 | + from opentelemetry.sdk.trace import TracerProvider |
| 71 | + from opentelemetry.sdk.trace.export import BatchSpanProcessor |
| 72 | + |
| 73 | + resource = Resource.create({"service.name": "transformers"}) |
| 74 | + metrics_exporter = PeriodicExportingMetricReader( |
| 75 | + OTLPMetricExporter( |
| 76 | + endpoint="http://localhost:9090/api/v1/otlp/v1/metrics" |
| 77 | + ), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var |
| 78 | + export_interval_millis=1000, |
| 79 | + ) |
| 80 | + meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter]) |
| 81 | + metrics.set_meter_provider(meter_provider) |
| 82 | + trace_exporter = OTLPSpanExporter( |
| 83 | + endpoint="http://localhost:4318/v1/traces" |
| 84 | + ) # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var |
| 85 | + tracer_provider = TracerProvider(resource=resource) |
| 86 | + tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) |
| 87 | + trace.set_tracer_provider(tracer_provider) |
57 | 88 | except Exception as e: |
58 | | - print(f"Decoding failed for request {request}: {e}") |
59 | | - token_count += len(batch_outputs[request].generated_tokens[1:]) |
60 | | - output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) |
61 | | - if len(output_text) > 0: |
| 89 | + print(f"Error setting up metrics: {e}") |
| 90 | + |
| 91 | + |
| 92 | +def batch_generate( |
| 93 | + model: AutoModelForCausalLM, |
| 94 | + simple_batch_inputs: list, |
| 95 | + generation_config: GenerationConfig, |
| 96 | + tokenizer: AutoTokenizer, |
| 97 | + displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs |
| 98 | + output_file: Optional[str] = None, |
| 99 | + expected_outputs: Optional[list[str]] = None, |
| 100 | + slice_inputs: bool = True, |
| 101 | +) -> tuple[float, float]: |
| 102 | + # Actual batch generation |
| 103 | + if displayed_samples >= 0: |
| 104 | + print("--- Running CB Generation Example ---") |
| 105 | + start_time_simple = time.time() |
| 106 | + batch_outputs = model.generate_batch( |
| 107 | + inputs=simple_batch_inputs, |
| 108 | + generation_config=generation_config, |
| 109 | + slice_inputs=slice_inputs, # TODO: move this to the generation config |
| 110 | + ) |
| 111 | + end_time_simple = time.time() |
| 112 | + if displayed_samples >= 0: |
| 113 | + print("Done with batch generation.") |
| 114 | + |
| 115 | + # Decode outputs |
| 116 | + token_count = 0 |
| 117 | + data = [] |
| 118 | + for i, request in enumerate(batch_outputs): |
| 119 | + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True) |
| 120 | + data.append({"input": input_text}) |
| 121 | + |
| 122 | + # Try to decode the output |
| 123 | + try: |
| 124 | + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True) |
| 125 | + token_count += len(batch_outputs[request].generated_tokens[1:]) |
| 126 | + data[-1]["output"] = output_text |
| 127 | + except Exception as e: |
| 128 | + print(f"Decoding failed for request {request}: {e}") |
| 129 | + data[-1]["output"] = "__ERROR__" |
| 130 | + continue |
| 131 | + |
| 132 | + # Display sample if asked |
| 133 | + if i < displayed_samples: |
| 134 | + if len(output_text) > 0: |
| 135 | + print("-" * 20) |
| 136 | + print(f"{request} Input: {input_text}") |
| 137 | + print(f"{request} Output: {output_text}") |
| 138 | + else: |
| 139 | + print(f"{request} Input: {input_text}") |
| 140 | + print("[WARN]") |
| 141 | + print(f"{request} Output was empty!") |
| 142 | + |
| 143 | + # Compare with classic generate if asked |
| 144 | + if expected_outputs is not None: |
| 145 | + matches = output_text == expected_outputs[i] |
| 146 | + data[-1]["ref"] = expected_outputs[i] |
| 147 | + data[-1]["matches"] = matches |
| 148 | + print(f"Request {i} matches" if matches else f"Request {i} does NOT match!") |
| 149 | + |
| 150 | + # Compute stats and maybe print them |
| 151 | + gen_time = end_time_simple - start_time_simple |
| 152 | + tok_per_sec = token_count / gen_time |
| 153 | + if displayed_samples >= 0: |
62 | 154 | print("-" * 20) |
63 | | - print(f"{request} Input: {input_text}") |
64 | | - print(f"{request} Output: {output_text}") |
65 | | - else: |
66 | | - print("", end="\r\r\r\r") |
67 | | -print("-" * 20) |
68 | | -print("--- Finished CB Generation Example ---\n\n") |
| 155 | + print("--- Finished CB Generation Example ---\n") |
| 156 | + print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s") |
| 157 | + stats = { |
| 158 | + "num_blocks": generation_config.num_blocks, |
| 159 | + "max_batch_tokens": generation_config.max_batch_tokens, |
| 160 | + "gen_time": gen_time, |
| 161 | + "token_count": token_count, |
| 162 | + "tok_per_sec": tok_per_sec, |
| 163 | + } |
69 | 164 |
|
| 165 | + # If an output file is provided, save the reordered data to it |
| 166 | + data.sort(key=lambda x: x["input"]) |
| 167 | + data = [stats] + data |
| 168 | + if output_file is not None: |
| 169 | + with open(output_file, "w") as f: |
| 170 | + json.dump(data, f, indent=4) |
70 | 171 |
|
71 | | -print( |
72 | | - f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s" |
73 | | -) |
| 172 | + return gen_time, tok_per_sec |
74 | 173 |
|
75 | 174 |
|
76 | | -# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version |
| 175 | +if __name__ == "__main__": |
| 176 | + # Parse args |
| 177 | + parser = argparse.ArgumentParser() |
| 178 | + parser.add_argument("--num-blocks", "-n", type=int, default=None) |
| 179 | + parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) |
77 | 180 |
|
78 | | -# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512) |
79 | | -# simple_batch_inputs = list(tokenized_test_prompts["input_ids"]) |
| 181 | + parser.add_argument( |
| 182 | + "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" |
| 183 | + ) |
| 184 | + parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable |
| 185 | + parser.add_argument("--slice-inputs", action="store_true", default=False) |
| 186 | + parser.add_argument("--use-cuda-graph", action="store_true", default=False) |
| 187 | + parser.add_argument("--compile", action="store_true", default=False) |
80 | 188 |
|
81 | | -# def tokenize_function(examples): |
82 | | -# # Truncate to avoid overly long prompts exceeding max context length |
83 | | -# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512) |
| 189 | + parser.add_argument("--samples", type=int, default=500) |
| 190 | + parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display") |
| 191 | + parser.add_argument("--output-file", type=str, default=None) |
| 192 | + parser.add_argument("--compare", action="store_true", default=False) |
| 193 | + parser.add_argument("--metrics", action="store_true", default=False) |
| 194 | + args = parser.parse_args() |
84 | 195 |
|
| 196 | + # If turned on, we setup metrics |
| 197 | + if args.metrics: |
| 198 | + setup_metrics() |
85 | 199 |
|
86 | | -# tokenized_datasets = train_dataset.map(tokenize_function, batched=True) |
87 | | -# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] |
| 200 | + # Set matmul precision if not none |
| 201 | + if args.matmul_precision != "none": |
| 202 | + torch.set_float32_matmul_precision(args.matmul_precision) |
88 | 203 |
|
| 204 | + # Prepare model |
| 205 | + model = AutoModelForCausalLM.from_pretrained( |
| 206 | + MODEL_ID, |
| 207 | + attn_implementation=args.attn, |
| 208 | + dtype=torch.bfloat16, |
| 209 | + ) |
| 210 | + model = model.cuda().eval() |
89 | 211 |
|
90 | | -# model.config.attn_implementation = "sdpa" |
91 | | -# start_time_simple = time.time() |
92 | | -# batch_size = 64 |
93 | | -# full_outputs = [] |
94 | | -# from tqdm import tqdm |
| 212 | + # If turned on, we compile the model |
| 213 | + if args.compile: |
| 214 | + model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") |
| 215 | + if args.slice_inputs: |
| 216 | + assert not args.compile, "Slicing inputs requires is not the model to be compiled" |
| 217 | + assert not args.use_cuda_graph, "Slicing inputs is not compatible with cuda graphs" |
95 | 218 |
|
96 | | -# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)): |
97 | | -# outputs = model.generate( |
98 | | -# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device), |
99 | | -# generation_config=GenerationConfig( |
100 | | -# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id |
101 | | -# ), |
102 | | -# ) |
103 | | -# full_outputs.extend(outputs.tolist()) |
| 219 | + # Prepare tokenizer and dataset |
| 220 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left") |
| 221 | + dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") |
| 222 | + dataset = dataset.select(range(args.samples)) # Use only 5 examples for the simple version |
| 223 | + tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True) |
| 224 | + simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] |
104 | 225 |
|
105 | | -# end_time_simple = time.time() |
106 | | -# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds") |
| 226 | + # Prepare generation config |
| 227 | + generation_config = GenerationConfig( |
| 228 | + max_new_tokens=512, |
| 229 | + use_cuda_graph=args.use_cuda_graph, |
| 230 | + eos_token_id=tokenizer.eos_token_id, |
| 231 | + pad_token_id=tokenizer.pad_token_id, |
| 232 | + do_sample=False, |
| 233 | + num_blocks=args.num_blocks, |
| 234 | + max_batch_tokens=args.max_batch_tokens, |
| 235 | + ) |
| 236 | + |
| 237 | + # If we need to compare, we need to generate the reference outputs |
| 238 | + expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None |
| 239 | + |
| 240 | + # If no output file is provided, we pick a name based on the args |
| 241 | + if args.output_file is None: |
| 242 | + os.makedirs("runs/cb", exist_ok=True) |
| 243 | + attn = args.attn.replace("|", "_").replace("/", "_") |
| 244 | + args.output_file = ( |
| 245 | + f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json" |
| 246 | + ) |
| 247 | + |
| 248 | + # Run warmup batch generation |
| 249 | + batch_generate( |
| 250 | + model, |
| 251 | + simple_batch_inputs[: min(5, args.samples)], |
| 252 | + generation_config, |
| 253 | + tokenizer, |
| 254 | + displayed_samples=-1, |
| 255 | + slice_inputs=args.slice_inputs, |
| 256 | + ) |
| 257 | + |
| 258 | + # Run batch generation |
| 259 | + gen_time, tok_per_sec = batch_generate( |
| 260 | + model, |
| 261 | + simple_batch_inputs, |
| 262 | + generation_config, |
| 263 | + tokenizer, |
| 264 | + displayed_samples=args.displayed, |
| 265 | + output_file=args.output_file, |
| 266 | + expected_outputs=expected_outputs, |
| 267 | + slice_inputs=args.slice_inputs, |
| 268 | + ) |
107 | 269 |
|
108 | | -# print("\nResults from simple generate_batch:") |
109 | | -# for i, request in enumerate(full_outputs): |
110 | | -# output_text = tokenizer.decode(request, skip_special_tokens=False) |
111 | | -# print("-" * 20) |
112 | | -# print(f" Output: {output_text}") |
113 | | -# print("-" * 20) |
114 | | -# print("--- Finished Simple Batch Generation Example ---\n\n") |
| 270 | +# Example usage: |
| 271 | +# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json |
0 commit comments