Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.version import version as airflow_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -1269,14 +1269,16 @@ def check_error_for_resource_is_not_ready_msg(self, error_msg: str) -> bool:
return all([word in error_msg for word in key_words])


class DataprocAsyncHook(GoogleBaseHook):
class DataprocAsyncHook(GoogleBaseAsyncHook):
"""
Asynchronous interaction with Google Cloud Dataproc APIs.
All the methods in the hook where project_id is used must be called with
keyword arguments rather than positional.
"""

sync_hook_class = DataprocHook

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
Expand All @@ -1286,53 +1288,90 @@ def __init__(
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
self._cached_client: JobControllerAsyncClient | None = None

def get_cluster_client(self, region: str | None = None) -> ClusterControllerAsyncClient:
async def get_cluster_client(self, region: str | None = None) -> ClusterControllerAsyncClient:
"""Create a ClusterControllerAsyncClient."""
client_options = None
if region and region != "global":
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")

sync_hook = await self.get_sync_hook()
return ClusterControllerAsyncClient(
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceAsyncClient:
async def get_template_client(self, region: str | None = None) -> WorkflowTemplateServiceAsyncClient:
"""Create a WorkflowTemplateServiceAsyncClient."""
client_options = None
if region and region != "global":
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")

sync_hook = await self.get_sync_hook()
return WorkflowTemplateServiceAsyncClient(
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def get_job_client(self, region: str | None = None) -> JobControllerAsyncClient:
async def get_job_client(self, region: str | None = None) -> JobControllerAsyncClient:
"""Create a JobControllerAsyncClient."""
if self._cached_client is None:
client_options = None
if region and region != "global":
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")

sync_hook = await self.get_sync_hook()
self._cached_client = JobControllerAsyncClient(
credentials=self.get_credentials(),
credentials=sync_hook.get_credentials(),
client_info=CLIENT_INFO,
client_options=client_options,
)
return self._cached_client

def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncClient:
async def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncClient:
"""Create a BatchControllerAsyncClient."""
client_options = None
if region and region != "global":
client_options = ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")

sync_hook = await self.get_sync_hook()
return BatchControllerAsyncClient(
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def get_operations_client(self, region: str) -> OperationsClient:
async def get_operations_client(self, region: str) -> OperationsClient:
"""Create a OperationsClient."""
return self.get_template_client(region=region).transport.operations_client
template_client = await self.get_template_client(region=region)
return template_client.transport.operations_client

@GoogleBaseHook.fallback_to_default_project_id
async def get_cluster(
self,
region: str,
cluster_name: str,
project_id: str,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Cluster:
"""
Get a cluster.
:param region: Cloud Dataproc region in which to handle the request.
:param cluster_name: Name of the cluster to get.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
to complete. If *retry* is specified, the timeout applies to each
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = await self.get_cluster_client(region=region)
result = await client.get_cluster(
request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
async def create_cluster(
Expand Down Expand Up @@ -1390,7 +1429,7 @@ async def create_cluster(
cluster["config"] = cluster_config # type: ignore
cluster["labels"] = labels # type: ignore

client = self.get_cluster_client(region=region)
client = await self.get_cluster_client(region=region)
result = await client.create_cluster(
request={
"project_id": project_id,
Expand Down Expand Up @@ -1435,7 +1474,7 @@ async def delete_cluster(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_cluster_client(region=region)
client = await self.get_cluster_client(region=region)
result = await client.delete_cluster(
request={
"project_id": project_id,
Expand Down Expand Up @@ -1483,7 +1522,7 @@ async def diagnose_cluster(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_cluster_client(region=region)
client = await self.get_cluster_client(region=region)
result = await client.diagnose_cluster(
request={
"project_id": project_id,
Expand All @@ -1500,38 +1539,6 @@ async def diagnose_cluster(
)
return result

@GoogleBaseHook.fallback_to_default_project_id
async def get_cluster(
self,
region: str,
cluster_name: str,
project_id: str,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Cluster:
"""
Get the resource representation for a cluster in a project.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param region: Cloud Dataproc region to handle the request.
:param cluster_name: The cluster name.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
to complete. If *retry* is specified, the timeout applies to each
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_cluster_client(region=region)
result = await client.get_cluster(
request={"project_id": project_id, "region": region, "cluster_name": cluster_name},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
async def list_clusters(
self,
Expand Down Expand Up @@ -1561,7 +1568,7 @@ async def list_clusters(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_cluster_client(region=region)
client = await self.get_cluster_client(region=region)
result = await client.list_clusters(
request={"project_id": project_id, "region": region, "filter": filter_, "page_size": page_size},
retry=retry,
Expand Down Expand Up @@ -1638,7 +1645,7 @@ async def update_cluster(
"""
if region is None:
raise TypeError("missing 1 required keyword argument: 'region'")
client = self.get_cluster_client(region=region)
client = await self.get_cluster_client(region=region)
operation = await client.update_cluster(
request={
"project_id": project_id,
Expand Down Expand Up @@ -1680,10 +1687,8 @@ async def create_workflow_template(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
if region is None:
raise TypeError("missing 1 required keyword argument: 'region'")
metadata = metadata or ()
client = self.get_template_client(region)
client = await self.get_template_client(region)
parent = f"projects/{project_id}/regions/{region}"
return await client.create_workflow_template(
request={"parent": parent, "template": template}, retry=retry, timeout=timeout, metadata=metadata
Expand Down Expand Up @@ -1725,10 +1730,8 @@ async def instantiate_workflow_template(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
if region is None:
raise TypeError("missing 1 required keyword argument: 'region'")
metadata = metadata or ()
client = self.get_template_client(region)
client = await self.get_template_client(region)
name = f"projects/{project_id}/regions/{region}/workflowTemplates/{template_name}"
operation = await client.instantiate_workflow_template(
request={"name": name, "version": version, "request_id": request_id, "parameters": parameters},
Expand Down Expand Up @@ -1767,10 +1770,8 @@ async def instantiate_inline_workflow_template(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
if region is None:
raise TypeError("missing 1 required keyword argument: 'region'")
metadata = metadata or ()
client = self.get_template_client(region)
client = await self.get_template_client(region)
parent = f"projects/{project_id}/regions/{region}"
operation = await client.instantiate_inline_workflow_template(
request={"parent": parent, "template": template, "request_id": request_id},
Expand All @@ -1781,7 +1782,8 @@ async def instantiate_inline_workflow_template(
return operation

async def get_operation(self, region, operation_name):
return await self.get_operations_client(region).get_operation(name=operation_name)
operations_client = await self.get_operations_client(region)
return await operations_client.get_operation(name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
async def get_job(
Expand All @@ -1806,9 +1808,7 @@ async def get_job(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
if region is None:
raise TypeError("missing 1 required keyword argument: 'region'")
client = self.get_job_client(region=region)
client = await self.get_job_client(region=region)
job = await client.get_job(
request={"project_id": project_id, "region": region, "job_id": job_id},
retry=retry,
Expand Down Expand Up @@ -1845,9 +1845,7 @@ async def submit_job(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
if region is None:
raise TypeError("missing 1 required keyword argument: 'region'")
client = self.get_job_client(region=region)
client = await self.get_job_client(region=region)
return await client.submit_job(
request={"project_id": project_id, "region": region, "job": job, "request_id": request_id},
retry=retry,
Expand Down Expand Up @@ -1878,7 +1876,7 @@ async def cancel_job(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_job_client(region=region)
client = await self.get_job_client(region=region)

job = await client.cancel_job(
request={"project_id": project_id, "region": region, "job_id": job_id},
Expand Down Expand Up @@ -1920,7 +1918,7 @@ async def create_batch(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_batch_client(region)
client = await self.get_batch_client(region)
parent = f"projects/{project_id}/regions/{region}"

result = await client.create_batch(
Expand Down Expand Up @@ -1959,7 +1957,7 @@ async def delete_batch(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_batch_client(region)
client = await self.get_batch_client(region)
name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"

await client.delete_batch(
Expand Down Expand Up @@ -1994,7 +1992,7 @@ async def get_batch(
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
client = self.get_batch_client(region)
client = await self.get_batch_client(region)
name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"

result = await client.get_batch(
Expand Down Expand Up @@ -2039,7 +2037,7 @@ async def list_batches(
:param filter: Result filters as specified in ListBatchesRequest
:param order_by: How to order results as specified in ListBatchesRequest
"""
client = self.get_batch_client(region)
client = await self.get_batch_client(region)
parent = f"projects/{project_id}/regions/{region}"

result = await client.list_batches(
Expand Down
Loading