@@ -32,7 +32,7 @@ def launch(
32
32
ssh_config_file : str | os .PathLike | None = None ,
33
33
backend : Literal ["mpi" , "gloo" , "nccl" , "ucc" ] | None = None ,
34
34
log_dir : str = "./logs" ,
35
- clone_env_vars : list [str ] = ["PYTHON*" , "CUDA*" ],
35
+ clone_env_vars : list [str ] = ["PYTHON*" , "CUDA*" , "TORCH*" , "PYTORCH*" , "NCCL*" ],
36
36
env_file : str | os .PathLike | None = None ,
37
37
):
38
38
if not dist .is_available ():
@@ -77,14 +77,14 @@ def launch(
77
77
78
78
log_dir = os .path .abspath (log_dir )
79
79
80
+ explicit_env_vars = ["PATH" , "LD_LIBRARY" , "LIBRARY_PATH" ]
80
81
env_export_string = " " .join (
81
82
f'{ k } ="{ v } "'
82
83
for k , v in os .environ .items ()
83
- if any (fnmatch .fnmatch (e , k ) for e in clone_env_vars )
84
+ if any (fnmatch .fnmatch (k , e ) for e in clone_env_vars + explicit_env_vars )
84
85
)
85
86
if env_export_string != "" :
86
87
env_export_string = f"export { env_export_string } && "
87
-
88
88
env_file_string = f"source { env_file } && " if env_file is not None else ""
89
89
90
90
# start agents on each node
0 commit comments