Skip to content

Commit 37353b1

Browse files
committed
fix checkpoint merge tools
1 parent 0229a69 commit 37353b1

File tree

4 files changed

+6
-5
lines changed

4 files changed

+6
-5
lines changed

megatron/checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
233233
checkpoint_name = get_checkpoint_name(args.save, iteration)
234234

235235
# Save distributed optimizer's custom parameter state.
236-
if args.use_distributed_optimizer:
236+
if args.use_distributed_optimizer and not args.no_save_optim:
237237
optim_checkpoint_name = \
238238
get_distributed_optimizer_checkpoint_name(checkpoint_name)
239239
ensure_directory_exists(optim_checkpoint_name)

megatron/global_vars.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _set_signal_handler():
8080

8181

8282

83-
def set_global_variables(args):
83+
def set_global_variables(args, build_tokenizer=True):
8484
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
8585

8686
assert args is not None
@@ -89,7 +89,8 @@ def set_global_variables(args):
8989
set_args(args)
9090

9191
_build_num_microbatches_calculator(args)
92-
_ = _build_tokenizer(args)
92+
if build_tokenizer:
93+
_ = _build_tokenizer(args)
9394
_set_tensorboard_writer(args)
9495
_set_adlr_autoresume(args)
9596
_set_timers(args)

tools/checkpoint_loader_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def get_models(count, dtype):
152152
models[vp_rank].append(model_[vp_rank])
153153
return models
154154

155-
set_global_variables(margs)
155+
set_global_variables(margs, build_tokenizer=False)
156156
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
157157
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
158158
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)

tools/checkpoint_saver_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def check_message(msg):
164164

165165
validate_args(margs)
166166

167-
set_global_variables(margs)
167+
set_global_variables(margs, build_tokenizer=False)
168168

169169
# margs = megatron args
170170
margs = get_args()

0 commit comments

Comments
 (0)