Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3125,20 +3125,26 @@ async def create_job_for_partition_get(
job_query_resp = await job_client.query(query_request, cast(Session, session))
return job_query_resp["jobReference"]["jobId"]

def get_records(self, query_results: dict[str, Any]) -> list[Any]:
def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]:
"""
Given the output query response from gcloud-aio bigquery, convert the response to records.

:param query_results: the results from a SQL query
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists.
"""
buffer = []
buffer: list[Any] = []
if "rows" in query_results and query_results["rows"]:
rows = query_results["rows"]
fields = query_results["schema"]["fields"]
col_types = [field["type"] for field in fields]
for dict_row in rows:
typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])]
buffer.append(typed_row)
if not as_dict:
buffer.append(typed_row)
else:
fields_names = [field["name"] for field in fields]
typed_row_dict = {k: v for k, v in zip(fields_names, typed_row)}
buffer.append(typed_row_dict)
return buffer

def value_check(
Expand Down
35 changes: 25 additions & 10 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,12 +759,19 @@ def execute(self, context=None):

class BigQueryGetDataOperator(GoogleCloudBaseOperator):
"""
Fetches the data from a BigQuery table (alternatively fetch data for selected columns)
and returns data in a python list. The number of elements in the returned list will
be equal to the number of rows fetched. Each element in the list will again be a list
where element would represent the columns values for that row.
Fetches the data from a BigQuery table (alternatively fetch data for selected columns) and returns data
in either of the following two formats, based on "as_dict" value:
1. False (Default) - A Python list of lists, with the number of nested lists equal to the number of rows
fetched. Each nested list represents a row, where the elements within it correspond to the column values
for that particular row.

**Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]``
**Example Result**: ``[['Tony', 10], ['Mike', 20]``


2. True - A Python list of dictionaries, where each dictionary represents a row. In each dictionary,
the keys are the column names and the values are the corresponding values for those columns.

**Example Result**: ``[{'name': 'Tony', 'age': 10}, {'name': 'Mike', 'age': 20}]``

.. seealso::
For more information on how to use this operator, take a look at the guide:
Expand Down Expand Up @@ -811,6 +818,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param deferrable: Run operator in the deferrable mode
:param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job.
Defaults to 4 seconds.
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists
(default: False).
"""

template_fields: Sequence[str] = (
Expand All @@ -836,6 +845,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
poll_interval: float = 4.0,
as_dict: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -850,6 +860,7 @@ def __init__(
self.project_id = project_id
self.deferrable = deferrable
self.poll_interval = poll_interval
self.as_dict = as_dict

def _submit_job(
self,
Expand Down Expand Up @@ -885,7 +896,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.hook = hook

if not self.deferrable:
self.log.info(
Expand All @@ -911,21 +921,26 @@ def execute(self, context: Context):

self.log.info("Total extracted rows: %s", len(rows))

table_data = [row.values() for row in rows]
if self.as_dict:
table_data = [{k: v for k, v in row.items()} for row in rows]
else:
table_data = [row.values() for row in rows]

return table_data

job = self._submit_job(hook, job_id="")
self.job_id = job.job_id
context["ti"].xcom_push(key="job_id", value=self.job_id)

context["ti"].xcom_push(key="job_id", value=job.job_id)
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryGetDataTrigger(
conn_id=self.gcp_conn_id,
job_id=self.job_id,
job_id=job.job_id,
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=hook.project_id,
poll_interval=self.poll_interval,
as_dict=self.as_dict,
),
method_name="execute_complete",
)
Expand Down
13 changes: 11 additions & 2 deletions airflow/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]


class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
"""BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class"""
"""
BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class

:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists
(default: False).
"""

def __init__(self, as_dict: bool = False, **kwargs):
super().__init__(**kwargs)
self.as_dict = as_dict

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryInsertJobTrigger arguments and classpath."""
Expand All @@ -190,7 +199,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if response_from_hook == "success":
query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
records = hook.get_records(query_results)
records = hook.get_records(query_results=query_results, as_dict=self.as_dict)
self.log.debug("Response from hook: %s", response_from_hook)
yield TriggerEvent(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,11 @@ To fetch data from a BigQuery table you can use
Alternatively you can fetch data for selected columns if you pass fields to
``selected_fields``.

This operator returns data in a Python list where the number of elements in the
returned list will be equal to the number of rows fetched. Each element in the
list will again be a list where elements would represent the column values for
The result of this operator can be retrieved in two different formats based on the value of the ``as_dict`` parameter:
``False`` (default) - A Python list of lists, where the number of elements in the nesting list will be equal to the number of rows fetched. Each element in the
nesting will a nested list where elements would represent the column values for
that row.
``True`` - A Python list of dictionaries, where each dictionary represents a row. In each dictionary, the keys are the column names and the values are the corresponding values for those columns.

.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
:language: python
Expand Down
26 changes: 26 additions & 0 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,3 +2348,29 @@ def test_get_records_return_type(self):
assert isinstance(result[0][0], int)
assert isinstance(result[0][1], float)
assert isinstance(result[0][2], str)

def test_get_records_as_dict(self):
query_result = {
"kind": "bigquery#getQueryResultsResponse",
"etag": "test_etag",
"schema": {
"fields": [
{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"},
{"name": "f1_", "type": "FLOAT", "mode": "NULLABLE"},
{"name": "f2_", "type": "STRING", "mode": "NULLABLE"},
]
},
"jobReference": {
"projectId": "test_airflow-providers",
"jobId": "test_jobid",
"location": "US",
},
"totalRows": "1",
"rows": [{"f": [{"v": "22"}, {"v": "3.14"}, {"v": "PI"}]}],
"totalBytesProcessed": "0",
"jobComplete": True,
"cacheHit": False,
}
hook = BigQueryAsyncHook()
result = hook.get_records(query_result, as_dict=True)
assert result == [{"f0_": 22, "f1_": 3.14, "f2_": "PI"}]
16 changes: 12 additions & 4 deletions tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,9 @@ def test_bigquery_operator_extra_link_when_multiple_query(


class TestBigQueryGetDataOperator:
@pytest.mark.parametrize("as_dict", [True, False])
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_execute(self, mock_hook):
def test_execute(self, mock_hook, as_dict):
max_results = 100
selected_fields = "DATE"
operator = BigQueryGetDataOperator(
Expand All @@ -796,6 +797,7 @@ def test_execute(self, mock_hook):
max_results=max_results,
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
as_dict=as_dict,
)
operator.execute(None)
mock_hook.return_value.list_rows.assert_called_once_with(
Expand Down Expand Up @@ -839,9 +841,10 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
exc.value.trigger, BigQueryGetDataTrigger
), "Trigger is not a BigQueryGetDataTrigger"

@pytest.mark.parametrize("as_dict", [True, False])
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_get_data_operator_async_without_selected_fields(
self, mock_hook, create_task_instance_of_operator
self, mock_hook, create_task_instance_of_operator, as_dict
):
"""
Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired
Expand All @@ -861,6 +864,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
table_id=TEST_TABLE_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -870,7 +874,8 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
exc.value.trigger, BigQueryGetDataTrigger
), "Trigger is not a BigQueryGetDataTrigger"

def test_bigquery_get_data_operator_execute_failure(self):
@pytest.mark.parametrize("as_dict", [True, False])
def test_bigquery_get_data_operator_execute_failure(self, as_dict):
"""Tests that an AirflowException is raised in case of error event"""

operator = BigQueryGetDataOperator(
Expand All @@ -879,14 +884,16 @@ def test_bigquery_get_data_operator_execute_failure(self):
table_id="any",
max_results=100,
deferrable=True,
as_dict=as_dict,
)

with pytest.raises(AirflowException):
operator.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

def test_bigquery_get_data_op_execute_complete_with_records(self):
@pytest.mark.parametrize("as_dict", [True, False])
def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
"""Asserts that exception is raised with correct expected exception message"""

operator = BigQueryGetDataOperator(
Expand All @@ -895,6 +902,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self):
table_id="any",
max_results=100,
deferrable=True,
as_dict=as_dict,
)

with mock.patch.object(operator.log, "info") as mock_log_info:
Expand Down