Skip to content

Commit b46fc36

Browse files
committed
rebased and updated
1 parent 02df60c commit b46fc36

File tree

2 files changed

+67
-48
lines changed

2 files changed

+67
-48
lines changed

test_single_controller_ppo.py

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@
4949
from compose_rl.controllers import BaseDistributedGPUActor, SPMDActorGroup
5050
from compose_rl.controllers.buffer import Buffer
5151
from compose_rl.algorithms.online.callback_utils import preprocess_batches
52-
from databricks.sdk import WorkspaceClient
5352

54-
MLFLOW_RUN_NAME=os.environ['COMPOSER_RUN_NAME'] # SHOULD BE SET BY MCLI
55-
MLFLOW_EXPERIMENT_NAME=f'/Users/{WorkspaceClient().current_user.me().user_name}/test_single_controller'
5653

5754
@contextmanager
5855
def time_it(name: str):
@@ -193,9 +190,11 @@ def build_ppo_trainer(self):
193190
dummy_distributed_sampler = torch.utils.data.distributed.DistributedSampler(dummy_dataset)
194191
dummy_dataloader = torch.utils.data.DataLoader(dummy_dataset, sampler=dummy_distributed_sampler)
195192

193+
# TODO: We might be able to skip part of the setup here as some mlflow
194+
# environment variables are set in the _setup_mlflow function
196195
mlflow_logger = MLFlowLogger(
197196
experiment_name=self.config.loggers.mlflow.experiment_name,
198-
run_name=f'test_single_controller_ppo_async_{self.config.max_async_step}_deepseek_l8b_open_r1_48k',
197+
run_name=self.config.loggers.mlflow.run_name,
199198
tracking_uri=self.config.loggers.mlflow.tracking_uri,
200199
)
201200

