Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .cloud/.azure/run.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
[{
"experiment_name": "<your-experiment-name>",
"tags": {"<your-run-tag-key>": "<your-run-tag-value>"},
"wait_for_completion": true,
"runconfig_python_file": "<your-runconfig-python-file>",
"runconfig_python_function_name": "<your-runconfig-python-function-name>",
"runconfig_yaml_file": "<your-runconfig-yaml-file>",
"pipeline_yaml_file": "<your-pipeline-yaml-file>",
"pipeline_publish": false,
"pipeline_name": "<your-pipeline-name>",
"pipeline_version": "<your-pipeline-version>",
"pipeline_continue_on_step_failure": false
},
{
"experiment_name": "<your-experiment-name>",
"tags": {"<your-run-tag-key>": "<your-run-tag-value>"},
Expand All @@ -11,3 +24,4 @@
"pipeline_version": "<your-pipeline-version>",
"pipeline_continue_on_step_failure": false
}
]
182 changes: 106 additions & 76 deletions code/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import json

import time
from azureml.core import Workspace, Experiment
from azureml.core.authentication import ServicePrincipalAuthentication
from azureml.pipeline.core import PipelineRun
Expand All @@ -11,70 +11,7 @@
from utils import AMLConfigurationException, AMLExperimentConfigurationException, required_parameters_provided, mask_parameter, convert_to_markdown, load_pipeline_yaml, load_runconfig_yaml, load_runconfig_python


def main():
# Loading input values
print("::debug::Loading input values")
parameters_file = os.environ.get("INPUT_PARAMETERS_FILE", default="run.json")
azure_credentials = os.environ.get("INPUT_AZURE_CREDENTIALS", default="{}")
try:
azure_credentials = json.loads(azure_credentials)
except JSONDecodeError:
print("::error::Please paste output of `az ad sp create-for-rbac --name <your-sp-name> --role contributor --scopes /subscriptions/<your-subscriptionId>/resourceGroups/<your-rg> --sdk-auth` as value of secret variable: AZURE_CREDENTIALS")
raise AMLConfigurationException(f"Incorrect or poorly formed output from azure credentials saved in AZURE_CREDENTIALS secret. See setup in https://github.com/Azure/aml-workspace/blob/master/README.md")

# Checking provided parameters
print("::debug::Checking provided parameters")
required_parameters_provided(
parameters=azure_credentials,
keys=["tenantId", "clientId", "clientSecret"],
message="Required parameter(s) not found in your azure credentials saved in AZURE_CREDENTIALS secret for logging in to the workspace. Please provide a value for the following key(s): "
)

# Mask values
print("::debug::Masking parameters")
mask_parameter(parameter=azure_credentials.get("tenantId", ""))
mask_parameter(parameter=azure_credentials.get("clientId", ""))
mask_parameter(parameter=azure_credentials.get("clientSecret", ""))
mask_parameter(parameter=azure_credentials.get("subscriptionId", ""))

# Loading parameters file
print("::debug::Loading parameters file")
parameters_file_path = os.path.join(".cloud", ".azure", parameters_file)
try:
with open(parameters_file_path) as f:
parameters = json.load(f)
except FileNotFoundError:
print(f"::debug::Could not find parameter file in {parameters_file_path}. Please provide a parameter file in your repository if you do not want to use default settings (e.g. .cloud/.azure/run.json).")
parameters = {}

# Loading Workspace
print("::debug::Loading AML Workspace")
sp_auth = ServicePrincipalAuthentication(
tenant_id=azure_credentials.get("tenantId", ""),
service_principal_id=azure_credentials.get("clientId", ""),
service_principal_password=azure_credentials.get("clientSecret", "")
)
config_file_path = os.environ.get("GITHUB_WORKSPACE", default=".cloud/.azure")
config_file_name = "aml_arm_config.json"
try:
ws = Workspace.from_config(
path=config_file_path,
_file_name=config_file_name,
auth=sp_auth
)
except AuthenticationException as exception:
print(f"::error::Could not retrieve user token. Please paste output of `az ad sp create-for-rbac --name <your-sp-name> --role contributor --scopes /subscriptions/<your-subscriptionId>/resourceGroups/<your-rg> --sdk-auth` as value of secret variable: AZURE_CREDENTIALS: {exception}")
raise AuthenticationException
except AuthenticationError as exception:
print(f"::error::Microsoft REST Authentication Error: {exception}")
raise AuthenticationError
except AdalError as exception:
print(f"::error::Active Directory Authentication Library Error: {exception}")
raise AdalError
except ProjectSystemException as exception:
print(f"::error::Workspace authorizationfailed: {exception}")
raise ProjectSystemException

def submitRun(ws, parameters):
# Create experiment
print("::debug::Creating experiment")
try:
Expand Down Expand Up @@ -153,17 +90,7 @@ def main():
print(f"::set-output name=run_id::{run.id}")
print(f"::set-output name=run_url::{run.get_portal_url()}")

# Waiting for run to complete
print("::debug::Waiting for run to complete")
if parameters.get("wait_for_completion", True):
run.wait_for_completion(show_output=True)

