Skip to content

Commit

Permalink
Added support for custom setup scripts on Compute Instances (Azure#26340
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nthandeMS authored and mccoyp committed Sep 22, 2022
1 parent fca5f94 commit 62bccfd
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 15 deletions.
2 changes: 2 additions & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
### Bugs Fixed

### Other Changes
- Removed declaration on Python 3.6 support
- Added support for custom setup scripts on compute instances.
- Removed declaration on Python 3.6 support.
- Updated dependencies upper bounds to be major versions.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..core.fields import ExperimentalField, NestedField, StringTransformedEnum
from .compute import ComputeSchema, IdentitySchema, NetworkSettingsSchema
from .schedule import ComputeSchedulesSchema
from .setup_scripts import SetupScriptsSchema


class ComputeInstanceSshSettingsSchema(PathAwareSchema):
Expand Down Expand Up @@ -50,3 +51,4 @@ class ComputeInstanceSchema(ComputeSchema):
schedules = ExperimentalField(NestedField(ComputeSchedulesSchema))
identity = ExperimentalField(NestedField(IdentitySchema))
idle_time_before_shutdown = ExperimentalField(fields.Str())
setup_scripts = ExperimentalField(NestedField(SetupScriptsSchema))
33 changes: 33 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/compute/setup_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import fields
from marshmallow.decorators import post_load

from azure.ai.ml._schema.core.fields import ArmVersionedStr, CodeField, LocalPathField, NestedField, UnionField
from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
from azure.ai.ml.constants._common import LOCAL_PATH, AzureMLResourceType


class ScriptReferenceSchema(metaclass=PatchedSchemaMeta):
path = fields.Str()
command = fields.Str()
timeout_minutes = fields.Int()

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._compute._setup_scripts import ScriptReference

return ScriptReference(**data)


class SetupScriptsSchema(metaclass=PatchedSchemaMeta):
creation_script = NestedField(ScriptReferenceSchema())
startup_script = NestedField(ScriptReferenceSchema())

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._compute._setup_scripts import SetupScripts

return SetupScripts(**data)
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(self, **kwargs):
def _jsonschema_type_mapping(self):
schema = {
"type": "string",
"pattern": "^azureml:.*",
"pattern": self.pattern,
"arm_type": self.azureml_type,
}
if self.name is not None:
Expand Down
29 changes: 15 additions & 14 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/input_output_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@ def check_dict(self, data, **kwargs):
else:
raise ValidationError("InputSchema needs type Input to dump")

def generate_path_property(azureml_type):
return UnionField(
[
ArmVersionedStr(azureml_type=azureml_type),
ArmVersionedStr(azureml_type=LOCAL_PATH, pattern="^file:.*"),
fields.Str(metadata={"pattern": "^(http(s)?):.*"}),
fields.Str(metadata={"pattern": "^(wasb(s)?):.*"}),
ArmVersionedStr(azureml_type=LOCAL_PATH, pattern="^(?!(azureml|http(s)?|wasb(s)?|file):).*"),
],
is_strict=True,
)

def generate_path_property(azureml_type):
return UnionField(
[
ArmVersionedStr(azureml_type=azureml_type),
ArmVersionedStr(azureml_type=LOCAL_PATH, pattern="^file:.*"),
fields.Str(metadata={"pattern": "^(http(s)?):.*"}),
fields.Str(metadata={"pattern": "^(wasb(s)?):.*"}),
ArmVersionedStr(azureml_type=LOCAL_PATH, pattern="^(?!(azureml|http(s)?|wasb(s)?|file):).*"),
],
is_strict=True,
)


class ModelInputSchema(InputSchema):
Expand All @@ -60,7 +61,7 @@ class ModelInputSchema(InputSchema):
AssetTypes.TRITON_MODEL,
]
)
path = InputSchema.generate_path_property(azureml_type=AzureMLResourceType.MODEL)
path = generate_path_property(azureml_type=AzureMLResourceType.MODEL)
datastore = fields.Str(metadata={"description": "Name of the datastore to upload local paths to."}, required=False)


