@@ -76,15 +76,15 @@ def train_func(config: Dict[str, Any]):
7676 } if config ["General" ].get ("checkpoint_dir" ) else None
7777 })
7878
79- try :
79+ try :
8080 common .logger .info (f"trainer prepare start" )
8181 trainer .prepare (model , tokenizer , datasets , optimizer , accelerator )
8282 except Exception as e :
8383 common .logger .critical (e , exc_info = True )
8484 exit (1 )
8585 common .logger .info (f"trainer prepare finish" )
8686
87- try :
87+ try :
8888 common .logger .info (f"train start" )
8989 trainer .train ()
9090 except Exception as e :
@@ -101,12 +101,12 @@ def main(external_config = None):
101101 num_training_workers = config ["Training" ].get ("num_training_workers" )
102102 resources_per_worker = config ["Training" ].get ("resources_per_worker" )
103103
104- device = config ["Training" ]["device" ]
104+ device = config ["Training" ]["device" ]. lower ()
105105 if not ray .is_initialized ():
106106 runtime_env = {
107107 "env_vars" : {
108108 "OMP_NUM_THREADS" : str (resources_per_worker ["CPU" ]),
109- "ACCELERATE_USE_CPU" : "True" if device == "CPU " else "False" ,
109+ "ACCELERATE_USE_CPU" : "True" if device == "cpu " else "False" ,
110110 "ACCELERATE_USE_IPEX" : "False" ,
111111 "ACCELERATE_MIXED_PRECISION" : "no" ,
112112 "CCL_WORKER_COUNT" : "1" ,
@@ -122,14 +122,14 @@ def main(external_config = None):
122122 num_workers = num_training_workers ,
123123 resources_per_worker = resources_per_worker ,
124124 placement_strategy = "SPREAD" ,
125- use_gpu = False if device == "CPU " else True
125+ use_gpu = False if device == "cpu " else True
126126 )
127127
128128 if config .get ("torch_config" , None ) is None :
129- torch_config = common .TorchConfig (backend = "ccl" if device == "CPU " else None )
129+ torch_config = common .TorchConfig (backend = "ccl" if device == "cpu " else None , device = device )
130130 else :
131131 customer_torch_config = config .get ("torch_config" )
132- torch_config = common .TorchConfig (** customer_torch_config )
132+ torch_config = common .TorchConfig (** customer_torch_config , device = device )
133133
134134 if config .get ("failure_config" , None ) is None :
135135 failure_config = FailureConfig ()
@@ -149,10 +149,11 @@ def main(external_config = None):
149149 train_func ,
150150 train_loop_config = config ,
151151 scaling_config = scaling_config ,
152- torch_config = torch_config ,
153- run_config = run_config
152+ torch_config = torch_config ,
153+ run_config = run_config
154154 )
155155 results = trainer .fit ()
156+
156157 return results
157158
158159if __name__ == "__main__" :
0 commit comments