Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
Defaults to 4 seconds.
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists
(default: False).
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
"""

template_fields: Sequence[str] = (
Expand All @@ -845,6 +846,7 @@ def __init__(
deferrable: bool = False,
poll_interval: float = 4.0,
as_dict: bool = False,
use_legacy_sql: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -860,14 +862,15 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval
self.as_dict = as_dict
self.use_legacy_sql = use_legacy_sql

def _submit_job(
self,
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
get_query = self.generate_query()
configuration = {"query": {"query": get_query}}
configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}}
"""Submit a new job and get the job id for polling the status using Triggerer."""
return hook.insert_job(
configuration=configuration,
Expand All @@ -887,18 +890,23 @@ def generate_query(self) -> str:
query += self.selected_fields
else:
query += "*"
query += f" from {self.dataset_id}.{self.table_id} limit {self.max_results}"
query += f" from `{self.project_id}.{self.dataset_id}.{self.table_id}` limit {self.max_results}"
return query

def execute(self, context: Context):
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
use_legacy_sql=self.use_legacy_sql,
)

if not self.deferrable:
self.log.info(
"Fetching Data from %s.%s max results: %s", self.dataset_id, self.table_id, self.max_results
"Fetching Data from %s.%s.%s max results: %s",
self.project_id,
self.dataset_id,
self.table_id,
self.max_results,
)
if not self.selected_fields:
schema: dict[str, list] = hook.get_schema(
Expand Down
20 changes: 8 additions & 12 deletions tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"refreshIntervalMs": 2000000,
}
TEST_TABLE = "test-table"
GCP_CONN_ID = "google_cloud_default"


class TestBigQueryCreateEmptyTableOperator:
Expand Down Expand Up @@ -791,6 +792,7 @@ def test_execute(self, mock_hook, as_dict):
max_results = 100
selected_fields = "DATE"
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
Expand All @@ -799,8 +801,10 @@ def test_execute(self, mock_hook, as_dict):
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
as_dict=as_dict,
use_legacy_sql=False,
)
operator.execute(None)
mock_hook.assert_called_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, use_legacy_sql=False)
mock_hook.return_value.list_rows.assert_called_once_with(
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
Expand All @@ -818,12 +822,6 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
Asserts that a task is deferred and a BigQuerygetDataTrigger will be fired
when the BigQueryGetDataOperator is executed with deferrable=True.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

ti = create_task_instance_of_operator(
BigQueryGetDataOperator,
dag_id="dag_id",
Expand All @@ -833,6 +831,7 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
max_results=100,
selected_fields="value,name",
deferrable=True,
use_legacy_sql=False,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -851,12 +850,6 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired
when the BigQueryGetDataOperator is executed with deferrable=True.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

ti = create_task_instance_of_operator(
BigQueryGetDataOperator,
dag_id="dag_id",
Expand All @@ -866,6 +859,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -886,6 +880,7 @@ def test_bigquery_get_data_operator_execute_failure(self, as_dict):
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with pytest.raises(AirflowException):
Expand All @@ -904,6 +899,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with mock.patch.object(operator.log, "info") as mock_log_info:
Expand Down