Skip to content

Commit

Permalink
[ML] Add enable_node_public_ip to Compute Instances and Aml Computes (A…
Browse files Browse the repository at this point in the history
…zure#28004)

* Added enable_node_public_ip to compute

* Formatting changes

* Added enable_node_public_ip to schemas

* Added unit tests for no public ip

* Updated changelog

* Updated pylint rule
  • Loading branch information
nthandeMS authored Dec 20, 2022
1 parent 6810ee4 commit 7c3fad7
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 38 deletions.
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Features Added
- Change print behavior of entity classes to show object yaml in notebooks, can be configured on in other contexts.
- Added property to enable/disable public ip addresses to Compute Instances and AML Computes.

### Bugs Fixed
- Fixed issue with date-time format for utc_time_created field when creating models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ def make(self, data, **kwargs):
class AmlComputeSchema(ComputeSchema):
type = StringTransformedEnum(allowed_values=[ComputeType.AMLCOMPUTE], required=True)
size = fields.Str()
tier = StringTransformedEnum(allowed_values=[ComputeTier.LOWPRIORITY, ComputeTier.DEDICATED])
tier = StringTransformedEnum(
allowed_values=[ComputeTier.LOWPRIORITY, ComputeTier.DEDICATED]
)
min_instances = fields.Int()
max_instances = fields.Int()
idle_time_before_scale_down = fields.Int()
ssh_public_access_enabled = fields.Bool()
ssh_settings = NestedField(AmlComputeSshSettingsSchema)
network_settings = NestedField(NetworkSettingsSchema)
identity = NestedField(IdentitySchema)
enable_node_public_ip = fields.Bool(
metadata={
"description": "Enable or disable node public IP address provisioning."
}
)
17 changes: 14 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/compute/compute_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,29 @@ def make(self, data, **kwargs):


class ComputeInstanceSchema(ComputeSchema):
type = StringTransformedEnum(allowed_values=[ComputeType.COMPUTEINSTANCE], required=True)
type = StringTransformedEnum(
allowed_values=[ComputeType.COMPUTEINSTANCE], required=True
)
size = fields.Str()
network_settings = NestedField(NetworkSettingsSchema)
create_on_behalf_of = NestedField(CreateOnBehalfOfSchema)
ssh_settings = NestedField(ComputeInstanceSshSettingsSchema)
ssh_public_access_enabled = fields.Bool(dump_default=None)
state = fields.Str(dump_only=True)
last_operation = fields.Dict(keys=fields.Str(), values=fields.Str(), dump_only=True)
services = fields.List(fields.Dict(keys=fields.Str(), values=fields.Str()), dump_only=True)
services = fields.List(
fields.Dict(keys=fields.Str(), values=fields.Str()), dump_only=True
)
schedules = NestedField(ComputeSchedulesSchema)
identity = ExperimentalField(NestedField(IdentitySchema))
idle_time_before_shutdown = ExperimentalField(fields.Str())
idle_time_before_shutdown_minutes = ExperimentalField(fields.Int())
setup_scripts = ExperimentalField(NestedField(SetupScriptsSchema))
os_image_metadata = ExperimentalField(NestedField(OsImageMetadataSchema, dump_only=True))
os_image_metadata = ExperimentalField(
NestedField(OsImageMetadataSchema, dump_only=True)
)
enable_node_public_ip = fields.Bool(
metadata={
"description": "Enable or disable node public IP address provisioning."
}
)
76 changes: 59 additions & 17 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_compute/aml_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access
# pylint: disable=protected-access,too-many-instance-attributes

from typing import Dict, Optional

from azure.ai.ml._restclient.v2022_10_01_preview.models import AmlCompute as AmlComputeRest
from azure.ai.ml._restclient.v2022_10_01_preview.models import (
AmlCompute as AmlComputeRest,
)
from azure.ai.ml._restclient.v2022_10_01_preview.models import (
AmlComputeProperties,
ComputeResource,
Expand All @@ -16,7 +18,11 @@
)
from azure.ai.ml._schema._utils.utils import get_subnet_str
from azure.ai.ml._schema.compute.aml_compute import AmlComputeSchema
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal, to_iso_duration_format
from azure.ai.ml._utils.utils import (
camel_to_snake,
snake_to_pascal,
to_iso_duration_format,
)
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType
from azure.ai.ml.entities._credentials import IdentityConfiguration
Expand Down Expand Up @@ -57,7 +63,9 @@ def _to_user_account_credentials(self) -> UserAccountCredentials:
)

