Skip to content

Commit

Permalink
[Onprem] Support for Different Type of GPUs + Small Bugfix (#1356)
Browse files Browse the repository at this point in the history
* Ok

* Great suggestion from Zhanghao

* fix
  • Loading branch information
michaelzhiluo authored Nov 4, 2022
1 parent 75ab3de commit fd6c335
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
8 changes: 5 additions & 3 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def fill_template(template_name: str,
output_prefix)
output_path = os.path.abspath(output_path)

# Add yaml file path to the template variables.
variables['sky_ray_yaml_remote_path'] = SKY_RAY_YAML_REMOTE_PATH
variables['sky_ray_yaml_local_path'] = output_path
# Write out yaml config.
template = jinja2.Template(template)
content = template.render(**variables)
Expand Down Expand Up @@ -786,6 +783,11 @@ def write_cluster_config(
# Sky remote utils.
'sky_remote_path': SKY_REMOTE_PATH,
'sky_local_path': str(local_wheel_path),
# Add yaml file path to the template variables.
'sky_ray_yaml_remote_path': SKY_RAY_YAML_REMOTE_PATH,
'sky_ray_yaml_local_path':
tmp_yaml_path
if not isinstance(cloud, clouds.Local) else yaml_path,
'sky_version': str(version.parse(sky.__version__)),
'sky_wheel_hash': wheel_hash,
# Local IP handling (optional).
Expand Down
13 changes: 10 additions & 3 deletions sky/backends/onprem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ def get_local_cluster_accelerators(
'T4',
'P4',
'K80',
'A100',]
'A100',
'1080',
'2080',
'A5000'
'A6000']
accelerators_dict = {}
for acc in all_accelerators:
output_str = os.popen(f'lspci | grep \\'{acc}\\'').read()
Expand Down Expand Up @@ -358,9 +362,10 @@ def _stop_ray_workers(runner: command_runner.SSHCommandRunner):

# Launching Ray on the head node.
head_resources = json.dumps(custom_resources[0], separators=(',', ':'))
head_gpu_count = sum(list(custom_resources[0].values()))
head_cmd = ('ray start --head --port=6379 '
'--object-manager-port=8076 --dashboard-port 8265 '
f'--resources={head_resources!r}')
f'--resources={head_resources!r} --num-gpus={head_gpu_count}')

with console.status('[bold cyan]Launching ray cluster on head'):
backend_utils.run_command_and_handle_ssh_failure(
Expand Down Expand Up @@ -399,9 +404,11 @@ def _start_ray_workers(

worker_resources = json.dumps(custom_resources[idx + 1],
separators=(',', ':'))
worker_gpu_count = sum(list(custom_resources[idx + 1].values()))
worker_cmd = (f'ray start --address={head_ip}:6379 '
'--object-manager-port=8076 --dashboard-port 8265 '
f'--resources={worker_resources!r}')
f'--resources={worker_resources!r} '
f'--num-gpus={worker_gpu_count}')
backend_utils.run_command_and_handle_ssh_failure(
runner,
worker_cmd,
Expand Down
7 changes: 6 additions & 1 deletion sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ def _set_accelerators(
except ValueError:
with ux_utils.print_exception_no_traceback():
raise ValueError(parse_error) from None
assert len(accelerators) == 1, accelerators

# Ignore check for the local cloud case.
# It is possible the accelerators dict can contain multiple
# types of accelerators for some on-prem clusters.
if not isinstance(self._cloud, clouds.Local):
assert len(accelerators) == 1, accelerators

# Canonicalize the accelerator names.
accelerators = {
Expand Down

0 comments on commit fd6c335

Please sign in to comment.