Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions airflow/providers/google/cloud/hooks/cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from google.cloud.run_v2 import (
CreateJobRequest,
Expand Down Expand Up @@ -113,9 +113,15 @@ def update_job(

@GoogleBaseHook.fallback_to_default_project_id
def execute_job(
self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
self,
job_name: str,
region: str,
project_id: str = PROVIDE_PROJECT_ID,
overrides: dict[str, Any] | None = None,
) -> operation.Operation:
run_job_request = RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
run_job_request = RunJobRequest(
name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", overrides=overrides
)
operation = self.get_conn().run_job(request=run_job_request)
return operation

Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/google/cloud/operators/cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from google.cloud.run_v2 import Job

Expand Down Expand Up @@ -248,6 +248,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
:param job_name: Required. The name of the job to update.
:param job: Required. The job descriptor containing the new configuration of the job to update.
The name field will be replaced by job_name
:param overrides: Optional map of override values.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param polling_period_seconds: Optional: Control the rate of the poll for the result of deferrable run.
By default, the trigger will poll every 10 seconds.
Expand All @@ -270,6 +271,7 @@ def __init__(
project_id: str,
region: str,
job_name: str,
overrides: dict[str, Any] | None = None,
polling_period_seconds: float = 10,
timeout_seconds: float | None = None,
gcp_conn_id: str = "google_cloud_default",
Expand All @@ -281,6 +283,7 @@ def __init__(
self.project_id = project_id
self.region = region
self.job_name = job_name
self.overrides = overrides
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_period_seconds = polling_period_seconds
Expand All @@ -293,7 +296,7 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
self.operation = hook.execute_job(
region=self.region, project_id=self.project_id, job_name=self.job_name
region=self.region, project_id=self.project_id, job_name=self.job_name, overrides=self.overrides
)

if not self.deferrable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ or you can define the same operator in the deferrable mode:
:start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode]
:end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode]

You can also specify overrides that allow you to give a new entrypoint command to the job and more:

:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`

.. exampleinclude:: /../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
:language: python
:dedent: 4
:start-after: [START howto_operator_cloud_run_execute_job_with_overrides]
:end-before: [END howto_operator_cloud_run_execute_job_with_overrides]


Update a job
Expand Down
15 changes: 12 additions & 3 deletions tests/providers/google/cloud/hooks/test_cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id


class TestCloudBathHook:
class TestCloudRunHook:
def dummy_get_credentials(self):
pass

Expand Down Expand Up @@ -111,9 +111,18 @@ def test_execute_job(self, mock_batch_service_client, cloud_run_hook):
job_name = "job1"
region = "region1"
project_id = "projectid"
run_job_request = RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": 1,
"timeout": "60s",
}
run_job_request = RunJobRequest(
name=f"projects/{project_id}/locations/{region}/jobs/{job_name}", overrides=overrides
)

cloud_run_hook.execute_job(job_name=job_name, region=region, project_id=project_id)
cloud_run_hook.execute_job(
job_name=job_name, region=region, project_id=project_id, overrides=overrides
)
cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request)

@mock.patch(
Expand Down
68 changes: 67 additions & 1 deletion tests/providers/google/cloud/operators/test_cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_execute_success(self, hook_mock):
operator.execute(context=mock.MagicMock())

hook_mock.return_value.execute_job.assert_called_once_with(
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=None
)

@mock.patch(CLOUD_RUN_HOOK_PATH)
Expand Down Expand Up @@ -209,6 +209,72 @@ def test_execute_deferrable_execute_complete_method_success(self, hook_mock):
result = operator.execute_complete(mock.MagicMock(), event)
assert result["name"] == JOB_NAME

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides(self, hook_mock):
hook_mock.return_value.get_job.return_value = JOB
hook_mock.return_value.execute_job.return_value = self._mock_operation(3, 3, 0)

overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": 1,
"timeout": "60s",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

operator.execute(context=mock.MagicMock())

hook_mock.return_value.execute_job.assert_called_once_with(
job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, overrides=overrides
)

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides_with_invalid_task_count(self, hook_mock):
overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": -1,
"timeout": "60s",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

with pytest.raises(AirflowException):
operator.execute(context=mock.MagicMock())

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides_with_invalid_timeout(self, hook_mock):
overrides = {
"container_overrides": [{"args": ["python", "main.py"]}],
"task_count": 1,
"timeout": "60",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

with pytest.raises(AirflowException):
operator.execute(context=mock.MagicMock())

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_overrides_with_invalid_container_args(self, hook_mock):
overrides = {
"container_overrides": [{"name": "job", "args": "python main.py"}],
"task_count": 1,
"timeout": "60s",
}

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, job_name=JOB_NAME, overrides=overrides
)

with pytest.raises(AirflowException):
operator.execute(context=mock.MagicMock())

def _mock_operation(self, task_count, succeeded_count, failed_count):
operation = mock.MagicMock()
operation.result.return_value = self._mock_execution(task_count, succeeded_count, failed_count)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@
job_name_prefix = "cloudrun-system-test-job"
job1_name = f"{job_name_prefix}1"
job2_name = f"{job_name_prefix}2"
job3_name = f"{job_name_prefix}3"

create1_task_name = "create-job1"
create2_task_name = "create-job2"

execute1_task_name = "execute-job1"
execute2_task_name = "execute-job2"
execute3_task_name = "execute-job3"

update_job1_task_name = "update-job1"

Expand All @@ -70,6 +72,9 @@ def _assert_executed_jobs_xcom(ti):
job2_dicts = ti.xcom_pull(task_ids=[execute2_task_name], key="return_value")
assert job2_name in job2_dicts[0]["name"]

job3_dicts = ti.xcom_pull(task_ids=[execute3_task_name], key="return_value")
assert job3_name in job3_dicts[0]["name"]


def _assert_created_jobs_xcom(ti):
job1_dicts = ti.xcom_pull(task_ids=[create1_task_name], key="return_value")
Expand Down Expand Up @@ -181,6 +186,31 @@ def _create_job_with_label():
)
# [END howto_operator_cloud_run_execute_job_deferrable_mode]

# [START howto_operator_cloud_run_execute_job_with_overrides]
overrides = {
"container_overrides": [
{
"name": "job",
"args": ["python", "main.py"],
"env": [{"name": "ENV_VAR", "value": "value"}],
"clearArgs": False,
}
],
"task_count": 1,
"timeout": "60s",
}

execute3 = CloudRunExecuteJobOperator(
task_id=execute3_task_name,
project_id=PROJECT_ID,
region=region,
overrides=overrides,
job_name=job3_name,
dag=dag,
deferrable=False,
)
# [END howto_operator_cloud_run_execute_job_with_overrides]

assert_executed_jobs = PythonOperator(
task_id="assert-executed-jobs", python_callable=_assert_executed_jobs_xcom, dag=dag
)
Expand Down Expand Up @@ -237,7 +267,7 @@ def _create_job_with_label():
(
(create1, create2)
>> assert_created_jobs
>> (execute1, execute2)
>> (execute1, execute2, execute3)
>> assert_executed_jobs
>> list_jobs_limit
>> assert_jobs_limit
Expand Down