Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix in AWS glue operator when specifying the WorkerType & NumberOfWorkers #19787

Merged
merged 9 commits into from
Dec 6, 2021
50 changes: 37 additions & 13 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
concurrent_run_limit: int = 1,
script_location: Optional[str] = None,
retry_limit: int = 0,
num_of_dpus: int = 10,
num_of_dpus: Optional[int] = None,
iam_role_name: Optional[str] = None,
create_job_kwargs: Optional[dict] = None,
*args,
Expand All @@ -70,11 +70,23 @@ def __init__(
self.concurrent_run_limit = concurrent_run_limit
self.script_location = script_location
self.retry_limit = retry_limit
self.num_of_dpus = num_of_dpus
self.s3_bucket = s3_bucket
self.role_name = iam_role_name
self.s3_glue_logs = 'logs/glue-logs/'
self.create_job_kwargs = create_job_kwargs or {}

if "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs:
Ritika-Singhal marked this conversation as resolved.
Show resolved Hide resolved
if num_of_dpus is not None:
raise ValueError("Cannot specify num_of_dpus with custom WorkerType")
elif "WorkerType" not in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs:
raise ValueError("Need to specify custom WorkerType when specifying NumberOfWorkers")
elif "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" not in self.create_job_kwargs:
raise ValueError("Need to specify NumberOfWorkers when specifying custom WorkerType")
elif num_of_dpus is None:
self.num_of_dpus = 10
else:
self.num_of_dpus = num_of_dpus

kwargs['client_type'] = 'glue'
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -179,17 +191,29 @@ def get_or_create_glue_job(self) -> str:
s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}'
execution_role = self.get_iam_execution_role()
try:
create_job_response = glue_client.create_job(
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
Role=execution_role['Role']['Arn'],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
MaxRetries=self.retry_limit,
AllocatedCapacity=self.num_of_dpus,
**self.create_job_kwargs,
)
if "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs:
create_job_response = glue_client.create_job(
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
Role=execution_role['Role']['Arn'],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
MaxRetries=self.retry_limit,
**self.create_job_kwargs,
)
else:
create_job_response = glue_client.create_job(
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
Role=execution_role['Role']['Arn'],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
MaxRetries=self.retry_limit,
MaxCapacity=self.num_of_dpus,
**self.create_job_kwargs,
)
return create_job_response['Name']
except Exception as general_error:
self.log.error("Failed to create aws glue job, error: %s", general_error)
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,48 @@ def test_get_or_create_glue_job(self, mock_get_conn, mock_get_iam_execution_role
).get_or_create_glue_job()
assert glue_job == mock_glue_job

@mock.patch.object(AwsGlueJobHook, "get_iam_execution_role")
@mock.patch.object(AwsGlueJobHook, "get_conn")
def test_get_or_create_glue_job_worker_type(self, mock_get_conn, mock_get_iam_execution_role):
mock_get_iam_execution_role.return_value = mock.MagicMock(Role={'RoleName': 'my_test_role'})
some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
some_s3_bucket = "my-includes"

mock_glue_job = mock_get_conn.return_value.get_job()['Job']['Name']
glue_job = AwsGlueJobHook(
job_name='aws_test_glue_job',
desc='This is test case job from Airflow',
script_location=some_script,
iam_role_name='my_test_role',
s3_bucket=some_s3_bucket,
region_name=self.some_aws_region,
create_job_kwargs={'WorkerType': 'G.2X', 'NumberOfWorkers': 60},
).get_or_create_glue_job()
assert glue_job == mock_glue_job

@mock.patch.object(AwsGlueJobHook, "get_iam_execution_role")
@mock.patch.object(AwsGlueJobHook, "get_conn")
def test_init_worker_type_value_error(self, mock_get_conn, mock_get_iam_execution_role):
mock_get_iam_execution_role.return_value = mock.MagicMock(Role={'RoleName': 'my_test_role'})
some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
some_s3_bucket = "my-includes"

try:
glue_hook = AwsGlueJobHook(
job_name='aws_test_glue_job',
desc='This is test case job from Airflow',
script_location=some_script,
iam_role_name='my_test_role',
s3_bucket=some_s3_bucket,
region_name=self.some_aws_region,
num_of_dpus=20,
create_job_kwargs={'WorkerType': 'G.2X', 'NumberOfWorkers': 60},
)
except ValueError as e:
self.assertEqual(type(e), ValueError)
else:
raise AssertionError("ValueError was not raised")
Ritika-Singhal marked this conversation as resolved.
Show resolved Hide resolved

@mock.patch.object(AwsGlueJobHook, "get_job_state")
@mock.patch.object(AwsGlueJobHook, "get_or_create_glue_job")
@mock.patch.object(AwsGlueJobHook, "get_conn")
Expand Down