Expand All @@ -79,7 +80,7 @@ class DataInputSchema(InputSchema):
AssetTypes.URI_FOLDER,
]
)
path = InputSchema.generate_path_property(azureml_type=AzureMLResourceType.DATA)
path = generate_path_property(azureml_type=AzureMLResourceType.DATA)
datastore = fields.Str(metadata={"description": "Name of the datastore to upload local paths to."}, required=False)


Expand All @@ -95,7 +96,7 @@ class MLTableInputSchema(InputSchema):
required=False,
)
type = StringTransformedEnum(allowed_values=[AssetTypes.MLTABLE])
path = InputSchema.generate_path_property(azureml_type=AzureMLResourceType.DATA)
path = generate_path_property(azureml_type=AzureMLResourceType.DATA)
datastore = fields.Str(metadata={"description": "Name of the datastore to upload to."}, required=False)


Expand Down
93 changes: 93 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_compute/_setup_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access

import re
from typing import Optional

from azure.ai.ml._restclient.v2022_01_01_preview.models import ScriptReference as RestScriptReference
from azure.ai.ml._restclient.v2022_01_01_preview.models import ScriptsToExecute as RestScriptsToExecute
from azure.ai.ml._restclient.v2022_01_01_preview.models import SetupScripts as RestSetupScripts
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.entities._mixins import RestTranslatableMixin


@experimental
class ScriptReference(RestTranslatableMixin):
"""Script reference.
:param path: The location of scripts in workspace storage.
:type path: Optional[str], optional
:param command: Optional command line arguments passed to the script to run.
:type command: Optional[str], optional
:param timeout_minutes: Optional time period passed to timeout command.
:type timeout_minutes: Optional[int], optional
"""

def __init__(
self, *, path: Optional[str] = None, command: Optional[str] = None, timeout_minutes: Optional[int] = None
):
self.path = path
self.command = command
self.timeout_minutes = timeout_minutes

def _to_rest_object(self) -> RestScriptReference:
return RestScriptReference(
script_source="workspaceStorage",
script_data=self.path,
script_arguments=self.command,
timeout=f"{self.timeout_minutes}m",
)

@classmethod
def _from_rest_object(cls, obj: RestScriptReference) -> "ScriptReference":
if obj is None:
return obj
timeout_match = re.match(r"(\d+)m", obj.timeout) if obj.timeout else None
timeout_minutes = timeout_match.group(1) if timeout_match else None
script_reference = ScriptReference(
path=obj.script_data if obj.script_data else None,
command=obj.script_arguments if obj.script_arguments else None,
timeout_minutes=timeout_minutes,
)
return script_reference


@experimental
class SetupScripts(RestTranslatableMixin):
"""Customized setup scripts.
:param startup_script: Script that's run every time the machine starts.
:type startup_script: Optional[ScriptReference], optional
:param creation_script: Script that's run only once during provision of the compute.
:type creation_script: Optional[ScriptReference], optional
"""

def __init__(
self, *, startup_script: Optional[ScriptReference] = None, creation_script: Optional[ScriptReference] = None
):
self.startup_script = startup_script
self.creation_script = creation_script

def _to_rest_object(self) -> RestScriptsToExecute:
scripts_to_execute = RestScriptsToExecute(
startup_script=self.startup_script._to_rest_object() if self.startup_script else None,
creation_script=self.creation_script._to_rest_object() if self.creation_script else None,
)
return RestSetupScripts(scripts=scripts_to_execute)

@classmethod
def _from_rest_object(cls, obj: RestSetupScripts) -> "SetupScripts":
if obj is None or obj.scripts is None:
return None
scripts = obj.scripts
setup_scripts = SetupScripts(
startup_script=ScriptReference._from_rest_object(
scripts.startup_script if scripts.startup_script else None
),
creation_script=ScriptReference._from_rest_object(
scripts.creation_script if scripts.creation_script else None
),
)
return setup_scripts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ._identity import IdentityConfiguration
from ._schedule import ComputeSchedules
from ._setup_scripts import SetupScripts


