Skip to content

Commit 6a9b384

Browse files
committed
Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs
2 parents 8af3af1 + 58cebc8 commit 6a9b384

File tree

20 files changed

+2256
-1597
lines changed

20 files changed

+2256
-1597
lines changed

examples/metrics-monitoring/README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,40 @@
22

33
## Continuous Batching Metrics in Transformers
44

5+
To setup metric monitoring with continuous batching, you will want to have tempo and prometheus running.
6+
7+
For this, we provide a docker compose image in `examples/metrics-monitoring`.
8+
9+
To run it:
10+
11+
```sh
12+
cd examples/metrics-monitoring
13+
docker compose up
14+
```
15+
16+
Then, in your srcipt running CB, you will need to create a MeterProvider and TracerProvider as follows:
17+
18+
```py
19+
from opentelemetry import metrics, trace
20+
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
21+
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
22+
from opentelemetry.sdk.metrics import MeterProvider
23+
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
24+
from opentelemetry.sdk.resources import Resource
25+
from opentelemetry.sdk.trace import TracerProvider
26+
from opentelemetry.sdk.trace.export import BatchSpanProcessor
27+
28+
resource = Resource.create({"service.name": "transformers"})
29+
30+
metrics_exporter = PeriodicExportingMetricReader(
31+
OTLPMetricExporter(endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var
32+
export_interval_millis=1000
33+
)
34+
meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter])
35+
metrics.set_meter_provider(meter_provider)
36+
37+
trace_exporter = OTLPSpanExporter(endpoint="http://localhost:4318/v1/traces") # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var
38+
tracer_provider = TracerProvider(resource=resource)
39+
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
40+
trace.set_tracer_provider(tracer_provider)
41+
```
Lines changed: 244 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import argparse
16+
import json
17+
import os
118
import time
19+
from typing import Optional
220

321
import datasets
422
import torch
@@ -7,108 +25,247 @@
725
from transformers.generation import GenerationConfig
826

927

10-
torch.set_float32_matmul_precision("high")
28+
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
1129