# Creating additional outputs of finished run
run_metrics = run.get_metrics(recursive=True)
run_metrics_markdown = convert_to_markdown(run_metrics)
print(f"::set-output name=run_metrics::{run_metrics}")
print(f"::set-output name=run_metrics_markdown::{run_metrics_markdown}")

# we can publish the pipeline without waiting for run to be finished. need to verify it
# Publishing pipeline
print("::debug::Publishing pipeline")
if type(run) is PipelineRun and parameters.get("publish_pipeline", False):
Expand All @@ -188,6 +115,109 @@ def main():

print("::debug::Successfully finished Azure Machine Learning Train Action")

wait_for_completion = False
# as we don't want to wait here, we just return the run object from here.
if parameters.get("wait_for_completion", True):
wait_for_completion = True

return (run, wait_for_completion)


def postRun(submittedRuns_for_wait):
# Waiting for run to complete
print("::debug::Waiting for run to complete")
run_pending = True

while run_pending:
tempStack = submittedRuns_for_wait
for run in tempStack:
if run.get_status() in ['Completed', 'Failed']:
# Creating additional outputs of finished run
run_metrics = run.get_metrics(recursive=True)
run_metrics_markdown = convert_to_markdown(run_metrics)
print(f"::set-output name=run_metrics::{run_metrics}")
print(f"::set-output name=run_metrics_markdown::{run_metrics_markdown}")
submittedRuns_for_wait.remove(run)
time.sleep(10) # wait for 10 seconds to check again.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wait time is point of check, do we wait for 10 seconds before checking status again?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given runs can be quite long, this seems reasonable. @marvinbuss thoughts?

if len(submittedRuns_for_wait) == 0:
run_pending = False


def main():
# Loading input values
print("::debug::Loading input values")
parameters_file = os.environ.get("INPUT_PARAMETERS_FILE", default="run.json")
azure_credentials = os.environ.get("INPUT_AZURE_CREDENTIALS", default="{}")
try:
azure_credentials = json.loads(azure_credentials)
except JSONDecodeError:
print("::error::Please paste output of `az ad sp create-for-rbac --name <your-sp-name> --role contributor --scopes /subscriptions/<your-subscriptionId>/resourceGroups/<your-rg> --sdk-auth` as value of secret variable: AZURE_CREDENTIALS")
raise AMLConfigurationException(f"Incorrect or poorly formed output from azure credentials saved in AZURE_CREDENTIALS secret. See setup in https://github.com/Azure/aml-workspace/blob/master/README.md")

# Checking provided parameters
print("::debug::Checking provided parameters")
required_parameters_provided(
parameters=azure_credentials,
keys=["tenantId", "clientId", "clientSecret"],
message="Required parameter(s) not found in your azure credentials saved in AZURE_CREDENTIALS secret for logging in to the workspace. Please provide a value for the following key(s): "
)

# Mask values
print("::debug::Masking parameters")
mask_parameter(parameter=azure_credentials.get("tenantId", ""))
mask_parameter(parameter=azure_credentials.get("clientId", ""))
mask_parameter(parameter=azure_credentials.get("clientSecret", ""))
mask_parameter(parameter=azure_credentials.get("subscriptionId", ""))

# Loading parameters file
print("::debug::Loading parameters file")
parameters_file_path = os.path.join(".cloud", ".azure", parameters_file)
try:
with open(parameters_file_path) as f:
parameters = json.load(f)
except FileNotFoundError:
print(f"::debug::Could not find parameter file in {parameters_file_path}. Please provide a parameter file in your repository if you do not want to use default settings (e.g. .cloud/.azure/run.json).")
parameters = [{}] # we want to run atleast once with default values.

# Loading Workspace
print("::debug::Loading AML Workspace")
sp_auth = ServicePrincipalAuthentication(
tenant_id=azure_credentials.get("tenantId", ""),
service_principal_id=azure_credentials.get("clientId", ""),
service_principal_password=azure_credentials.get("clientSecret", "")
)
config_file_path = os.environ.get("GITHUB_WORKSPACE", default=".cloud/.azure")
config_file_name = "aml_arm_config.json"
try:
ws = Workspace.from_config(
path=config_file_path,
_file_name=config_file_name,
auth=sp_auth
)
except AuthenticationException as exception:
print(f"::error::Could not retrieve user token. Please paste output of `az ad sp create-for-rbac --name <your-sp-name> --role contributor --scopes /subscriptions/<your-subscriptionId>/resourceGroups/<your-rg> --sdk-auth` as value of secret variable: AZURE_CREDENTIALS: {exception}")
raise AuthenticationException
except AuthenticationError as exception:
print(f"::error::Microsoft REST Authentication Error: {exception}")
raise AuthenticationError
except AdalError as exception:
print(f"::error::Active Directory Authentication Library Error: {exception}")
raise AdalError
except ProjectSystemException as exception:
print(f"::error::Workspace authorizationfailed: {exception}")
raise ProjectSystemException

submittedRuns_for_wait = []
for parameter in parameters:
run, wait_for_completion = submitRun(ws, parameter)

# add a list of tuple to be used later, we will use it to wait.
if wait_for_completion is True:
submittedRuns_for_wait.append(run)

postRun(submittedRuns_for_wait)
print("submission over")


if __name__ == "__main__":
main()