class ComputeInstanceSshSettings:
Expand Down Expand Up @@ -123,6 +124,8 @@ class ComputeInstance(Compute):
:param idle_time_before_shutdown: Stops compute instance after user defined period of
inactivity. Time is defined in ISO8601 format. Minimum is 15 min, maximum is 3 days.
:type idle_time_before_shutdown: Optional[str], optional
:param setup_scripts: Details of customized scripts to execute for setting up the cluster.
:type setup_scripts: Optional[SetupScripts], optional
"""

def __init__(
Expand All @@ -138,6 +141,7 @@ def __init__(
schedules: Optional[ComputeSchedules] = None,
identity: IdentityConfiguration = None,
idle_time_before_shutdown: Optional[str] = None,
setup_scripts: Optional[SetupScripts] = None,
**kwargs,
):
kwargs[TYPE] = ComputeType.COMPUTEINSTANCE
Expand All @@ -159,6 +163,7 @@ def __init__(
self.schedules = schedules
self.identity = identity
self.idle_time_before_shutdown = idle_time_before_shutdown
self.setup_scripts = setup_scripts
self.subnet = None

@property
Expand Down Expand Up @@ -227,6 +232,7 @@ def _to_rest_object(self) -> ComputeResource:
idle_time_before_shutdown=self.idle_time_before_shutdown,
)
compute_instance_prop.schedules = self.schedules._to_rest_object() if self.schedules else None
compute_instance_prop.setup_scripts = self.setup_scripts._to_rest_object() if self.setup_scripts else None
compute_instance = CIRest(
description=self.description,
compute_type=self.type,
Expand Down Expand Up @@ -318,6 +324,9 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "ComputeInstance":
if prop.properties and prop.properties.schedules and prop.properties.schedules.compute_start_stop
else None,
identity=IdentityConfiguration._from_rest_object(rest_obj.identity) if rest_obj.identity else None,
setup_scripts=SetupScripts._from_rest_object(prop.properties.setup_scripts)
if prop.properties and prop.properties.setup_scripts
else None,
)
return response

Expand Down
14 changes: 14 additions & 0 deletions sdk/ml/azure-ai-ml/tests/compute/unittests/test_compute_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ def test_compute_instance_schedules_from_yaml(self):
assert compute_instance2.schedules.compute_start_stop[1].trigger.interval == 1
assert compute_instance2.schedules.compute_start_stop[1].trigger.schedule is not None

def test_compute_instance_setup_scripts_from_yaml(self):
loaded_instance: ComputeInstance = load_compute("tests/test_configs/compute/compute-ci-setup-scripts.yaml")
compute_resource: ComputeResource = loaded_instance._to_rest_object()
compute_instance: ComputeInstance = ComputeInstance._load_from_rest(compute_resource)

assert compute_instance.setup_scripts is not None
assert compute_instance.setup_scripts.creation_script is not None
assert compute_instance.setup_scripts.creation_script.path == "Users/test/creation-script.sh"
assert compute_instance.setup_scripts.creation_script.timeout_minutes == "20"
assert compute_instance.setup_scripts.startup_script is not None
assert compute_instance.setup_scripts.startup_script.path == "Users/test/startup-script.sh"
assert compute_instance.setup_scripts.startup_script.command == "ls"
assert compute_instance.setup_scripts.startup_script.timeout_minutes == "15"

def test_compute_instance_uai_from_yaml(self):
compute: ComputeInstance = load_compute("tests/test_configs/compute/compute-ci-uai.yaml")
assert compute.name == "banchci"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: banchci
type: computeinstance
size: STANDARD_DS3_V2
description: some_desc_ci


setup_scripts:
creation_script:
path: Users/test/creation-script.sh
timeout_minutes: 20
startup_script:
path: Users/test/startup-script.sh
command: ls
timeout_minutes: 15

0 comments on commit 62bccfd

Please sign in to comment.