Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .cloud/.azure/run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
experiment_name: "<your-experiment-name>" # this will be used as your experiment name in AML and will default to the repo name + branch name
tags:
- your-run-tag-key: "your-run-tag-value" # Not required. Tags to be added to the submitted run.
wait_for_completion: true # Whether to have the action wait for the run to complete.
runconfig_python_file: "your-runconfig-python-file>" # path to the python file that will return an Estimator, Pipeline, AutoMLConfig or ScriptRunConfig object.
runconfig_python_function_name: "<your-runconfig-python-function-name>"
runconfig_yaml_file: "<your-runconfig-yaml-file>"
pipeline_yaml_file: "<your-pipeline-yaml-file>" # If running a pipeline, the yaml file describing the pipeline
pipeline_publish: false # Publish after running
pipeline_name: "<your-pipeline-name>"
pipeline_version: "<your-pipeline-version>"
pipeline_continue_on_step_failure: false # Continue pipeline execution when a step fails.
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ jobs:
- name: Test
id: python_test
run: |
pip install pytest jsonschema azureml-sdk
pip install pytest jsonschema azureml-sdk pyyaml
pytest
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@ FROM marvinbuss/aml-docker:1.4.0

LABEL maintainer="azure/gh-aml"

RUN pip install pyyaml

COPY /code /code
ENTRYPOINT ["/code/entrypoint.sh"]
12 changes: 6 additions & 6 deletions action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ inputs:
description: "Paste output of `az ad sp create-for-rbac --name <your-sp-name> --role contributor --scopes /subscriptions/<your-subscriptionId>/resourceGroups/<your-rg> --sdk-auth` as value of secret variable: AZURE_CREDENTIALS"
required: true
parameters_file:
description: "JSON file including the parameters of the run."
required: true
default: "run.json"
description: "YAML or JSON file including the parameters for the run."
required: false
default: "run.yaml"
outputs:
experiment_name:
description: "Name of the experiment of the run"
Expand All @@ -21,11 +21,11 @@ outputs:
run_metrics_markdown:
description: "Metrics of the run formatted as markdown table (will only be provided if wait_for_completion is set to True)"
published_pipeline_id:
description: "Id of the publised pipeline (will only be provided if you submitted a pipeline and pipeline_publish is set to True)"
description: "Id of the published pipeline (will only be provided if you submitted a pipeline and pipeline_publish is set to True)"
published_pipeline_status:
description: "Status of the publised pipeline (will only be provided if you submitted a pipeline and pipeline_publish is set to True)"
description: "Status of the published pipeline (will only be provided if you submitted a pipeline and pipeline_publish is set to True)"
published_pipeline_endpoint:
description: "Endpoint of the publised pipeline (will only be provided if you submitted a pipeline and pipeline_publish is set to True)"
description: "Endpoint of the published pipeline (will only be provided if you submitted a pipeline and pipeline_publish is set to True)"
branding:
icon: "chevron-up"
color: "blue"
Expand Down
31 changes: 21 additions & 10 deletions code/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import json
import yaml

from azureml.core import Workspace, Experiment
from azureml.core.authentication import ServicePrincipalAuthentication
Expand All @@ -8,7 +9,7 @@
from adal.adal_error import AdalError
from msrest.exceptions import AuthenticationError
from json import JSONDecodeError
from utils import AMLConfigurationException, AMLExperimentConfigurationException, mask_parameter, convert_to_markdown, load_pipeline_yaml, load_runconfig_yaml, load_runconfig_python, validate_json
from utils import AMLConfigurationException, AMLExperimentConfigurationException, mask_parameter, convert_to_markdown, load_pipeline_yaml, load_runconfig_yaml, load_runconfig_python, validate_params
from schemas import azure_credentials_schema, parameters_schema


Expand All @@ -24,7 +25,7 @@ def main():

# Checking provided parameters
print("::debug::Checking provided parameters")
validate_json(
validate_params(
data=azure_credentials,
schema=azure_credentials_schema,
input_name="AZURE_CREDENTIALS"
Expand All @@ -39,18 +40,28 @@ def main():

# Loading parameters file
print("::debug::Loading parameters file")
parameters_file = os.environ.get("INPUT_PARAMETERS_FILE", default="run.json")
parameters_file = os.environ.get("INPUT_PARAMETERS_FILE", default="run.yaml")
parameters_file_path = os.path.join(".cloud", ".azure", parameters_file)
try:
with open(parameters_file_path) as f:
parameters = json.load(f)
except FileNotFoundError:
print(f"::debug::Could not find parameter file in {parameters_file_path}. Please provide a parameter file in your repository if you do not want to use default settings (e.g. .cloud/.azure/run.json).")
parameters = {}
if os.path.splitext(parameters_file_path)[1] in ['yaml', 'yml']:
try:
with open(parameters_file_path) as f:
parameters = yaml.safe_load(f)
except FileNotFoundError:
print(f"::debug::Could not find parameter file in {parameters_file_path}. Please provide a parameter file in your repository if you do not want to use default settings (e.g. .cloud/.azure/run.json).")
parameters = {}
# checking provided parameters
# TODO: Add mlspec-lib for validation
else:
try:
with open(parameters_file_path) as f:
parameters = json.load(f)
except FileNotFoundError:
print(f"::debug::Could not find parameter file in {parameters_file_path}. Please provide a parameter file in your repository if you do not want to use default settings (e.g. .cloud/.azure/run.json).")
parameters = {}

# Checking provided parameters
print("::debug::Checking provided parameters")
validate_json(
validate_params(
data=parameters,
schema=parameters_schema,
input_name="PARAMETERS_FILE"
Expand Down
2 changes: 1 addition & 1 deletion code/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def mask_parameter(parameter):
print(f"::add-mask::{parameter}")


def validate_json(data, schema, input_name):
def validate_params(data, schema, input_name):
validator = jsonschema.Draft7Validator(schema)
errors = list(validator.iter_errors(data))
if len(errors) > 0:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
myPath = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(myPath, "..", "code"))

from utils import AMLConfigurationException, convert_to_markdown, validate_json, load_pipeline_yaml, load_runconfig_yaml, load_runconfig_python
from utils import AMLConfigurationException, convert_to_markdown, validate_params, load_pipeline_yaml, load_runconfig_yaml, load_runconfig_python
from schemas import azure_credentials_schema
from objects import markdown_conversion_input, markdown_conversion_output

Expand All @@ -31,7 +31,7 @@ def test_validate_json_valid_inputs():
"tenantId": ""
}
schema_object = azure_credentials_schema
validate_json(
validate_params(
data=json_object,
schema=schema_object,
input_name="PARAMETERS_FILE"
Expand All @@ -49,7 +49,7 @@ def test_validate_json_invalid_json():
}
schema_object = azure_credentials_schema
with pytest.raises(AMLConfigurationException):
assert validate_json(
assert validate_params(
data=json_object,
schema=schema_object,
input_name="PARAMETERS_FILE"
Expand All @@ -63,7 +63,7 @@ def test_validate_json_invalid_schema():
json_object = {}
schema_object = {}
with pytest.raises(Exception):
assert validate_json(
assert validate_params(
data=json_object,
schema=schema_object,
input_name="PARAMETERS_FILE"
Expand Down