Skip to content

Commit f7b6e86

Browse files
author
harborn
authored
[common] add device option for TorchConfig (intel#126)
* add device option for TorchConfig * update * update * update
1 parent 5157502 commit f7b6e86

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

common/torch_config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ray.train.torch.config import TorchConfig as RayTorchConfig
33
from ray.train._internal.worker_group import WorkerGroup
44
from dataclasses import dataclass
5+
from typing import Optional
56
import os
67
import sys
78
# The package importlib_metadata is in a different place, depending on the Python version.
@@ -13,9 +14,11 @@
1314

1415
@dataclass
1516
class TorchConfig(RayTorchConfig):
17+
device: Optional[str] = None
1618

1719
@property
1820
def backend_cls(self):
21+
EnableCCLBackend.device = self.device
1922
return EnableCCLBackend
2023

2124

@@ -41,11 +44,13 @@ def libs_import():
4144
) from ccl_not_exist
4245

4346

44-
def _del_torch_distributed_env_vars():
45-
del os.environ["ACCELERATE_TORCH_DEVICE"]
47+
def _set_torch_distributed_env_vars(device):
48+
if device is not None:
49+
os.environ["ACCELERATE_TORCH_DEVICE"] = device
4650

4751

4852
class EnableCCLBackend(_TorchBackend):
53+
device: Optional[str] = None
4954

5055
def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
5156
for i in range(len(worker_group)):
@@ -54,4 +59,4 @@ def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
5459

5560
def on_training_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
5661
super().on_training_start(worker_group, backend_config)
57-
worker_group.execute(_del_torch_distributed_env_vars)
62+
worker_group.execute(_set_torch_distributed_env_vars, self.device)

finetune/finetune.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def train_func(config: Dict[str, Any]):
7676
} if config["General"].get("checkpoint_dir") else None
7777
})
7878

79-
try :
79+
try:
8080
common.logger.info(f"trainer prepare start")
8181
trainer.prepare(model, tokenizer, datasets, optimizer, accelerator)
8282
except Exception as e:
8383
common.logger.critical(e, exc_info=True)
8484
exit(1)
8585
common.logger.info(f"trainer prepare finish")
8686

87-
try :
87+
try:
8888
common.logger.info(f"train start")
8989
trainer.train()
9090
except Exception as e:
@@ -101,12 +101,12 @@ def main(external_config = None):
101101
num_training_workers = config["Training"].get("num_training_workers")
102102
resources_per_worker = config["Training"].get("resources_per_worker")
103103

104-
device = config["Training"]["device"]
104+
device = config["Training"]["device"].lower()
105105
if not ray.is_initialized():
106106
runtime_env = {
107107
"env_vars": {
108108
"OMP_NUM_THREADS": str(resources_per_worker["CPU"]),
109-
"ACCELERATE_USE_CPU": "True" if device == "CPU" else "False",
109+
"ACCELERATE_USE_CPU": "True" if device == "cpu" else "False",
110110
"ACCELERATE_USE_IPEX": "False",
111111
"ACCELERATE_MIXED_PRECISION": "no",
112112
"CCL_WORKER_COUNT": "1",
@@ -122,14 +122,14 @@ def main(external_config = None):
122122
num_workers = num_training_workers,
123123
resources_per_worker = resources_per_worker,
124124
placement_strategy = "SPREAD",
125-
use_gpu = False if device == "CPU" else True
125+
use_gpu = False if device == "cpu" else True
126126
)
127127

128128
if config.get("torch_config", None) is None:
129-
torch_config = common.TorchConfig(backend = "ccl" if device == "CPU" else None)
129+
torch_config = common.TorchConfig(backend = "ccl" if device == "cpu" else None, device=device)
130130
else:
131131
customer_torch_config = config.get("torch_config")
132-
torch_config = common.TorchConfig(**customer_torch_config)
132+
torch_config = common.TorchConfig(**customer_torch_config, device=device)
133133

134134
if config.get("failure_config", None) is None:
135135
failure_config = FailureConfig()
@@ -149,10 +149,11 @@ def main(external_config = None):
149149
train_func,
150150
train_loop_config=config,
151151
scaling_config=scaling_config,
152-
torch_config = torch_config,
153-
run_config = run_config
152+
torch_config=torch_config,
153+
run_config=run_config
154154
)
155155
results = trainer.fit()
156+
156157
return results
157158

158159
if __name__ == "__main__":

0 commit comments

Comments
 (0)