@@ -397,7 +396,7 @@ def __init__(
397396
self.eval_interval_num = int(config.eval_interval.strip("iter"))
398397
self.num_batches_per_update = config.variables.num_batches_per_update
399398
self.experiment_name = config.loggers.mlflow.experiment_name
400-
self.run_name = f'test_single_controller_ppo_async_{config.max_async_step}_deepseek_l8b_open_r1_48k'
399+
self.run_name = config.loggers.mlflow.run_name
401400

402401
self.callback = self.build_callback()
403402

@@ -464,20 +463,19 @@ def __init__(
464463
self.tokenizer_pad_token_id = ray.get(self.streaming_dataset_actor.get_tokenizer_pad_token_id.remote())
465464
self.prompt_handler_config = ray.get(self.streaming_dataset_actor.get_prompt_handler_config.remote())
466465
self.max_gen_len = self.prompt_handler_config['max_gen_len']
467-
468-
# Load iter_num from the checkpoint
469466
self.save_folder = os.path.join(config.save_folder, 'RolloutAgent')
470-
471467
self.iter_num = 0
472468

473-
# Load the latest checkpoint
474-
475-
self.latest_checkpoint = os.path.join(self.save_folder, 'latest_rollout_agent.symlink') # TODO: This might need to use the updated path
476-
477-
if config.autoresume and _artifact_exists(self.latest_checkpoint):
469+
# Load the latest checkpoint if we are autoresuming.
470+
# Note that since we are checking if the checkpoint exists with
471+
# mlflow.client.list_artifacts, we need to use the relative path to
472+
# the checkpoint (i.e. not include dbfs://.../{mlflow_experiment_id}/{mlflow_run_id}
473+
# in the path).
474+
self.latest_checkpoint_path = os.path.join(self.save_folder, 'latest_rollout_agent.symlink')
475+
if config.autoresume and _artifact_exists(self.latest_checkpoint_path):
478476
print(f'Autoresuming from checkpoint for RolloutAgent.')
479-
get_file(self.latest_checkpoint, self.latest_checkpoint, overwrite=True)
480-
with open(self.latest_checkpoint, 'rb') as f:
477+
get_file(self.latest_checkpoint_path, self.latest_checkpoint_path, overwrite=True)
478+
with open(self.latest_checkpoint_path, 'rb') as f:
481479
checkpoint = pickle.load(f)
482480
self.iter_num = checkpoint['iter_num']
483481
print(f'Loading streaming dataloader state dict for RolloutAgent.', checkpoint['streaming_dataloader'])
@@ -521,29 +519,31 @@ def get_next_iter_rollouts(self):
521519
processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1)
522520
iter_data['sequences'] = processed_sequences
523521

524-
save_folder_iter = os.path.join(self.save_folder, f'iter_{self.iter_num}')
525-
checkpoint_path = os.path.join(save_folder_iter, 'checkpoint.pt')
522+
save_folder_for_curr_iter = os.path.join(self.save_folder, f'iter_{self.iter_num}')
523+
checkpoint_path = os.path.join(save_folder_for_curr_iter, 'checkpoint.pt')
526524
self.iter_num += 1
527525

528526
streaming_dataloader_state_dict = ray.get(self.streaming_dataset_actor.get_dataloader_state_dict.remote())
529527
print(f'Streaming dataloader state dict for RolloutAgent.', streaming_dataloader_state_dict)
530528

531529
# make sure that the folder path can exist
532-
os.makedirs(save_folder_iter, exist_ok=True)
530+
os.makedirs(save_folder_for_curr_iter, exist_ok=True)
533531
with open(checkpoint_path, 'wb') as f:
534532
pickle.dump({
535533
'iter_data': iter_data,
536534
'iter_num': self.iter_num,
537535
'streaming_dataloader': streaming_dataloader_state_dict,
538536
}, f)
539537

540-
mlflow.log_artifact(checkpoint_path, save_folder_iter, run_id=_get_mlflow_run_id())
538+
# log the checkpoint to mlflow
539+
mlflow.log_artifact(checkpoint_path, save_folder_for_curr_iter, run_id=_get_mlflow_run_id())
541540

542-
if os.path.exists(self.latest_checkpoint):
543-
os.remove(self.latest_checkpoint)
544-
create_symlink_file(checkpoint_path, self.latest_checkpoint)
545-
546-
mlflow.log_artifact(self.latest_checkpoint, self.config.save_folder, run_id=_get_mlflow_run_id())
541+
if os.path.exists(self.latest_checkpoint_path):
542+
os.remove(self.latest_checkpoint_path)
543+
create_symlink_file(checkpoint_path, self.latest_checkpoint_path)
544+
545+
# log the latest checkpoint to mlflow
546+
mlflow.log_artifact(self.latest_checkpoint_path, self.save_folder, run_id=_get_mlflow_run_id())
547547
return iter_data
548548

549549
async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', lock: asyncio.Lock, rollout_semaphore: asyncio.Semaphore):
@@ -743,42 +743,64 @@ async def train_async(self, max_duration: int | str):
743743
def _get_mlflow_run_id() -> Optional[str]:
744744
return os.environ.get('MLFLOW_RUN_ID', None)
745745

746-
def _setup_mlflow():
747-
print('setting up mlflow')
746+
def _get_valid_mlflow_experiment_name(config: Any) -> str:
747+
"""Fixes the experiment name to be an absolute path for mlflow.
748+
749+
MLflow requires the experiment name to be an absolute path.
750+
If the experiment name is not an absolute path, we prepend the current
751+
user's username to the experiment name.
752+
"""
753+
mlflow_experiment_name = config.loggers.mlflow.experiment_name
754+
if mlflow_experiment_name.startswith('/'):
755+
return mlflow_experiment_name
756+
else:
757+
from databricks.sdk import WorkspaceClient
758+
return f'/Users/{WorkspaceClient().current_user.me().user_name}/{mlflow_experiment_name}'
759+
760+
def _setup_mlflow(config: Any):
748761
dist.init_process_group(backend='gloo')
749-
# Create a new MLFlow run to be used for the entire run
750762
mlflow.set_tracking_uri('databricks')
751763

752-
# get mlflow experiment
753-
experiment = mlflow.get_experiment_by_name(MLFLOW_EXPERIMENT_NAME)
754-
if experiment is None:
755-
experiment_id = mlflow.create_experiment(MLFLOW_EXPERIMENT_NAME)
756-
else:
757-
experiment_id = experiment.experiment_id
758-
mlflow.set_experiment(experiment_id=experiment_id)
764+
# mlflow experiment name needs to be an absolute path for databricks mlflow.
765+
mlflow_experiment_name = _get_valid_mlflow_experiment_name(config)
766+
setattr(config.loggers.mlflow, 'experiment_name', mlflow_experiment_name)
767+
# COMPOSER_RUN_NAME is set for interactive mode as well.
768+
mlflow_run_name = os.environ['COMPOSER_RUN_NAME']
769+
setattr(config.loggers.mlflow, 'run_name', mlflow_run_name)
759770

771+
# get mlflow experiment if it exists, otherwise create it and set it to all ranks.
772+
experiment_id = None
773+
if composer_dist.get_global_rank() == 0:
774+
experiment = mlflow.get_experiment_by_name(mlflow_experiment_name)
775+
if experiment is None:
776+
experiment_id = mlflow.create_experiment(mlflow_experiment_name)
777+
else:
778+
experiment_id = experiment.experiment_id
779+
experiment_id_broadcast_list = [experiment_id]
780+
composer_dist.broadcast_object_list(experiment_id_broadcast_list, src=0)
781+
experiment_id = experiment_id_broadcast_list[0]
760782

783+
mlflow.set_experiment(experiment_id=experiment_id)
761784

785+
# get mlflow run if it exists and we are autoresuming, otherwise create it and set it to all ranks.
762786
run_id = None
763787
if composer_dist.get_global_rank() == 0:
764-
# find a preexisting run if it exists
765788
existing_runs = mlflow.search_runs(
766789
experiment_ids=[experiment_id],
767-
filter_string=f'tags.run_name = "{MLFLOW_RUN_NAME}"',
790+
filter_string=f'tags.run_name = "{mlflow_run_name}"',
768791
output_format='list',
769792
) if config.autoresume else []
770793
if len(existing_runs) > 0:
771794
run_id = existing_runs[0].info.run_id
772795
print(f'Resuming mlflow run with run id: {run_id}')
773796
else:
774-
run_id = mlflow.start_run(run_name=MLFLOW_RUN_NAME).info.run_id
797+
run_id = mlflow.start_run(run_name=mlflow_run_name).info.run_id
775798
print(f'Creating new mlflow run with run id: {run_id}')
776-
broadcast_list = [run_id]
777-
778-
composer_dist.broadcast_object_list(broadcast_list, src=0)
799+
run_id_broadcast_list = [run_id]
800+
composer_dist.broadcast_object_list(run_id_broadcast_list, src=0)
801+
run_id = run_id_broadcast_list[0]
779802

780803
# set all the right enviornment variables
781-
run_id = broadcast_list[0]
782804
assert run_id is not None and experiment_id is not None, "Run ID and experiment ID must be set"
783805
os.environ['MLFLOW_RUN_ID'] = run_id
784806
os.environ['MLFLOW_EXPERIMENT_ID'] = experiment_id
@@ -812,8 +834,7 @@ def _artifact_exists(artifact_path: str) -> bool:
812834

813835
# If we got here, the path exists (root or found item).
814836
return True
815-
816-
837+
817838

818839
def _run_single_controller_ppo(
819840
config: Any,
@@ -830,7 +851,7 @@ def _run_single_controller_ppo(
830851
# Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually
831852
os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1'
832853

833-
_setup_mlflow()
854+
_setup_mlflow(config)
834855

835856
with start_ray_server() as _address:
836857
# only rank 0 is the master controller

yamls/single-controller-grpo-workflow.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ parameters:
5050
loggers:
5151
mlflow:
5252
tags:
53-
run:
54-
run_name: null
5553
group: grpo
5654
tracking_uri: databricks
5755
experiment_name: test_single_controller_ppo
@@ -163,7 +161,7 @@ parameters:
163161
gradient_clipping:
164162
clipping_type: norm
165163
clipping_threshold: 0.001
166-
autoresume: true
164+
autoresume: false
167165
log_config: true
168166
fsdp_config:
169167
verbose: false
@@ -178,7 +176,7 @@ parameters:
178176
activation_checkpointing: true
179177
activation_checkpointing_reentrant: false
180178
max_seq_len: 10240
181-
save_folder: /tmp/checkpoints
179+
save_folder: artifacts/checkpoints
182180
dist_timeout: 1800
183181
max_duration: 10iter
184182
progress_bar: false

0 commit comments

Comments
 (0)