Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add description method in BigQueryCursor class #25366

Merged
merged 5 commits into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2663,11 +2663,16 @@ def __init__(
self.job_id = None # type: Optional[str]
self.buffer = [] # type: list
self.all_pages_loaded = False # type: bool
self._description = [] # type: List

@property
def description(self) -> None:
"""The schema description method is not currently implemented"""
raise NotImplementedError
def description(self) -> List:
"""Return the cursor description"""
return self._description

@description.setter
def description(self, value):
self._description = value

def close(self) -> None:
"""By default, do nothing"""
Expand All @@ -2688,6 +2693,10 @@ def execute(self, operation: str, parameters: Optional[dict] = None) -> None:
self.flush_results()
self.job_id = self.hook.run_query(sql)

query_results = self._get_query_result()
description = _format_schema_for_description(query_results["schema"])
self.description = description

def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Execute a BigQuery query multiple times with different parameters.
Expand Down Expand Up @@ -2723,17 +2732,7 @@ def next(self) -> Union[List, None]:
if self.all_pages_loaded:
return None

query_results = (
self.service.jobs()
.getQueryResults(
projectId=self.project_id,
jobId=self.job_id,
location=self.location,
pageToken=self.page_token,
)
.execute(num_retries=self.num_retries)
)

query_results = self._get_query_result()
if 'rows' in query_results and query_results['rows']:
self.page_token = query_results.get('pageToken')
fields = query_results['schema']['fields']
Expand Down Expand Up @@ -2805,6 +2804,21 @@ def setinputsizes(self, sizes: Any) -> None:
def setoutputsize(self, size: Any, column: Any = None) -> None:
"""Does nothing by default"""

def _get_query_result(self) -> Dict:
"""Get job query results like data, schema, job type..."""
query_results = (
self.service.jobs()
.getQueryResults(
projectId=self.project_id,
jobId=self.job_id,
location=self.location,
pageToken=self.page_token,
)
.execute(num_retries=self.num_retries)
)

return query_results


def _bind_parameters(operation: str, parameters: dict) -> str:
"""Helper method that binds parameters to a SQL query"""
Expand Down Expand Up @@ -2973,3 +2987,23 @@ def _validate_src_fmt_configs(
raise ValueError(f"{k} is not a valid src_fmt_configs for type {source_format}.")

return src_fmt_configs


def _format_schema_for_description(schema: Dict) -> List:
"""
Reformat the schema to match cursor description standard which is a tuple
of 7 elemenbts (name, type, display_size, internal_size, precision, scale, null_ok)
"""
description = []
for field in schema["fields"]:
field_description = (
field["name"],
field["type"],
None,
None,
None,
None,
field["mode"] == "NULLABLE",
)
description.append(field_description)
return description
29 changes: 26 additions & 3 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BigQueryHook,
_api_resource_configs_duplication_check,
_cleanse_time_partitioning,
_format_schema_for_description,
_validate_src_fmt_configs,
_validate_value,
split_tablename,
Expand Down Expand Up @@ -1239,11 +1240,33 @@ def test_execute_many(self, mock_insert, _):
]
)

def test_format_schema_for_description(self):
test_query_result = {
"schema": {
"fields": [
{"name": "field_1", "type": "STRING", "mode": "NULLABLE"},
]
},
}
description = _format_schema_for_description(test_query_result["schema"])
assert description == [('field_1', 'STRING', None, None, None, None, True)]

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_description(self, mock_get_service):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_description(self, mock_insert, mock_get_service):
mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {
"schema": {
"fields": [
{"name": "ts", "type": "TIMESTAMP", "mode": "NULLABLE"},
]
},
}

bq_cursor = self.hook.get_cursor()
with pytest.raises(NotImplementedError):
bq_cursor.description
bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts")
assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None, None, True)]

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_close(self, mock_get_service):
Expand Down