Skip to content

Commit

Permalink
cell magic accepts table_ids instead of queries
Browse files Browse the repository at this point in the history
added default patch to unit tests
  • Loading branch information
shubha-rajan committed Sep 6, 2019
1 parent f00b60b commit 77e62c5
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 7 deletions.
38 changes: 31 additions & 7 deletions bigquery/google/cloud/bigquery/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@

from __future__ import print_function

import re
import ast
import sys
import time
Expand Down Expand Up @@ -266,6 +267,15 @@ def default_query_job_config(self, value):
context = Context()


def _print_error(error, destination_var=None):
if destination_var:
print(
"Could not save output to variable '{}'.".format(destination_var),
file=sys.stderr,
)
print("\nERROR:\n", error, file=sys.stderr)


def _run_query(client, query, job_config=None):
"""Runs a query while printing status updates
Expand Down Expand Up @@ -434,6 +444,26 @@ def _cell_magic(line, query):
else:
max_results = None

error = None

if not re.search(r"\s", query.rstrip()):
table_id = query.rstrip()

try:
rows = client.list_rows(table_id, max_results=max_results)
except Exception as ex:
error = str(ex)
if error:
_print_error(error, args.destination_var)
return

result = rows.to_dataframe(bqstorage_client=bqstorage_client)
if args.destination_var:
IPython.get_ipython().push({args.destination_var: result})
return
else:
return result

job_config = bigquery.job.QueryJobConfig()
job_config.query_parameters = params
job_config.use_legacy_sql = args.use_legacy_sql
Expand All @@ -445,7 +475,6 @@ def _cell_magic(line, query):
value = int(args.maximum_bytes_billed)
job_config.maximum_bytes_billed = value

error = None
try:
query_job = _run_query(client, query, job_config=job_config)
except Exception as ex:
Expand All @@ -455,12 +484,7 @@ def _cell_magic(line, query):
display.clear_output()

if error:
if args.destination_var:
print(
"Could not save output to variable '{}'.".format(args.destination_var),
file=sys.stderr,
)
print("\nERROR:\n", error, file=sys.stderr)
_print_error(error, args.destination_var)
return

if args.dry_run and args.destination_var:
Expand Down
108 changes: 108 additions & 0 deletions bigquery/tests/unit/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,114 @@ def test_bigquery_magic_w_max_results_valid_calls_queryjob_result():
query_job_mock.result.assert_called_with(max_results=5)


def test_bigquery_magic_w_table_id_invalid():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context._project = None

credentials_mock = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
default_patch = mock.patch(
"google.auth.default", return_value=(credentials_mock, "general-project")
)

list_rows_patch = mock.patch(
"google.cloud.bigquery.magics.bigquery.Client.list_rows",
autospec=True,
side_effect=exceptions.BadRequest("Not a valid table ID"),
)

table_id = "not-a-real-table"

with list_rows_patch, default_patch, io.capture_output() as captured_io:
ip.run_cell_magic("bigquery", "df", table_id)

output = captured_io.stderr
assert "Could not save output to variable" in output
assert "400 Not a valid table ID" in output
assert "Traceback (most recent call last)" not in output


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_w_table_id_and_destination_var():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context._project = None

credentials_mock = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
default_patch = mock.patch(
"google.auth.default", return_value=(credentials_mock, "general-project")
)

row_iterator_mock = mock.create_autospec(
google.cloud.bigquery.table.RowIterator, instance=True
)

client_patch = mock.patch(
"google.cloud.bigquery.magics.bigquery.Client", autospec=True
)

table_id = "bigquery-public-data.samples.shakespeare"
result = pandas.DataFrame([17], columns=["num"])

with client_patch as client_mock, default_patch:
client_mock().list_rows.return_value = row_iterator_mock
row_iterator_mock.to_dataframe.return_value = result

ip.run_cell_magic("bigquery", "df", table_id)

assert "df" in ip.user_ns
df = ip.user_ns["df"]

assert isinstance(df, pandas.DataFrame)


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_w_table_id_and_bqstorage_client():
ip = IPython.get_ipython()
ip.extension_manager.load_extension("google.cloud.bigquery")
magics.context._project = None

credentials_mock = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
default_patch = mock.patch(
"google.auth.default", return_value=(credentials_mock, "general-project")
)

row_iterator_mock = mock.create_autospec(
google.cloud.bigquery.table.RowIterator, instance=True
)

client_patch = mock.patch(
"google.cloud.bigquery.magics.bigquery.Client", autospec=True
)

bqstorage_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient
)
bqstorage_instance_mock = mock.create_autospec(
bigquery_storage_v1beta1.BigQueryStorageClient, instance=True
)
bqstorage_mock.return_value = bqstorage_instance_mock
bqstorage_client_patch = mock.patch(
"google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock
)

table_id = "bigquery-public-data.samples.shakespeare"

with default_patch, client_patch as client_mock, bqstorage_client_patch:
client_mock().list_rows.return_value = row_iterator_mock

ip.run_cell_magic("bigquery", "--use_bqstorage_api --max_results=5", table_id)
row_iterator_mock.to_dataframe.assert_called_once_with(
bqstorage_client=bqstorage_instance_mock
)


@pytest.mark.usefixtures("ipython_interactive")
def test_bigquery_magic_dryrun_option_sets_job_config():
ip = IPython.get_ipython()
Expand Down

0 comments on commit 77e62c5

Please sign in to comment.