Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import itertools
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from google.cloud.run_v2 import (
CreateJobRequest,
Expand Down Expand Up @@ -67,16 +67,21 @@ class CloudRunHook(GoogleBaseHook):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'.
If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available
or fails in your environment (e.g., Docker containers with certain network configurations).
"""

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
transport: Literal["rest", "grpc"] | None = None,
**kwargs,
) -> None:
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
self._client: JobsClient | None = None
self.transport = transport

def get_conn(self):
"""
Expand All @@ -85,7 +90,12 @@ def get_conn(self):
:return: Cloud Run Jobs client object.
"""
if self._client is None:
self._client = JobsClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
client_kwargs = {
"credentials": self.get_credentials(),
"client_info": CLIENT_INFO,
"transport": self.transport,
}
self._client = JobsClient(**client_kwargs)
return self._client

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -176,6 +186,9 @@ class CloudRunAsyncHook(GoogleBaseAsyncHook):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'.
If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available
or fails in your environment (e.g., Docker containers with certain network configurations).
"""

sync_hook_class = CloudRunHook
Expand All @@ -184,15 +197,24 @@ def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
transport: Literal["rest", "grpc"] | None = None,
**kwargs,
):
self._client: JobsAsyncClient | None = None
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
self.transport = transport
super().__init__(
gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, transport=transport, **kwargs
)

async def get_conn(self):
if self._client is None:
sync_hook = await self.get_sync_hook()
self._client = JobsAsyncClient(credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO)
client_kwargs = {
"credentials": sync_hook.get_credentials(),
"client_info": CLIENT_INFO,
"transport": self.transport,
}
self._client = JobsAsyncClient(**client_kwargs)

return self._client

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import google.cloud.exceptions
from google.api_core.exceptions import AlreadyExists
Expand Down Expand Up @@ -263,6 +263,9 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run the operator in deferrable mode.
:param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'.
If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available
or fails in your environment (e.g., Docker containers with certain network configurations).
"""

operator_extra_links = (CloudRunJobLoggingLink(),)
Expand All @@ -275,6 +278,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
"overrides",
"polling_period_seconds",
"timeout_seconds",
"transport",
)

def __init__(
Expand All @@ -288,6 +292,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
transport: Literal["rest", "grpc"] | None = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -300,11 +305,14 @@ def __init__(
self.polling_period_seconds = polling_period_seconds
self.timeout_seconds = timeout_seconds
self.deferrable = deferrable
self.transport = transport
self.operation: operation.Operation | None = None

def execute(self, context: Context):
hook: CloudRunHook = CloudRunHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
transport=self.transport,
)
self.operation = hook.execute_job(
region=self.region, project_id=self.project_id, job_name=self.job_name, overrides=self.overrides
Expand Down Expand Up @@ -333,6 +341,7 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_period_seconds=self.polling_period_seconds,
transport=self.transport,
),
method_name="execute_complete",
)
Expand All @@ -350,7 +359,11 @@ def execute_complete(self, context: Context, event: dict):
f"Operation failed with error code [{error_code}] and error message [{error_message}]"
)

hook: CloudRunHook = CloudRunHook(self.gcp_conn_id, self.impersonation_chain)
hook: CloudRunHook = CloudRunHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
transport=self.transport,
)

job = hook.get_job(job_name=event["job_name"], region=self.region, project_id=self.project_id)
return Job.to_dict(job)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import asyncio
from collections.abc import AsyncIterator, Sequence
from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.google.cloud.hooks.cloud_run import CloudRunAsyncHook
Expand Down Expand Up @@ -59,6 +59,9 @@ class CloudRunJobFinishedTrigger(BaseTrigger):
account from the list granting this role to the originating account (templated).
:param poll_sleep: Polling period in seconds to check for the status.
:timeout: The time to wait before failing the operation.
:param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'.
Defaults to 'grpc'. Use 'rest' if gRPC is not available or fails in your environment
(e.g., Docker containers with certain network configurations).
"""

def __init__(
Expand All @@ -71,6 +74,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
polling_period_seconds: float = 10,
timeout: float | None = None,
transport: Literal["rest", "grpc"] | None = None,
):
super().__init__()
self.project_id = project_id
Expand All @@ -81,6 +85,7 @@ def __init__(
self.polling_period_seconds = polling_period_seconds
self.timeout = timeout
self.impersonation_chain = impersonation_chain
self.transport = transport

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize class arguments and classpath."""
Expand All @@ -95,6 +100,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"polling_period_seconds": self.polling_period_seconds,
"timeout": self.timeout,
"impersonation_chain": self.impersonation_chain,
"transport": self.transport,
},
)

Expand Down Expand Up @@ -143,4 +149,5 @@ def _get_async_hook(self) -> CloudRunAsyncHook:
return CloudRunAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
transport=self.transport or "grpc",
)
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ def test_delete_job(self, mock_batch_service_client, cloud_run_hook):
cloud_run_hook.delete_job(job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID)
cloud_run_hook._client.delete_job.assert_called_once_with(delete_request)

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
new=mock_base_gcp_hook_default_project_id,
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
@pytest.mark.parametrize(("transport", "expected_transport"), [("rest", "rest"), (None, None)])
def test_get_conn_with_transport(self, mock_jobs_client, transport, expected_transport):
"""Test that transport parameter is passed to JobsClient."""
hook = CloudRunHook(transport=transport)
hook.get_credentials = self.dummy_get_credentials
hook.get_conn()

mock_jobs_client.assert_called_once()
call_kwargs = mock_jobs_client.call_args[1]
assert call_kwargs["transport"] == expected_transport

def _mock_pager(self, number_of_jobs):
mock_pager = []
for i in range(number_of_jobs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,28 @@ def test_template_fields(self):
assert "overrides" in operator.template_fields
assert "polling_period_seconds" in operator.template_fields
assert "timeout_seconds" in operator.template_fields
assert "transport" in operator.template_fields

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_with_transport(self, hook_mock):
"""Test that transport parameter is passed to CloudRunHook."""
hook_mock.return_value.get_job.return_value = JOB
hook_mock.return_value.execute_job.return_value = self._mock_operation(3, 3, 0)

operator = CloudRunExecuteJobOperator(
task_id=TASK_ID,
project_id=PROJECT_ID,
region=REGION,
job_name=JOB_NAME,
transport="rest",
)

operator.execute(context=mock.MagicMock())

# Verify that CloudRunHook was instantiated with transport parameter
hook_mock.assert_called_once()
call_kwargs = hook_mock.call_args[1]
assert call_kwargs["transport"] == "rest"

@mock.patch(CLOUD_RUN_HOOK_PATH)
def test_execute_success(self, hook_mock):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def trigger():
polling_period_seconds=POLL_SLEEP,
timeout=TIMEOUT,
impersonation_chain=IMPERSONATION_CHAIN,
transport=None,
)


Expand All @@ -65,6 +66,7 @@ def test_serialization(self, trigger):
"polling_period_seconds": POLL_SLEEP,
"timeout": TIMEOUT,
"impersonation_chain": IMPERSONATION_CHAIN,
"transport": None,
}

@pytest.mark.asyncio
Expand Down