49
49
from compose_rl .controllers import BaseDistributedGPUActor , SPMDActorGroup
50
50
from compose_rl .controllers .buffer import Buffer
51
51
from compose_rl .algorithms .online .callback_utils import preprocess_batches
52
- from databricks .sdk import WorkspaceClient
53
52
54
- MLFLOW_RUN_NAME = os .environ ['COMPOSER_RUN_NAME' ] # SHOULD BE SET BY MCLI
55
- MLFLOW_EXPERIMENT_NAME = f'/Users/{ WorkspaceClient ().current_user .me ().user_name } /test_single_controller'
56
53
57
54
@contextmanager
58
55
def time_it (name : str ):
@@ -193,9 +190,11 @@ def build_ppo_trainer(self):
193
190
dummy_distributed_sampler = torch .utils .data .distributed .DistributedSampler (dummy_dataset )
194
191
dummy_dataloader = torch .utils .data .DataLoader (dummy_dataset , sampler = dummy_distributed_sampler )
195
192
193
+ # TODO: We might be able to skip part of the setup here as some mlflow
194
+ # environment variables are set in the _setup_mlflow function
196
195
mlflow_logger = MLFlowLogger (
197
196
experiment_name = self .config .loggers .mlflow .experiment_name ,
198
- run_name = f'test_single_controller_ppo_async_ { self .config .max_async_step } _deepseek_l8b_open_r1_48k' ,
197
+ run_name = self .config .loggers . mlflow . run_name ,
199
198
tracking_uri = self .config .loggers .mlflow .tracking_uri ,
200
199
)
201
200
@@ -397,7 +396,7 @@ def __init__(
397
396
self .eval_interval_num = int (config .eval_interval .strip ("iter" ))
398
397
self .num_batches_per_update = config .variables .num_batches_per_update
399
398
self .experiment_name = config .loggers .mlflow .experiment_name
400
- self .run_name = f'test_single_controller_ppo_async_ { config .max_async_step } _deepseek_l8b_open_r1_48k'
399
+ self .run_name = config .loggers . mlflow . run_name
401
400
402
401
self .callback = self .build_callback ()
403
402
@@ -464,20 +463,19 @@ def __init__(
464
463
self .tokenizer_pad_token_id = ray .get (self .streaming_dataset_actor .get_tokenizer_pad_token_id .remote ())
465
464
self .prompt_handler_config = ray .get (self .streaming_dataset_actor .get_prompt_handler_config .remote ())
466
465
self .max_gen_len = self .prompt_handler_config ['max_gen_len' ]
467
-
468
- # Load iter_num from the checkpoint
469
466
self .save_folder = os .path .join (config .save_folder , 'RolloutAgent' )
470
-
471
467
self .iter_num = 0
472
468
473
- # Load the latest checkpoint
474
-
475
- self .latest_checkpoint = os .path .join (self .save_folder , 'latest_rollout_agent.symlink' ) # TODO: This might need to use the updated path
476
-
477
- if config .autoresume and _artifact_exists (self .latest_checkpoint ):
469
+ # Load the latest checkpoint if we are autoresuming.
470
+ # Note that since we are checking if the checkpoint exists with
471
+ # mlflow.client.list_artifacts, we need to use the relative path to
472
+ # the checkpoint (i.e. not include dbfs://.../{mlflow_experiment_id}/{mlflow_run_id}
473
+ # in the path).
474
+ self .latest_checkpoint_path = os .path .join (self .save_folder , 'latest_rollout_agent.symlink' )
475
+ if config .autoresume and _artifact_exists (self .latest_checkpoint_path ):
478
476
print (f'Autoresuming from checkpoint for RolloutAgent.' )
479
- get_file (self .latest_checkpoint , self .latest_checkpoint , overwrite = True )
480
- with open (self .latest_checkpoint , 'rb' ) as f :
477
+ get_file (self .latest_checkpoint_path , self .latest_checkpoint_path , overwrite = True )
478
+ with open (self .latest_checkpoint_path , 'rb' ) as f :
481
479
checkpoint = pickle .load (f )
482
480
self .iter_num = checkpoint ['iter_num' ]
483
481
print (f'Loading streaming dataloader state dict for RolloutAgent.' , checkpoint ['streaming_dataloader' ])
@@ -521,29 +519,31 @@ def get_next_iter_rollouts(self):
521
519
processed_sequences = torch .cat ([all_prompts , padded_responses ], dim = - 1 )
522
520
iter_data ['sequences' ] = processed_sequences
523
521
524
- save_folder_iter = os .path .join (self .save_folder , f'iter_{ self .iter_num } ' )
525
- checkpoint_path = os .path .join (save_folder_iter , 'checkpoint.pt' )
522
+ save_folder_for_curr_iter = os .path .join (self .save_folder , f'iter_{ self .iter_num } ' )
523
+ checkpoint_path = os .path .join (save_folder_for_curr_iter , 'checkpoint.pt' )
526
524
self .iter_num += 1
527
525
528
526
streaming_dataloader_state_dict = ray .get (self .streaming_dataset_actor .get_dataloader_state_dict .remote ())
529
527
print (f'Streaming dataloader state dict for RolloutAgent.' , streaming_dataloader_state_dict )
530
528
531
529
# make sure that the folder path can exist
532
- os .makedirs (save_folder_iter , exist_ok = True )
530
+ os .makedirs (save_folder_for_curr_iter , exist_ok = True )
533
531
with open (checkpoint_path , 'wb' ) as f :
534
532
pickle .dump ({
535
533
'iter_data' : iter_data ,
536
534
'iter_num' : self .iter_num ,
537
535
'streaming_dataloader' : streaming_dataloader_state_dict ,
538
536
}, f )
539
537
540
- mlflow .log_artifact (checkpoint_path , save_folder_iter , run_id = _get_mlflow_run_id ())
538
+ # log the checkpoint to mlflow
539
+ mlflow .log_artifact (checkpoint_path , save_folder_for_curr_iter , run_id = _get_mlflow_run_id ())
541
540
542
- if os .path .exists (self .latest_checkpoint ):
543
- os .remove (self .latest_checkpoint )
544
- create_symlink_file (checkpoint_path , self .latest_checkpoint )
545
-
546
- mlflow .log_artifact (self .latest_checkpoint , self .config .save_folder , run_id = _get_mlflow_run_id ())
541
+ if os .path .exists (self .latest_checkpoint_path ):
542
+ os .remove (self .latest_checkpoint_path )
543
+ create_symlink_file (checkpoint_path , self .latest_checkpoint_path )
544
+
545
+ # log the latest checkpoint to mlflow
546
+ mlflow .log_artifact (self .latest_checkpoint_path , self .save_folder , run_id = _get_mlflow_run_id ())
547
547
return iter_data
548
548
549
549
async def run (self , num_iterations : int , experience_buffer : 'ExperienceBuffer' , lock : asyncio .Lock , rollout_semaphore : asyncio .Semaphore ):
@@ -743,42 +743,64 @@ async def train_async(self, max_duration: int | str):
743
743
def _get_mlflow_run_id () -> Optional [str ]:
744
744
return os .environ .get ('MLFLOW_RUN_ID' , None )
745
745
746
- def _setup_mlflow ():
747
- print ('setting up mlflow' )
746
+ def _get_valid_mlflow_experiment_name (config : Any ) -> str :
747
+ """Fixes the experiment name to be an absolute path for mlflow.
748
+
749
+ MLflow requires the experiment name to be an absolute path.
750
+ If the experiment name is not an absolute path, we prepend the current
751
+ user's username to the experiment name.
752
+ """
753
+ mlflow_experiment_name = config .loggers .mlflow .experiment_name
754
+ if mlflow_experiment_name .startswith ('/' ):
755
+ return mlflow_experiment_name
756
+ else :
757
+ from databricks .sdk import WorkspaceClient
758
+ return f'/Users/{ WorkspaceClient ().current_user .me ().user_name } /{ mlflow_experiment_name } '
759
+
760
+ def _setup_mlflow (config : Any ):
748
761
dist .init_process_group (backend = 'gloo' )
749
- # Create a new MLFlow run to be used for the entire run
750
762
mlflow .set_tracking_uri ('databricks' )
751
763
752
- # get mlflow experiment
753
- experiment = mlflow .get_experiment_by_name (MLFLOW_EXPERIMENT_NAME )
754
- if experiment is None :
755
- experiment_id = mlflow .create_experiment (MLFLOW_EXPERIMENT_NAME )
756
- else :
757
- experiment_id = experiment .experiment_id
758
- mlflow .set_experiment (experiment_id = experiment_id )
764
+ # mlflow experiment name needs to be an absolute path for databricks mlflow.
765
+ mlflow_experiment_name = _get_valid_mlflow_experiment_name (config )
766
+ setattr (config .loggers .mlflow , 'experiment_name' , mlflow_experiment_name )
767
+ # COMPOSER_RUN_NAME is set for interactive mode as well.
768
+ mlflow_run_name = os .environ ['COMPOSER_RUN_NAME' ]
769
+ setattr (config .loggers .mlflow , 'run_name' , mlflow_run_name )
759
770
771
+ # get mlflow experiment if it exists, otherwise create it and set it to all ranks.
772
+ experiment_id = None
773
+ if composer_dist .get_global_rank () == 0 :
774
+ experiment = mlflow .get_experiment_by_name (mlflow_experiment_name )
775
+ if experiment is None :
776
+ experiment_id = mlflow .create_experiment (mlflow_experiment_name )
777
+ else :
778
+ experiment_id = experiment .experiment_id
779
+ experiment_id_broadcast_list = [experiment_id ]
780
+ composer_dist .broadcast_object_list (experiment_id_broadcast_list , src = 0 )
781
+ experiment_id = experiment_id_broadcast_list [0 ]
760
782
783
+ mlflow .set_experiment (experiment_id = experiment_id )
761
784
785
+ # get mlflow run if it exists and we are autoresuming, otherwise create it and set it to all ranks.
762
786
run_id = None
763
787
if composer_dist .get_global_rank () == 0 :
764
- # find a preexisting run if it exists
765
788
existing_runs = mlflow .search_runs (
766
789
experiment_ids = [experiment_id ],
767
- filter_string = f'tags.run_name = "{ MLFLOW_RUN_NAME } "' ,
790
+ filter_string = f'tags.run_name = "{ mlflow_run_name } "' ,
768
791
output_format = 'list' ,
769
792
) if config .autoresume else []
770
793
if len (existing_runs ) > 0 :
771
794
run_id = existing_runs [0 ].info .run_id
772
795
print (f'Resuming mlflow run with run id: { run_id } ' )
773
796
else :
774
- run_id = mlflow .start_run (run_name = MLFLOW_RUN_NAME ).info .run_id
797
+ run_id = mlflow .start_run (run_name = mlflow_run_name ).info .run_id
775
798
print (f'Creating new mlflow run with run id: { run_id } ' )
776
- broadcast_list = [run_id ]
777
-
778
- composer_dist . broadcast_object_list ( broadcast_list , src = 0 )
799
+ run_id_broadcast_list = [run_id ]
800
+ composer_dist . broadcast_object_list ( run_id_broadcast_list , src = 0 )
801
+ run_id = run_id_broadcast_list [ 0 ]
779
802
780
803
# set all the right enviornment variables
781
- run_id = broadcast_list [0 ]
782
804
assert run_id is not None and experiment_id is not None , "Run ID and experiment ID must be set"
783
805
os .environ ['MLFLOW_RUN_ID' ] = run_id
784
806
os .environ ['MLFLOW_EXPERIMENT_ID' ] = experiment_id
@@ -812,8 +834,7 @@ def _artifact_exists(artifact_path: str) -> bool:
812
834
813
835
# If we got here, the path exists (root or found item).
814
836
return True
815
-
816
-
837
+
817
838
818
839
def _run_single_controller_ppo (
819
840
config : Any ,
@@ -830,7 +851,7 @@ def _run_single_controller_ppo(
830
851
# Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually
831
852
os .environ ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES' ] = '1'
832
853
833
- _setup_mlflow ()
854
+ _setup_mlflow (config )
834
855
835
856
with start_ray_server () as _address :
836
857
# only rank 0 is the master controller
0 commit comments