12-
model_id = "meta-llama/Llama-3.2-3b-Instruct"
13-
model = (
14-
AutoModelForCausalLM.from_pretrained(
15-
model_id,
16-
attn_implementation="paged_attention|kernels-community/flash-attn",
17-
dtype=torch.bfloat16,
30+
31+
def generate_simple(
32+
attn_implementation: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
33+
) -> list[str]:
34+
attn_implementation = {
35+
"sdpa_paged": "sdpa",
36+
"eager_paged": "eager",
37+
"flash_paged": "flash_attention_2",
38+
}[attn_implementation]
39+
40+
model = (
41+
AutoModelForCausalLM.from_pretrained(
42+
MODEL_ID,
43+
torch_dtype=torch.bfloat16,
44+
attn_implementation=attn_implementation,
45+
)
46+
.cuda()
47+
.eval()
1848
)
19-
.eval()
20-
.cuda()
21-
)
22-
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
23-
24-
generation_config = GenerationConfig(
25-
max_new_tokens=512,
26-
# use_cuda_graph=False,
27-
eos_token_id=tokenizer.eos_token_id,
28-
pad_token_id=tokenizer.pad_token_id,
29-
do_sample=False,
30-
)
31-
32-
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
33-
train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version
34-
print("--- Running CB Generation Example ---")
35-
36-
37-
def tokenize_function(examples):
38-
return tokenizer(examples["question"])
39-
40-
41-
tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
42-
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
43-
44-
start_time_simple = time.time()
45-
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
46-
batch_outputs = model.generate_batch(
47-
inputs=simple_batch_inputs,
48-
generation_config=generation_config,
49-
)
50-
end_time_simple = time.time()
51-
token_count = 0
52-
for request in batch_outputs:
53-
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
49+
50+
decoded_outputs = []
51+
for input_ids in simple_batch_inputs:
52+
input_ids = torch.tensor([input_ids]).to("cuda")
53+
attention_mask = torch.ones_like(input_ids)
54+
outputs = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config)
55+
generated_tokens = outputs[0][input_ids.shape[1] :]
56+
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
57+
decoded_outputs.append(decoded_output)
58+
59+
return decoded_outputs
60+
61+
62+
def setup_metrics():
5463
try:
55-
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
56-
token_count += len(batch_outputs[request].generated_tokens[1:])
64+
from opentelemetry import metrics, trace
65+
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
66+
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
67+
from opentelemetry.sdk.metrics import MeterProvider
68+
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
69+
from opentelemetry.sdk.resources import Resource
70+
from opentelemetry.sdk.trace import TracerProvider
71+
from opentelemetry.sdk.trace.export import BatchSpanProcessor
72+
73+
resource = Resource.create({"service.name": "transformers"})
74+
metrics_exporter = PeriodicExportingMetricReader(
75+
OTLPMetricExporter(
76+
endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"
77+
), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var
78+
export_interval_millis=1000,
79+
)
80+
meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter])
81+
metrics.set_meter_provider(meter_provider)
82+
trace_exporter = OTLPSpanExporter(
83+
endpoint="http://localhost:4318/v1/traces"
84+
) # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var
85+
tracer_provider = TracerProvider(resource=resource)
86+
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
87+
trace.set_tracer_provider(tracer_provider)
5788
except Exception as e:
58-
print(f"Decoding failed for request {request}: {e}")
59-
token_count += len(batch_outputs[request].generated_tokens[1:])
60-
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
61-
if len(output_text) > 0:
89+
print(f"Error setting up metrics: {e}")
90+
91+
92+
def batch_generate(
93+
model: AutoModelForCausalLM,
94+
simple_batch_inputs: list,
95+
generation_config: GenerationConfig,
96+
tokenizer: AutoTokenizer,
97+
displayed_samples: int = 0, # -1: no display, 0: display stats, >0: display inputs and some outputs
98+
output_file: Optional[str] = None,
99+
expected_outputs: Optional[list[str]] = None,
100+
slice_inputs: bool = True,
101+
) -> tuple[float, float]:
102+
# Actual batch generation
103+
if displayed_samples >= 0:
104+
print("--- Running CB Generation Example ---")
105+
start_time_simple = time.time()
106+
batch_outputs = model.generate_batch(
107+
inputs=simple_batch_inputs,
108+
generation_config=generation_config,
109+
slice_inputs=slice_inputs, # TODO: move this to the generation config
110+
)
111+
end_time_simple = time.time()
112+
if displayed_samples >= 0:
113+
print("Done with batch generation.")
114+
115+
# Decode outputs
116+
token_count = 0
117+
data = []
118+
for i, request in enumerate(batch_outputs):
119+
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True)
120+
data.append({"input": input_text})
121+
122+
# Try to decode the output
123+
try:
124+
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=True)
125+
token_count += len(batch_outputs[request].generated_tokens[1:])
126+
data[-1]["output"] = output_text
127+
except Exception as e:
128+
print(f"Decoding failed for request {request}: {e}")
129+
data[-1]["output"] = "__ERROR__"
130+
continue
131+
132+
# Display sample if asked
133+
if i < displayed_samples:
134+
if len(output_text) > 0:
135+
print("-" * 20)
136+
print(f"{request} Input: {input_text}")
137+
print(f"{request} Output: {output_text}")
138+
else:
139+
print(f"{request} Input: {input_text}")
140+
print("[WARN]")
141+
print(f"{request} Output was empty!")
142+
143+
# Compare with classic generate if asked
144+
if expected_outputs is not None:
145+
matches = output_text == expected_outputs[i]
146+
data[-1]["ref"] = expected_outputs[i]
147+
data[-1]["matches"] = matches
148+
print(f"Request {i} matches" if matches else f"Request {i} does NOT match!")
149+
150+
# Compute stats and maybe print them
151+
gen_time = end_time_simple - start_time_simple
152+
tok_per_sec = token_count / gen_time
153+
if displayed_samples >= 0:
62154
print("-" * 20)
63-
print(f"{request} Input: {input_text}")
64-
print(f"{request} Output: {output_text}")
65-
else:
66-
print("", end="\r\r\r\r")
67-
print("-" * 20)
68-
print("--- Finished CB Generation Example ---\n\n")
155+
print("--- Finished CB Generation Example ---\n")
156+
print(f"CB generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s")
157+
stats = {
158+
"num_blocks": generation_config.num_blocks,
159+
"max_batch_tokens": generation_config.max_batch_tokens,
160+
"gen_time": gen_time,
161+
"token_count": token_count,
162+
"tok_per_sec": tok_per_sec,
163+
}
69164

165+
# If an output file is provided, save the reordered data to it
166+
data.sort(key=lambda x: x["input"])
167+
data = [stats] + data
168+
if output_file is not None:
169+
with open(output_file, "w") as f:
170+
json.dump(data, f, indent=4)
70171

71-
print(
72-
f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s"
73-
)
172+
return gen_time, tok_per_sec
74173

75174

