-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AWS SageMaker] Integration Test for AWS SageMaker GroundTruth Compon…
…ent (#3830) * Integration Test for AWS SageMaker GroundTruth Component * Unfix already fixed bug * Fix the README I overwrote by mistake * Remove use of aws-secret for OIDC * Rev 2: Fix linting errors
- Loading branch information
Showing
12 changed files
with
367 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
...nents/aws/sagemaker/tests/integration_tests/component_tests/test_groundtruth_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pytest | ||
import os | ||
import json | ||
import utils | ||
from utils import kfp_client_utils | ||
from utils import sagemaker_utils | ||
from test_workteam_component import create_workteamjob | ||
import time | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_file_dir", | ||
[ | ||
pytest.param( | ||
"resources/config/image-classification-groundtruth", | ||
marks=pytest.mark.canary_test, | ||
) | ||
], | ||
) | ||
def test_groundtruth_labeling_job( | ||
kfp_client, experiment_id, region, sagemaker_client, test_file_dir | ||
): | ||
|
||
download_dir = utils.mkdir(os.path.join(test_file_dir + "/generated")) | ||
test_params = utils.load_params( | ||
utils.replace_placeholders( | ||
os.path.join(test_file_dir, "config.yaml"), | ||
os.path.join(download_dir, "config.yaml"), | ||
) | ||
) | ||
|
||
# First create a workteam using a separate pipeline and get the name, arn of the workteam created. | ||
workteam_name, _ = create_workteamjob( | ||
kfp_client, | ||
experiment_id, | ||
region, | ||
sagemaker_client, | ||
"resources/config/create-workteam", | ||
download_dir, | ||
) | ||
|
||
test_params["Arguments"][ | ||
"workteam_arn" | ||
] = workteam_arn = sagemaker_utils.get_workteam_arn(sagemaker_client, workteam_name) | ||
|
||
# Generate the ground_truth_train_job_name based on the workteam which will be used for labeling. | ||
test_params["Arguments"][ | ||
"ground_truth_train_job_name" | ||
] = ground_truth_train_job_name = ( | ||
test_params["Arguments"]["ground_truth_train_job_name"] + "-by-" + workteam_name | ||
) | ||
|
||
_ = kfp_client_utils.compile_run_monitor_pipeline( | ||
kfp_client, | ||
experiment_id, | ||
test_params["PipelineDefinition"], | ||
test_params["Arguments"], | ||
download_dir, | ||
test_params["TestName"], | ||
test_params["Timeout"], | ||
test_params["StatusToCheck"], | ||
) | ||
|
||
# Verify the GroundTruthJob was created in SageMaker and is InProgress. | ||
# TODO: Add a bot to complete the labeling job and check for completion instead. | ||
try: | ||
response = sagemaker_utils.describe_labeling_job( | ||
sagemaker_client, ground_truth_train_job_name | ||
) | ||
assert response["LabelingJobStatus"] == "InProgress" | ||
|
||
# Verify that the workteam has the specified labeling job | ||
labeling_jobs = sagemaker_utils.list_labeling_jobs_for_workteam( | ||
sagemaker_client, workteam_arn | ||
) | ||
assert len(labeling_jobs["LabelingJobSummaryList"]) == 1 | ||
assert ( | ||
labeling_jobs["LabelingJobSummaryList"][0]["LabelingJobName"] | ||
== ground_truth_train_job_name | ||
) | ||
finally: | ||
# Cleanup the SageMaker Resources | ||
sagemaker_utils.stop_labeling_job(sagemaker_client, ground_truth_train_job_name) | ||
sagemaker_utils.delete_workteam(sagemaker_client, workteam_name) | ||
|
||
# Delete generated files | ||
utils.remove_dir(download_dir) |
83 changes: 83 additions & 0 deletions
83
components/aws/sagemaker/tests/integration_tests/component_tests/test_workteam_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import pytest | ||
import os | ||
import json | ||
import utils | ||
from utils import kfp_client_utils | ||
from utils import sagemaker_utils | ||
from utils import minio_utils | ||
|
||
|
||
def create_workteamjob( | ||
kfp_client, experiment_id, region, sagemaker_client, test_file_dir, download_dir | ||
): | ||
|
||
test_params = utils.load_params( | ||
utils.replace_placeholders( | ||
os.path.join(test_file_dir, "config.yaml"), | ||
os.path.join(download_dir, "config.yaml"), | ||
) | ||
) | ||
|
||
# Get the account, region specific user_pool and client_id for the Sagemaker Workforce. | ||
( | ||
test_params["Arguments"]["user_pool"], | ||
test_params["Arguments"]["client_id"], | ||
test_params["Arguments"]["user_groups"], | ||
) = sagemaker_utils.get_cognito_member_definitions(sagemaker_client) | ||
|
||
# Generate random prefix for workteam_name to avoid errors if resources with same name exists | ||
test_params["Arguments"]["team_name"] = workteam_name = ( | ||
utils.generate_random_string(5) + "-" + test_params["Arguments"]["team_name"] | ||
) | ||
|
||
_, _, workflow_json = kfp_client_utils.compile_run_monitor_pipeline( | ||
kfp_client, | ||
experiment_id, | ||
test_params["PipelineDefinition"], | ||
test_params["Arguments"], | ||
download_dir, | ||
test_params["TestName"], | ||
test_params["Timeout"], | ||
) | ||
|
||
return workteam_name, workflow_json | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_file_dir", | ||
[pytest.param("resources/config/create-workteam", marks=pytest.mark.canary_test)], | ||
) | ||
def test_workteamjob( | ||
kfp_client, experiment_id, region, sagemaker_client, test_file_dir | ||
): | ||
|
||
download_dir = utils.mkdir(os.path.join(test_file_dir + "/generated")) | ||
workteam_name, workflow_json = create_workteamjob( | ||
kfp_client, experiment_id, region, sagemaker_client, test_file_dir, download_dir | ||
) | ||
|
||
outputs = {"sagemaker-private-workforce": ["workteam_arn"]} | ||
output_files = minio_utils.artifact_download_iterator( | ||
workflow_json, outputs, download_dir | ||
) | ||
|
||
try: | ||
response = sagemaker_utils.describe_workteam(sagemaker_client, workteam_name) | ||
|
||
# Verify WorkTeam was created in SageMaker | ||
assert response["Workteam"]["CreateDate"] is not None | ||
assert response["Workteam"]["WorkteamName"] == workteam_name | ||
|
||
# Verify WorkTeam arn artifact was created in Minio and matches the one in SageMaker | ||
workteam_arn = utils.read_from_file_in_tar( | ||
output_files["sagemaker-private-workforce"]["workteam_arn"], | ||
"workteam_arn.txt", | ||
) | ||
assert response["Workteam"]["WorkteamArn"] == workteam_arn | ||
|
||
finally: | ||
# Cleanup the SageMaker Resources | ||
sagemaker_utils.delete_workteam(sagemaker_client, workteam_name) | ||
|
||
# Delete generated files only if the test is successful | ||
utils.remove_dir(download_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 10 additions & 0 deletions
10
...onents/aws/sagemaker/tests/integration_tests/resources/config/create-workteam/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
PipelineDefinition: resources/definition/workteam_pipeline.py | ||
TestName: create-workteam | ||
Timeout: 3600 | ||
Arguments: | ||
region: ((REGION)) | ||
team_name: 'test-workteam' | ||
description: 'Team for GroundTruth Integ Test' | ||
user_pool: 'user-pool' | ||
user_groups: 'user-group' | ||
client_id: 'client-id' |
22 changes: 22 additions & 0 deletions
22
...ker/tests/integration_tests/resources/config/image-classification-groundtruth/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
PipelineDefinition: resources/definition/groundtruth_pipeline.py | ||
TestName: image-classification-groundtruth | ||
Timeout: 10 | ||
StatusToCheck: 'running' | ||
Arguments: | ||
region: ((REGION)) | ||
role: ((ROLE_ARN)) | ||
ground_truth_train_job_name: 'image-labeling' | ||
ground_truth_label_attribute_name: 'category' | ||
ground_truth_train_manifest_location: 's3://((DATA_BUCKET))/mini-image-classification/ground-truth-demo/train.manifest' | ||
ground_truth_output_location: 's3://((DATA_BUCKET))/mini-image-classification/ground-truth-demo/output' | ||
ground_truth_task_type: 'image classification' | ||
ground_truth_worker_type: 'private' | ||
ground_truth_label_category_config: 's3://((DATA_BUCKET))/mini-image-classification/ground-truth-demo/class_labels.json' | ||
ground_truth_ui_template: 's3://((DATA_BUCKET))/mini-image-classification/ground-truth-demo/instructions.template' | ||
ground_truth_title: 'Mini image classification' | ||
ground_truth_description: 'Test for Ground Truth KFP component' | ||
ground_truth_num_workers_per_object: '1' | ||
ground_truth_time_limit: '30' | ||
ground_truth_task_availibility: '3600' | ||
ground_truth_max_concurrent_tasks: '20' | ||
workteam_arn: 'workteam-arn' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
...onents/aws/sagemaker/tests/integration_tests/resources/definition/groundtruth_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import kfp | ||
import json | ||
import copy | ||
from kfp import components | ||
from kfp import dsl | ||
from kfp.aws import use_aws_secret | ||
|
||
sagemaker_gt_op = components.load_component_from_file( | ||
"../../ground_truth/component.yaml" | ||
) | ||
|
||
|
||
@dsl.pipeline( | ||
name="SageMaker GroundTruth image classification test pipeline", | ||
description="SageMaker GroundTruth image classification test pipeline", | ||
) | ||
def ground_truth_test( | ||
region="", | ||
ground_truth_train_job_name="", | ||
ground_truth_label_attribute_name="", | ||
ground_truth_train_manifest_location="", | ||
ground_truth_output_location="", | ||
ground_truth_task_type="", | ||
ground_truth_worker_type="", | ||
ground_truth_label_category_config="", | ||
ground_truth_ui_template="", | ||
ground_truth_title="", | ||
ground_truth_description="", | ||
ground_truth_num_workers_per_object="", | ||
ground_truth_time_limit="", | ||
ground_truth_task_availibility="", | ||
ground_truth_max_concurrent_tasks="", | ||
role="", | ||
workteam_arn="", | ||
): | ||
|
||
ground_truth_train = sagemaker_gt_op( | ||
region=region, | ||
role=role, | ||
job_name=ground_truth_train_job_name, | ||
label_attribute_name=ground_truth_label_attribute_name, | ||
manifest_location=ground_truth_train_manifest_location, | ||
output_location=ground_truth_output_location, | ||
task_type=ground_truth_task_type, | ||
worker_type=ground_truth_worker_type, | ||
workteam_arn=workteam_arn, | ||
label_category_config=ground_truth_label_category_config, | ||
ui_template=ground_truth_ui_template, | ||
title=ground_truth_title, | ||
description=ground_truth_description, | ||
num_workers_per_object=ground_truth_num_workers_per_object, | ||
time_limit=ground_truth_time_limit, | ||
task_availibility=ground_truth_task_availibility, | ||
max_concurrent_tasks=ground_truth_max_concurrent_tasks, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
kfp.compiler.Compiler().compile(ground_truth_test, __file__ + ".yaml") |
36 changes: 36 additions & 0 deletions
36
components/aws/sagemaker/tests/integration_tests/resources/definition/workteam_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import kfp | ||
import json | ||
import copy | ||
from kfp import components | ||
from kfp import dsl | ||
from kfp.aws import use_aws_secret | ||
|
||
sagemaker_workteam_op = components.load_component_from_file( | ||
"../../workteam/component.yaml" | ||
) | ||
|
||
|
||
@dsl.pipeline( | ||
name="SageMaker WorkTeam test pipeline", | ||
description="SageMaker WorkTeam test pipeline", | ||
) | ||
def workteam_test( | ||
region="", team_name="", description="", user_pool="", user_groups="", client_id="" | ||
): | ||
|
||
workteam = sagemaker_workteam_op( | ||
region=region, | ||
team_name=team_name, | ||
description=description, | ||
user_pool=user_pool, | ||
user_groups=user_groups, | ||
client_id=client_id, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
kfp.compiler.Compiler().compile( | ||
workteam_test, "SageMaker_WorkTeam_Pipelines" + ".yaml" | ||
) |
Oops, something went wrong.