Skip to content

Commit

Permalink
fix step_num
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jun 3, 2024
1 parent 9b29cfd commit 373114d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
10 changes: 7 additions & 3 deletions profile.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
#**IMPORTANT** There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. The only combination I was able to get both to work at the same time was with `python=3.11.9` and `torch=2.3.0`.
#Tested on `python=3.11.9 and torch=2.3.0``

#"meta-llama/Llama-2-7b-hf"

python train.py \
--model_name "meta-llama/Llama-2-7b-hf" \
--model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \
--gradient_accumulation_steps 2 \
--batch_size 1 \
--context_length 256 \
Expand All @@ -62,5 +64,7 @@ python train.py \
--dataset dummy \
--profile true \
--export_trace true \
--export_memory_timeline true \
--max_steps 10
--export_memory_timeline false \
--with_stack true \
--max_steps 10 \
--repeat 1
20 changes: 10 additions & 10 deletions profiling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

#adapted from https://github.com/pytorch/torchtitan

def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shapes=False, row_limit=25):
curr_trace_dir_name = str(prof.step_num)
def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shape=False, row_limit=25):
curr_trace_dir_name = "iteration_" + str(prof.step_num)
curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)
Expand All @@ -43,7 +43,7 @@ def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_c

#Export event averages
key_avgs = prof.key_averages(
group_by_input_shape=group_by_input_shapes, group_by_stack_n=group_by_stack
group_by_input_shape=group_by_input_shape, group_by_stack_n=group_by_stack
).table(sort_by=metric, row_limit=row_limit)
with open(f"{curr_trace_dir}/rank{rank}_key_averages.txt", "w") as f:
print(
Expand All @@ -56,10 +56,13 @@ def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_c
torch.distributed.barrier()

@contextlib.contextmanager
def profiling_context(args, rank, *, global_step: int = 0):
enable_profiling = args.profile

def profiling_context(args, rank):
enable_profiling = args["profile"]
if enable_profiling:
output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}"
model_name = args["model_name"].split("/")[-1]
train_type = args["train_type"]

logger.info(f"Profiling enabled. Traces will be saved at {output_dir}")

Expand All @@ -85,9 +88,6 @@ def profiling_context(args, rank, *, global_step: int = 0):
export_memory_timeline = args["export_memory_timeline"]
with_stack = args["with_stack"] or args["export_memory_timeline"]
with_shapes = args["with_shapes"] or export_memory_timeline
model_name = args["model_name"].split("/")[-1]
train_type = args["train_type"]
output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}"
callback = partial(trace_handler, rank=rank,
export_memory_timeline=export_memory_timeline,
output_dir=output_dir,
Expand All @@ -102,7 +102,7 @@ def profiling_context(args, rank, *, global_step: int = 0):
],
with_stack=with_stack,
profile_memory=profile_memory,
with_shapes=with_shapes,
record_shapes=with_shapes,
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat),
on_trace_ready=callback,
experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None,
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
ddp_loss = torch.zeros(2).to(local_rank)

for batch_idx, batch in enumerate(dataloader):
prof.step_num = f"epoch{epoch}-batch{batch_idx}"
#prof.step_num = f"epoch{epoch}-batch{batch_idx}"

accumulate_grads = (batch_idx+1) % gradient_accumulation_steps == 0

Expand Down

0 comments on commit 373114d

Please sign in to comment.