@classmethod
def _from_user_account_credentials(cls, credentials: UserAccountCredentials) -> "AmlComputeSshSettings":
def _from_user_account_credentials(
cls, credentials: UserAccountCredentials
) -> "AmlComputeSshSettings":
return cls(
admin_username=credentials.admin_user_name,
admin_password=credentials.admin_user_password,
Expand Down Expand Up @@ -95,6 +103,11 @@ class AmlCompute(Compute):
else is open all public nodes. It can be default only during cluster creation time, after
creation it will be either True or False. Possible values include: True, False, None. Default value: None.
:type ssh_public_access_enabled: bool, optional
:param enable_node_public_ip: Enable or disable node public IP address provisioning. Possible values are:
True - Indicates that the compute nodes will have public IPs provisioned.
False - Indicates that the compute nodes will have a private endpoint and no public IPs.
Default Value: True.
:type enable_node_public_ip: Optional[bool], optional
"""

def __init__(
Expand All @@ -111,6 +124,7 @@ def __init__(
idle_time_before_scale_down: Optional[int] = None,
identity: Optional[IdentityConfiguration] = None,
tier: Optional[str] = None,
enable_node_public_ip: Optional[bool] = True,
**kwargs,
):
kwargs[TYPE] = ComputeType.AMLCOMPUTE
Expand All @@ -129,20 +143,25 @@ def __init__(
self.ssh_settings = ssh_settings
self.network_settings = network_settings
self.tier = tier
self.enable_node_public_ip = enable_node_public_ip
self.subnet = None

@classmethod
def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
prop = rest_obj.properties

network_settings = None
if prop.properties.subnet or (prop.properties.enable_node_public_ip is not None):
if prop.properties.subnet or (
prop.properties.enable_node_public_ip is not None
):
network_settings = NetworkSettings(
subnet=prop.properties.subnet.id if prop.properties.subnet else None,
)

ssh_settings = (
AmlComputeSshSettings._from_user_account_credentials(prop.properties.user_account_credentials)
AmlComputeSshSettings._from_user_account_credentials(
prop.properties.user_account_credentials
)
if prop.properties.user_account_credentials
else None
)
Expand All @@ -158,16 +177,28 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
else None,
size=prop.properties.vm_size,
tier=camel_to_snake(prop.properties.vm_priority),
min_instances=prop.properties.scale_settings.min_node_count if prop.properties.scale_settings else None,
max_instances=prop.properties.scale_settings.max_node_count if prop.properties.scale_settings else None,
min_instances=prop.properties.scale_settings.min_node_count
if prop.properties.scale_settings
else None,
max_instances=prop.properties.scale_settings.max_node_count
if prop.properties.scale_settings
else None,
network_settings=network_settings or None,
ssh_settings=ssh_settings,
ssh_public_access_enabled=(prop.properties.remote_login_port_public_access == "Enabled"),
ssh_public_access_enabled=(
prop.properties.remote_login_port_public_access == "Enabled"
),
idle_time_before_scale_down=prop.properties.scale_settings.node_idle_time_before_scale_down.total_seconds()
if prop.properties.scale_settings and prop.properties.scale_settings.node_idle_time_before_scale_down
if prop.properties.scale_settings
and prop.properties.scale_settings.node_idle_time_before_scale_down
else None,
identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity)
if rest_obj.identity
else None,
identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None,
created_on=prop.additional_properties.get("createdOn", None),
enable_node_public_ip=prop.properties.enable_node_public_ip
if prop.properties.enable_node_public_ip
else True,
)
return response

Expand Down Expand Up @@ -199,27 +230,38 @@ def _to_rest_object(self) -> ComputeResource:
scale_settings = ScaleSettings(
max_node_count=self.max_instances,
min_node_count=self.min_instances,
node_idle_time_before_scale_down=to_iso_duration_format(int(self.idle_time_before_scale_down))
node_idle_time_before_scale_down=to_iso_duration_format(
int(self.idle_time_before_scale_down)
)
if self.idle_time_before_scale_down
else None,
)
remote_login_public_access = "Enabled"
if self.ssh_public_access_enabled is not None:
remote_login_public_access = "Enabled" if self.ssh_public_access_enabled else "Disabled"
remote_login_public_access = (
"Enabled" if self.ssh_public_access_enabled else "Disabled"
)
else:
remote_login_public_access = "NotSpecified"
aml_prop = AmlComputeProperties(
vm_size=self.size if self.size else ComputeDefaults.VMSIZE,
vm_priority=snake_to_pascal(self.tier),
user_account_credentials=self.ssh_settings._to_user_account_credentials() if self.ssh_settings else None,
user_account_credentials=self.ssh_settings._to_user_account_credentials()
if self.ssh_settings
else None,
scale_settings=scale_settings,
subnet=subnet_resource,
remote_login_port_public_access=remote_login_public_access,
enable_node_public_ip=self.enable_node_public_ip,
)

aml_comp = AmlComputeRest(description=self.description, compute_type=self.type, properties=aml_prop)
aml_comp = AmlComputeRest(
description=self.description, compute_type=self.type, properties=aml_prop
)
return ComputeResource(
location=self.location,
properties=aml_comp,
identity=(self.identity._to_compute_rest_object() if self.identity else None),
)
identity=(
self.identity._to_compute_rest_object() if self.identity else None
),
)
Loading

0 comments on commit 7c3fad7

Please sign in to comment.