@@ -34,19 +34,22 @@ class Provider:
3434 __KNOWN = {AWS : [_RegionZone (region = "us-west-1" , zone = "us-west-1a" )]}
3535 __DEFAULT_HEAD = {AWS : "m5.large" }
3636 __DEFAULT_WORKER = {AWS : "m5.large" }
37+ __DEFAULT_IMAGE = {AWS : "ami-0f56279347d2fa43e" }
3738
3839 def __init__ (
3940 self ,
4041 name : str ,
4142 credentials_file : str = None ,
4243 region : str = None ,
4344 zone : str = None ,
45+ image : str = None ,
4446 ):
4547 """
4648 Class that holds all information about particular connection to cluster provider, namely
4749 * provider name (must be one of known ones)
4850 * path to file with credentials (file format is provider-specific); omit to use global provider-default credentials
4951 * region and zone where cluster is to be spawned (optional, would be deduced if omitted)
52+ * image to use (optional, would use default for provider if omitted)
5053 """
5154
5255 if name not in self .__KNOWN :
@@ -77,6 +80,7 @@ def __init__(
7780 self .credentials_file = (
7881 os .path .abspath (credentials_file ) if credentials_file is not None else None
7982 )
83+ self .image = image or self .__DEFAULT_IMAGE [name ]
8084
8185 @property
8286 def default_head_type (self ):
@@ -197,6 +201,7 @@ def create(
197201 credentials : str = None ,
198202 region : str = None ,
199203 zone : str = None ,
204+ image : str = None ,
200205 project_name : str = None ,
201206 cluster_name : str = "modin-cluster" ,
202207 workers : int = 4 ,
@@ -222,6 +227,9 @@ def create(
222227 If omitted a default for given provider will be taken.
223228 zone : str, optional
224229 Availability zone (part of region) where to spawn the cluster.
230+ If omitted a default for given provider and region will be taken.
231+ image: str, optional
232+ Image to use for spawning head and worker nodes.
225233 If omitted a default for given provider will be taken.
226234 project_name : str, optional
227235 Project name to assign to the cluster in cloud, for easier manual tracking.
@@ -247,12 +255,16 @@ def create(
247255 """
248256 if not isinstance (provider , Provider ):
249257 provider = Provider (
250- name = provider , credentials_file = credentials , region = region , zone = zone
258+ name = provider ,
259+ credentials_file = credentials ,
260+ region = region ,
261+ zone = zone ,
262+ image = image ,
251263 )
252264 else :
253- if any (p is not None for p in (credentials , region , zone )):
265+ if any (p is not None for p in (credentials , region , zone , image )):
254266 warnings .warn (
255- "Ignoring credentials, region and zone parameters because provider is specified as Provider descriptor, not as name" ,
267+ "Ignoring credentials, region, zone and image parameters because provider is specified as Provider descriptor, not as name" ,
256268 UserWarning ,
257269 )
258270 if __spawner__ == "rayscale" :
0 commit comments