76-
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
175+
if __name__ == "__main__":
176+
# Parse args
177+
parser = argparse.ArgumentParser()
178+
parser.add_argument("--num-blocks", "-n", type=int, default=None)
179+
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
77180

78-
# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512)
79-
# simple_batch_inputs = list(tokenized_test_prompts["input_ids"])
181+
parser.add_argument(
182+
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
183+
)
184+
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
185+
parser.add_argument("--slice-inputs", action="store_true", default=False)
186+
parser.add_argument("--use-cuda-graph", action="store_true", default=False)
187+
parser.add_argument("--compile", action="store_true", default=False)
80188

81-
# def tokenize_function(examples):
82-
# # Truncate to avoid overly long prompts exceeding max context length
83-
# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512)
189+
parser.add_argument("--samples", type=int, default=500)
190+
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
191+
parser.add_argument("--output-file", type=str, default=None)
192+
parser.add_argument("--compare", action="store_true", default=False)
193+
parser.add_argument("--metrics", action="store_true", default=False)
194+
args = parser.parse_args()
84195

196+
# If turned on, we setup metrics
197+
if args.metrics:
198+
setup_metrics()
85199

86-
# tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
87-
# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
200+
# Set matmul precision if not none
201+
if args.matmul_precision != "none":
202+
torch.set_float32_matmul_precision(args.matmul_precision)
88203

204+
# Prepare model
205+
model = AutoModelForCausalLM.from_pretrained(
206+
MODEL_ID,
207+
attn_implementation=args.attn,
208+
dtype=torch.bfloat16,
209+
)
210+
model = model.cuda().eval()
89211

90-
# model.config.attn_implementation = "sdpa"
91-
# start_time_simple = time.time()
92-
# batch_size = 64
93-
# full_outputs = []
94-
# from tqdm import tqdm
212+
# If turned on, we compile the model
213+
if args.compile:
214+
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
215+
if args.slice_inputs:
216+
assert not args.compile, "Slicing inputs requires is not the model to be compiled"
217+
assert not args.use_cuda_graph, "Slicing inputs is not compatible with cuda graphs"
95218

96-
# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)):
97-
# outputs = model.generate(
98-
# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device),
99-
# generation_config=GenerationConfig(
100-
# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
101-
# ),
102-
# )
103-
# full_outputs.extend(outputs.tolist())
219+
# Prepare tokenizer and dataset
220+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
221+
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
222+
dataset = dataset.select(range(args.samples)) # Use only 5 examples for the simple version
223+
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
224+
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
104225

105-
# end_time_simple = time.time()
106-
# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds")
226+
# Prepare generation config
227+
generation_config = GenerationConfig(
228+
max_new_tokens=512,
229+
use_cuda_graph=args.use_cuda_graph,
230+
eos_token_id=tokenizer.eos_token_id,
231+
pad_token_id=tokenizer.pad_token_id,
232+
do_sample=False,
233+
num_blocks=args.num_blocks,
234+
max_batch_tokens=args.max_batch_tokens,
235+
)
236+
237+
# If we need to compare, we need to generate the reference outputs
238+
expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None
239+
240+
# If no output file is provided, we pick a name based on the args
241+
if args.output_file is None:
242+
os.makedirs("runs/cb", exist_ok=True)
243+
attn = args.attn.replace("|", "_").replace("/", "_")
244+
args.output_file = (
245+
f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json"
246+
)
247+
248+
# Run warmup batch generation
249+
batch_generate(
250+
model,
251+
simple_batch_inputs[: min(5, args.samples)],
252+
generation_config,
253+
tokenizer,
254+
displayed_samples=-1,
255+
slice_inputs=args.slice_inputs,
256+
)
257+
258+
# Run batch generation
259+
gen_time, tok_per_sec = batch_generate(
260+
model,
261+
simple_batch_inputs,
262+
generation_config,
263+
tokenizer,
264+
displayed_samples=args.displayed,
265+
output_file=args.output_file,
266+
expected_outputs=expected_outputs,
267+
slice_inputs=args.slice_inputs,
268+
)
107269

108-
# print("\nResults from simple generate_batch:")
109-
# for i, request in enumerate(full_outputs):
110-
# output_text = tokenizer.decode(request, skip_special_tokens=False)
111-
# print("-" * 20)
112-
# print(f" Output: {output_text}")
113-
# print("-" * 20)
114-
# print("--- Finished Simple Batch Generation Example ---\n\n")
270+
# Example usage:
271+
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json

0 commit comments

Comments
 (0)