Skip to content

Commit bc06a65

Browse files
committed
FIX-#1807: Fix passing zone in rayscale, add ability to override image
Signed-off-by: Vasilij Litvinov <vasilij.n.litvinov@intel.com>
1 parent dcff8da commit bc06a65

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

modin/experimental/cloud/cluster.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,22 @@ class Provider:
3434
__KNOWN = {AWS: [_RegionZone(region="us-west-1", zone="us-west-1a")]}
3535
__DEFAULT_HEAD = {AWS: "m5.large"}
3636
__DEFAULT_WORKER = {AWS: "m5.large"}
37+
__DEFAULT_IMAGE = {AWS: "ami-0f56279347d2fa43e"}
3738

3839
def __init__(
3940
self,
4041
name: str,
4142
credentials_file: str = None,
4243
region: str = None,
4344
zone: str = None,
45+
image: str = None,
4446
):
4547
"""
4648
Class that holds all information about particular connection to cluster provider, namely
4749
* provider name (must be one of known ones)
4850
* path to file with credentials (file format is provider-specific); omit to use global provider-default credentials
4951
* region and zone where cluster is to be spawned (optional, would be deduced if omitted)
52+
* image to use (optional, would use default for provider if omitted)
5053
"""
5154

5255
if name not in self.__KNOWN:
@@ -77,6 +80,7 @@ def __init__(
7780
self.credentials_file = (
7881
os.path.abspath(credentials_file) if credentials_file is not None else None
7982
)
83+
self.image = image or self.__DEFAULT_IMAGE[name]
8084

8185
@property
8286
def default_head_type(self):
@@ -197,6 +201,7 @@ def create(
197201
credentials: str = None,
198202
region: str = None,
199203
zone: str = None,
204+
image: str = None,
200205
project_name: str = None,
201206
cluster_name: str = "modin-cluster",
202207
workers: int = 4,
@@ -222,6 +227,9 @@ def create(
222227
If omitted a default for given provider will be taken.
223228
zone : str, optional
224229
Availability zone (part of region) where to spawn the cluster.
230+
If omitted a default for given provider and region will be taken.
231+
image: str, optional
232+
Image to use for spawning head and worker nodes.
225233
If omitted a default for given provider will be taken.
226234
project_name : str, optional
227235
Project name to assign to the cluster in cloud, for easier manual tracking.
@@ -247,12 +255,16 @@ def create(
247255
"""
248256
if not isinstance(provider, Provider):
249257
provider = Provider(
250-
name=provider, credentials_file=credentials, region=region, zone=zone
258+
name=provider,
259+
credentials_file=credentials,
260+
region=region,
261+
zone=zone,
262+
image=image,
251263
)
252264
else:
253-
if any(p is not None for p in (credentials, region, zone)):
265+
if any(p is not None for p in (credentials, region, zone, image)):
254266
warnings.warn(
255-
"Ignoring credentials, region and zone parameters because provider is specified as Provider descriptor, not as name",
267+
"Ignoring credentials, region, zone and image parameters because provider is specified as Provider descriptor, not as name",
256268
UserWarning,
257269
)
258270
if __spawner__ == "rayscale":

modin/experimental/cloud/rayscale.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class RayCluster(BaseCluster):
5959
os.path.abspath(os.path.dirname(__file__)), "ray-autoscaler.yml"
6060
)
6161
__instance_key = {Provider.AWS: "InstanceType"}
62+
__image_key = {Provider.AWS: "ImageId"}
6263
__credentials_env = {Provider.AWS: "AWS_SHARED_CREDENTIALS_FILE"}
6364

6465
def __init__(self, *a, **kw):
@@ -109,7 +110,7 @@ def __make_config(self):
109110
if self.provider.region:
110111
config["provider"]["region"] = self.provider.region
111112
if self.provider.zone:
112-
config["provider"]["zone"] = self.provider.zone
113+
config["provider"]["availability_zone"] = self.provider.zone
113114

114115
# connection details
115116
config["auth"]["ssh_user"] = "ubuntu"
@@ -120,10 +121,14 @@ def __make_config(self):
120121
# instance types
121122
try:
122123
instance_key = self.__instance_key[self.provider.name]
124+
image_key = self.__image_key[self.provider.name]
123125
except KeyError:
124126
raise ValueError(f"Unsupported provider: {self.provider.name}")
127+
125128
config["head_node"][instance_key] = self.head_node_type
129+
config["head_node"][image_key] = self.provider.image
126130
config["worker_nodes"][instance_key] = self.worker_node_type
131+
config["worker_nodes"][image_key] = self.provider.image
127132

128133
return _bootstrap_config(config)
129134

0 commit comments

Comments
 (0)