Skip to content

Commit 8138032

Browse files
authored
Refactor repo (#17)
1 parent 43570b6 commit 8138032

File tree

14 files changed

+350
-400
lines changed

14 files changed

+350
-400
lines changed
File renamed without changes.

src/main.py

Lines changed: 129 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,150 @@
1+
import contextlib
2+
import gc
3+
import time
4+
from argparse import ArgumentParser, Namespace
15
from typing import List, Optional
26

3-
from src.pipelines import get_pipeline_class
4-
from src.utils.arguments import parse_args
5-
from src.utils.benchmark import benchmark_end_to_end
6-
from src.utils.input import get_dummy_batch
7-
from src.utils.logging import configure_logging
7+
import torch
8+
9+
from src.pipeline import get_pipeline_class
10+
from src.profile import get_profiler, logger
11+
from src.utils import (
12+
configure_logging,
13+
format_mib,
14+
format_ms,
15+
get_dummy_batch,
16+
log_dict,
17+
log_rank_n,
18+
parse_config_args,
19+
)
20+
21+
22+
def get_arg_parser() -> ArgumentParser:
23+
parser = ArgumentParser()
24+
25+
# Model
26+
parser.add_argument("--model_type")
27+
parser.add_argument("--pretrained_config")
28+
parser.add_argument("--pretrained_model")
29+
parser.add_argument("--tokenizer", default="gpt2")
30+
parser.add_argument("--trust_remote_code", action="store_true")
31+
parser.add_argument("config_args", nargs="*")
32+
33+
# Runtime
34+
parser.add_argument("--pipeline_class", default="HF_Pipeline")
35+
parser.add_argument("--device", default="cuda", type=torch.device)
36+
parser.add_argument("--dtype", default="float16", type=lambda x: getattr(torch, x))
37+
parser.add_argument("--local_rank", type=int)
38+
parser.add_argument("--no_fast_init", dest="fast_init", action="store_false")
39+
40+
# Input and output
41+
parser.add_argument("--batch_size", default=1, type=int)
42+
parser.add_argument("--max_input_length", default=-1, type=int)
43+
parser.add_argument("--max_new_tokens", default=100, type=int)
44+
45+
# Cleanup
46+
parser.add_argument("--clear_every_run", action="store_true")
47+
48+
# Benchmark cycles
49+
parser.add_argument("--skip", type=int, default=1)
50+
parser.add_argument("--warmup", type=int, default=None)
51+
parser.add_argument("--cycles", type=int, default=5)
52+
53+
# Profiling and logging
54+
parser.add_argument("--max_log_outputs", default=None, type=int)
55+
parser.add_argument("--profile", action="store_true")
56+
parser.add_argument("--full_trace", action="store_true")
57+
parser.add_argument("--show_op_names", action="store_true")
58+
59+
return parser
860

961

1062
def main(argv: Optional[List[str]] = None) -> None:
11-
args = parse_args(argv=argv)
63+
parser = get_arg_parser()
64+
args = parser.parse_args(argv)
65+
config_args = parse_config_args(args.config_args)
66+
generate_kwargs = {"max_new_tokens": args.max_new_tokens, "do_sample": False}
67+
inputs = get_dummy_batch(args.batch_size, args.max_input_length)
68+
warmup = args.profile if args.warmup is None else args.warmup
69+
max_log_outputs = args.batch_size if args.max_log_outputs is None else args.max_log_outputs
1270

1371
pipeline_class = get_pipeline_class(args.pipeline_class)
1472
pipeline = pipeline_class(
1573
model_type=args.model_type,
1674
pretrained_model=args.pretrained_model,
1775
pretrained_config=args.pretrained_config,
18-
config_args=args.config_args,
76+
config_args=config_args,
1977
tokenizer=args.tokenizer,
2078
device=args.device,
2179
dtype=args.dtype,
2280
fast_init=args.fast_init,
2381
trust_remote_code=args.trust_remote_code,
2482
)
2583

26-
benchmark_end_to_end(
27-
pipeline=pipeline,
28-
inputs=get_dummy_batch(args.batch_size, args.max_input_length),
29-
generate_kwargs={"max_new_tokens": args.max_new_tokens, "do_sample": False},
30-
profile=args.profile,
31-
skip=args.skip,
32-
warmup=args.profile if args.warmup is None else args.warmup,
33-
cycles=args.cycles,
34-
full_trace=args.full_trace,
35-
show_op_names=args.show_op_names,
36-
max_log_outputs=args.batch_size if args.max_log_outputs is None else args.max_log_outputs,
37-
clear_every_run=args.clear_every_run,
38-
)
84+
all_metrics = []
85+
86+
if args.profile:
87+
profiler = get_profiler(
88+
skip=args.skip,
89+
warmup=warmup,
90+
cycles=args.cycles,
91+
full_trace=args.full_trace,
92+
show_op_names=args.show_op_names,
93+
)
94+
else:
95+
profiler = contextlib.nullcontext()
96+
97+
benchmark_stats = {
98+
"Model parameters": pipeline.get_num_parameters(),
99+
"Batch size": len(inputs),
100+
**generate_kwargs,
101+
**pipeline.get_initialization_metrics(),
102+
"Warmup cycles": args.skip + warmup,
103+
"Benchmark cycles": args.cycles,
104+
"Total cycles": args.skip + warmup + args.cycles,
105+
}
106+
107+
if pipeline.device.type == "cuda":
108+
benchmark_stats["Initial memory used"] = format_mib(torch.cuda.memory_allocated())
109+
benchmark_stats["Initial memory reserved"] = format_mib(torch.cuda.memory_reserved())
110+
torch.cuda.reset_peak_memory_stats()
111+
112+
t0 = time.perf_counter()
113+
with profiler as p:
114+
for step in range(args.skip + warmup + args.cycles):
115+
if step == args.skip + warmup:
116+
t1 = time.perf_counter()
117+
benchmark_stats["Warmup time"] = format_ms(t1 - t0)
118+
generated_text, metrics = pipeline(inputs, **generate_kwargs)
119+
if args.profile:
120+
p.step()
121+
122+
if step == 0:
123+
for i, o, _ in zip(inputs, generated_text, range(max_log_outputs)):
124+
log_rank_n(f"{'-' * 60}\nINPUT = {i}\nOUTPUT = {o}", logger.info)
125+
126+
if step >= args.skip + warmup:
127+
all_metrics.append(metrics)
128+
129+
if args.clear_every_run:
130+
torch.cuda.synchronize()
131+
gc.collect()
132+
torch.cuda.empty_cache()
133+
if pipeline.device.type == "cuda":
134+
benchmark_stats["Memory used"] = format_mib(torch.cuda.memory_allocated())
135+
benchmark_stats["Memory reserved"] = format_mib(torch.cuda.memory_reserved())
136+
benchmark_stats["Max memory used"] = format_mib(torch.cuda.max_memory_allocated())
137+
benchmark_stats["Max memory reserved"] = format_mib(torch.cuda.max_memory_reserved())
138+
139+
t2 = time.perf_counter()
140+
benchmark_stats["Benchmark time"] = format_ms(t2 - t1)
141+
benchmark_stats["Total time"] = format_ms(t2 - t0)
142+
143+
if len(all_metrics) > 0:
144+
benchmark_stats.update(pipeline.aggregate_and_format_metrics(all_metrics))
145+
146+
log_rank_n("*** Benchmark results:", logger.info)
147+
log_dict(benchmark_stats, logger.info)
39148

40149

41150
if __name__ == "__main__":

src/pipelines/pipeline.py renamed to src/pipeline.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import contextlib
22
import gc
33
import logging
4+
import os
45
import time
56
from typing import Any, Dict, List, Optional, Tuple
67

78
import numpy as np
89
import torch
910

10-
from src.utils.fast_init import fast_init
11-
from src.utils.logging import format_ms, log_rank_n
12-
from src.utils.utils import parse_revision
11+
from src.fast_init import fast_init
12+
from src.utils import format_ms, log_rank_n, parse_revision
1313
from transformers import (
1414
CONFIG_MAPPING,
1515
AutoConfig,
@@ -239,3 +239,43 @@ def aggregate_and_format_metrics(self, metrics: List[Dict[str, Any]]):
239239

240240
def get_initialization_metrics(self):
241241
return {f"Initialization time ({key})": format_ms(value) for key, value in self.initialization_metrics.items()}
242+
243+
244+
class HF_Pipeline(Pipeline):
245+
pass
246+
247+
248+
class DS_Pipeline(Pipeline):
249+
def __init__(self, **kwargs):
250+
import deepspeed
251+
252+
super().__init__(**kwargs)
253+
254+
if self.device != torch.device("cuda"):
255+
raise ValueError(f"Deepspeed does not support device {self.device}")
256+
257+
if self.dtype not in (torch.float32, torch.float16, torch.bfloat16):
258+
raise ValueError(f"Deepspeed does not support dtype {self.dtype}")
259+
260+
if self.config.model_type not in ("bloom", "gpt2"):
261+
raise ValueError(f"Deepspeed does not support model type {self.config.model_type}")
262+
263+
self.model = deepspeed.init_inference(
264+
self.model,
265+
mp_size=int(os.getenv("WORLD_SIZE", "1")),
266+
# base_dir="./",
267+
dtype=self.dtype,
268+
replace_with_kernel_inject=True,
269+
)
270+
271+
272+
_PIPELINE_CLASS_MAP = {
273+
"HF_Pipeline": HF_Pipeline,
274+
"DS_Pipeline": DS_Pipeline,
275+
}
276+
277+
278+
def get_pipeline_class(name):
279+
if name not in _PIPELINE_CLASS_MAP:
280+
raise NotImplementedError(f"Unsupported pipeline class: {name}")
281+
return _PIPELINE_CLASS_MAP[name]

src/pipelines/__init__.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/pipelines/ds.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/pipelines/transformers.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/profile.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import contextlib
2+
import logging
3+
from typing import Union
4+
5+
import torch
6+
7+
from src.utils import log_rank_n
8+
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def get_trace_fn(full_trace: bool = False, show_op_names: bool = False, rank: int = -1):
14+
def trace_fn(
15+
p: torch.profiler.profile,
16+
):
17+
averages = p.key_averages()
18+
if full_trace:
19+
# Show every GPU op.
20+
# Exclude CPU cuda ops to shorten the table.
21+
events = torch.autograd.profiler.EventList(
22+
[evt for evt in p.profiler.function_events if evt.self_cuda_time_total > 0]
23+
)
24+
log_rank_n(events.table(row_limit=-1, max_src_column_width=1000), logger.info, rank)
25+
26+
if show_op_names:
27+
# Show non-cropped names, in the same order as in the table.
28+
averages_sorted = torch.autograd.profiler.EventList(
29+
sorted(averages, key=lambda evt: evt.self_cuda_time_total, reverse=True)
30+
)
31+
for entry in averages_sorted:
32+
log_rank_n(entry.key, logger.info, rank)
33+
34+
# Try to avoid name cropping, still hard-coded to max 55 characters
35+
log_rank_n(
36+
averages.table(sort_by="self_cuda_time_total", row_limit=-1, max_src_column_width=1000), logger.info, rank
37+
)
38+
39+
return trace_fn
40+
41+
42+
def get_profiler(
43+
skip: int,
44+
warmup: int,
45+
cycles: int,
46+
full_trace: bool = False,
47+
show_op_names: bool = False,
48+
) -> Union[torch.profiler.profile, contextlib.nullcontext]:
49+
schedule = torch.profiler.schedule(
50+
# Warmup is a must if measuring speed as it's when all the optimizations are performed
51+
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
52+
skip_first=skip,
53+
# Warmup for the profiler
54+
warmup=warmup,
55+
wait=0,
56+
active=cycles,
57+
)
58+
return torch.profiler.profile(
59+
schedule=schedule,
60+
activities=[torch.profiler.ProfilerActivity.CUDA],
61+
on_trace_ready=get_trace_fn(full_trace, show_op_names),
62+
)

0 commit comments

Comments
 (0)