Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable

from requests import Response

from airflow.exceptions import AirflowException

from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.version_compat import BaseOperator
Expand Down Expand Up @@ -102,6 +106,8 @@ def __init__(
method: str = "GET",
data: Any = None,
headers: dict[str, str] | None = None,
response_check: Callable[..., bool] | None = None,
response_filter: Callable[..., Any] | None = None,
extra_options: dict[str, Any] | None = None,
http_conn_id: str = "http_default",
log_response: bool = False,
Expand Down Expand Up @@ -130,6 +136,8 @@ def __init__(
self.method = method
self.endpoint = endpoint
self.headers = headers or {}
self.response_check = response_check
self.response_filter = response_filter
self.data = data or {}
self.extra_options = extra_options or {}
self.log_response = log_response
Expand Down Expand Up @@ -172,22 +180,57 @@ def gcs_hook(self) -> GCSHook:

def execute(self, context: Context):
self.log.info("Calling HTTP method")
response = self.http_hook.run(
raw_response = self.http_hook.run(
endpoint=self.endpoint, data=self.data, headers=self.headers, extra_options=self.extra_options
)

self.log.info("Evaluating HTTP response")
processed_response = self.process_response(context=context, response=raw_response)

self.log.info("Uploading to GCS")
self.gcs_hook.upload(
data=response.content,
data=processed_response.content,
bucket_name=self.bucket_name,
object_name=self.object_name,
mime_type=self.mime_type,
gzip=self.gzip,
encoding=self.encoding or response.encoding,
encoding=self.encoding or processed_response.encoding,
chunk_size=self.chunk_size,
timeout=self.timeout,
num_max_attempts=self.num_max_attempts,
metadata=self.metadata,
cache_control=self.cache_control,
user_project=self.user_project,
)

@staticmethod
def _default_response_maker(response: Response | list[Response]) -> Callable:
"""
Create a default response maker function based on the type of response.

:param response: The response object or list of response objects.
:return: A function that returns response text(s).
"""
if isinstance(response, Response):
response_object = response # Makes mypy happy
return lambda: response_object.text

response_list: list[Response] = response # Makes mypy happy
return lambda: [entry.text for entry in response_list]

def process_response(self, context: Context, response: Response | list[Response]) -> Any:
"""Process the response."""
from airflow.utils.operator_helpers import determine_kwargs

make_default_response: Callable = self._default_response_maker(response=response)

if self.log_response:
self.log.info(make_default_response())
if self.response_check:
kwargs = determine_kwargs(self.response_check, [response], context)
if not self.response_check(response, **kwargs):
raise AirflowException("Response check returned False.")
if self.response_filter:
kwargs = determine_kwargs(self.response_filter, [response], context)
return self.response_filter(response, **kwargs)
return make_default_response()
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from airflow.providers.google.cloud.transfers.http_to_gcs import HttpToGCSOperator

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add unit tests to cover the functionality you have introduced? Thanks

TASK_CONTEXT = None
TASK_ID = "test-http-to-gcs-operator"
GCP_CONN_ID = "GCP_CONN_ID"
HTTP_CONN_ID = "HTTP_CONN_ID"
Expand Down Expand Up @@ -69,7 +70,7 @@ def test_execute_copy_single_file(self, http_hook, gcs_hook):
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
task.execute(None)
task.execute(TASK_CONTEXT)

# GCS
gcs_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
Expand Down Expand Up @@ -101,3 +102,6 @@ def test_execute_copy_single_file(self, http_hook, gcs_hook):
task.http_hook.run.assert_called_once_with(
endpoint=ENDPOINT, headers=HEADERS, data=DATA, extra_options=EXTRA_OPTIONS
)
task.process_response.assert_called_once_with(
context=TASK_CONTEXT, response=task.http_hook.run.return_value
)
Loading