Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigQueryTask #678

Merged
merged 6 commits into from
Feb 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/
- Add `checkpoint` option for individual `Task`s, as well as a global `checkpoint` config setting for storing the results of Tasks using their result handlers - [#649](https://github.com/PrefectHQ/prefect/pull/649)
- Add `defaults_from_attrs` decorator to easily construct `Task`s whose attributes serve as defaults for `Task.run` - [#293](https://github.com/PrefectHQ/prefect/issues/293)
- Add `GCSUploadTask` and `GCSDownloadTask` for uploading / retrieving string data to / from Google Cloud Storage - [#673](https://github.com/PrefectHQ/prefect/pull/673)
- Add `BigQueryTask` for executing queries against BigQuery tables - [#678](https://github.com/PrefectHQ/prefect/pull/678)

### Enhancements

Expand Down
2 changes: 1 addition & 1 deletion docs/outline.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ functions = ["switch", "ifelse", "merge"]
[pages.tasks.google]
title = "Google Cloud Tasks"
module = "prefect.tasks.google"
classes = ["GCSDownloadTask", "GCSUploadTask"]
classes = ["GCSDownloadTask", "GCSUploadTask", "BigQueryTask"]

[pages.tasks.sqlite]
title = "SQLite Tasks"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cryptography >= 2.2.2, < 3.0
dask >= 0.18, < 2.0
distributed >= 1.21.8, < 2.0
docker >= 3.4.1, < 4.0
google-cloud-bigquery >= 1.6.0, < 2.0
google-cloud-storage >= 1.13, < 2.0
idna < 2.8, >= 2.5
marshmallow == 3.0.0b19
Expand Down
1 change: 1 addition & 0 deletions src/prefect/tasks/google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
import prefect.tasks.google.storage

from prefect.tasks.google.storage import GCSDownloadTask, GCSUploadTask
from prefect.tasks.google.bigquery import BigQueryTask
171 changes: 171 additions & 0 deletions src/prefect/tasks/google/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import json
import uuid

from google.oauth2.service_account import Credentials
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
from typing import List

from prefect import context
from prefect.client import Secret
from prefect.core import Task
from prefect.utilities.tasks import defaults_from_attrs


class BigQueryTask(Task):
"""
Task for executing queries against a Google BigQuery table and (optionally) returning
the results. Note that _all_ initialization settings can be provided / overwritten at runtime.

Args:
- query (str, optional): a string of the query to execute
- query_params (list[tuple], optional): a list of 3-tuples specifying
BigQuery query parameters; currently only scalar query parameters are supported. See
[the Google documentation](https://cloud.google.com/bigquery/docs/parameterized-queries#bigquery-query-params-python)
for more details on how both the query and the query parameters should be formatted
- project (str, optional): the project to initialize the BigQuery Client with; if not provided,
will default to the one inferred from your credentials
- location (str, optional): location of the dataset which will be queried; defaults to "US"
- dry_run_max_bytes (int, optional): if provided, the maximum number of bytes the query is allowed
to process; this will be determined by executing a dry run and raising a `ValueError` if the
maximum is exceeded
- credentials_secret (str, optional): the name of the Prefect Secret containing a JSON representation
of your Google Application credentials; defaults to `"GOOGLE_APPLICATION_CREDENTIALS"`
- dataset_dest (str, optional): the optional name of a destination dataset to write the
query results to, if you don't want them returned; if provided, `table_dest` must also be
provided
- table_dest (str, optional): the optional name of a destination table to write the
query results to, if you don't want them returned; if provided, `dataset_dest` must also be
provided
- job_config (dict, optional): an optional dictionary of job configuration parameters; note that
the parameters provided here must be pickleable (e.g., dataset references will be rejected)
- **kwargs (optional): additional kwargs to pass to the `Task` constructor
"""

def __init__(
self,
query: str = None,
query_params: List[tuple] = None, # 3-tuples
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this arg is unused

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good eye; used now.

project: str = None,
location: str = "US",
dry_run_max_bytes: int = None,
credentials_secret: str = None,
dataset_dest: str = None,
table_dest: str = None,
job_config: dict = None,
**kwargs
):
self.query = query
self.query_params = query_params
self.project = project
self.location = location
self.dry_run_max_bytes = dry_run_max_bytes
self.credentials_secret = credentials_secret or "GOOGLE_APPLICATION_CREDENTIALS"
self.dataset_dest = dataset_dest
self.table_dest = table_dest
self.job_config = job_config or {}
super().__init__(**kwargs)

@defaults_from_attrs(
"query",
"query_params",
"project",
"location",
"dry_run_max_bytes",
"credentials_secret",
"dataset_dest",
"table_dest",
"job_config",
)
def run(
self,
query: str = None,
query_params: List[tuple] = None,
project: str = None,
location: str = "US",
dry_run_max_bytes: int = None,
credentials_secret: str = None,
dataset_dest: str = None,
table_dest: str = None,
job_config: dict = None,
):
"""
Run method for this Task. Invoked by _calling_ this Task within a Flow context, after initialization.

Args:
- query (str, optional): a string of the query to execute
- query_params (list[tuple], optional): a list of 3-tuples specifying
BigQuery query parameters; currently only scalar query parameters are supported. See
[the Google documentation](https://cloud.google.com/bigquery/docs/parameterized-queries#bigquery-query-params-python)
for more details on how both the query and the query parameters should be formatted
- project (str, optional): the project to initialize the BigQuery Client with; if not provided,
will default to the one inferred from your credentials
- location (str, optional): location of the dataset which will be queried; defaults to "US"
- dry_run_max_bytes (int, optional): if provided, the maximum number of bytes the query is allowed
to process; this will be determined by executing a dry run and raising a `ValueError` if the
maximum is exceeded
- credentials_secret (str, optional): the name of the Prefect Secret containing a JSON representation
of your Google Application credentials; defaults to `"GOOGLE_APPLICATION_CREDENTIALS"`
- dataset_dest (str, optional): the optional name of a destination dataset to write the
query results to, if you don't want them returned; if provided, `table_dest` must also be
provided
- table_dest (str, optional): the optional name of a destination table to write the
query results to, if you don't want them returned; if provided, `dataset_dest` must also be
provided
- job_config (dict, optional): an optional dictionary of job configuration parameters; note that
the parameters provided here must be pickleable (e.g., dataset references will be rejected)

Raises:
- ValueError: if the `query` is `None`
- ValueError: if only one of `dataset_dest` / `table_dest` is provided
- ValueError: if the query will execeed `dry_run_max_bytes`

Returns:
- list: a fully populated list of Query results, with one item per row
"""
## check for any argument inconsistencies
if query is None:
raise ValueError("No query provided.")
if sum([dataset_dest is None, table_dest is None]) == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:mind blown:

raise ValueError(
"Both `dataset_dest` and `table_dest` must be provided if writing to a destination table."
)

## create client
creds = json.loads(Secret(credentials_secret).get())
credentials = Credentials.from_service_account_info(creds)
project = project or credentials.project_id
client = bigquery.Client(project=project, credentials=credentials)

## setup jobconfig
job_config = bigquery.QueryJobConfig(**job_config)
if query_params is not None:
hydrated_params = [
bigquery.ScalarQueryParameter(*qp) for qp in query_params
]
job_config.query_parameters = hydrated_params

## perform dry_run if requested
if dry_run_max_bytes is not None:
old_info = dict(
dry_run=job_config.dry_run, use_query_cache=job_config.use_query_cache
)
job_config.dry_run = True
job_config.use_query_cache = False
query_job = client.query(query, location=location, job_config=job_config)
if query_job.total_bytes_processed > dry_run_max_bytes:
raise ValueError(
"Query will process {0} bytes which is above the set maximum of {1} for this task.".format(
query_job.total_bytes_processed, dry_run_max_bytes
)
)
job_config.dry_run = old_info["dry_run"]
job_config.use_query_cache = old_info["use_query_cache"]

## if writing to a destination table
if dataset_dest is not None:
table_ref = client.dataset(dataset_dest).table(table_dest)
job_config.destination = table_ref

query_job = client.query(query, location=location, job_config=job_config)
return list(query_job.result())
159 changes: 159 additions & 0 deletions tests/tasks/google/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import pytest
from unittest.mock import MagicMock

import prefect
from prefect.tasks.google import BigQueryTask
from prefect.utilities.configuration import set_temporary_config


class TestInitialization:
def test_initializes_with_nothing_and_sets_defaults(self):
task = BigQueryTask()
assert task.query is None
assert task.query_params is None
assert task.project is None
assert task.location == "US"
assert task.dry_run_max_bytes is None
assert task.credentials_secret == "GOOGLE_APPLICATION_CREDENTIALS"
assert task.dataset_dest is None
assert task.table_dest is None
assert task.job_config == dict()

def test_additional_kwargs_passed_upstream(self):
task = BigQueryTask(name="test-task", checkpoint=True, tags=["bob"])
assert task.name == "test-task"
assert task.checkpoint is True
assert task.tags == {"bob"}

@pytest.mark.parametrize(
"attr",
[
"query",
"query_params",
"project",
"location",
"dry_run_max_bytes",
"credentials_secret",
"dataset_dest",
"table_dest",
"job_config",
],
)
def test_initializes_attr_from_kwargs(self, attr):
task = BigQueryTask(**{attr: "my-value"})
assert getattr(task, attr) == "my-value"

def test_query_is_required_eventually(self):
task = BigQueryTask()
with pytest.raises(ValueError) as exc:
task.run()
assert "query" in str(exc.value)

@pytest.mark.parametrize("attr", ["dataset_dest", "table_dest"])
def test_dataset_dest_and_table_dest_are_required_together_eventually(self, attr):
task = BigQueryTask(**{attr: "some-value"})
with pytest.raises(ValueError) as exc:
task.run(query="SELECT *")
assert attr in str(exc.value)
assert "must be provided" in str(exc.value)


class TestCredentialsandProjects:
def test_creds_are_pulled_from_secret_at_runtime(self, monkeypatch):
task = BigQueryTask()

creds_loader = MagicMock()
monkeypatch.setattr("prefect.tasks.google.bigquery.Credentials", creds_loader)
monkeypatch.setattr(
"prefect.tasks.google.bigquery.bigquery.Client", MagicMock()
)

with set_temporary_config({"cloud.use_local_secrets": True}):
with prefect.context(secrets=dict(GOOGLE_APPLICATION_CREDENTIALS="42")):
task.run(query="SELECT *")

assert creds_loader.from_service_account_info.call_args[0][0] == 42

def test_creds_secret_name_can_be_overwritten_at_anytime(self, monkeypatch):
task = BigQueryTask(credentials_secret="TEST")

creds_loader = MagicMock()
monkeypatch.setattr("prefect.tasks.google.bigquery.Credentials", creds_loader)
monkeypatch.setattr(
"prefect.tasks.google.bigquery.bigquery.Client", MagicMock()
)

with set_temporary_config({"cloud.use_local_secrets": True}):
with prefect.context(secrets=dict(TEST="42", RUN="{}")):
task.run(query="SELECT *")
task.run(query="SELECT *", credentials_secret="RUN")

first_call, second_call = creds_loader.from_service_account_info.call_args_list
assert first_call[0][0] == 42
assert second_call[0][0] == {}

def test_project_is_pulled_from_creds_and_can_be_overriden_at_anytime(
self, monkeypatch
):
task = BigQueryTask()
task_proj = BigQueryTask(project="test-init")

client = MagicMock()
service_account_info = MagicMock(return_value=MagicMock(project_id="default"))
monkeypatch.setattr(
"prefect.tasks.google.bigquery.Credentials",
MagicMock(from_service_account_info=service_account_info),
)
monkeypatch.setattr("prefect.tasks.google.bigquery.bigquery.Client", client)

with set_temporary_config({"cloud.use_local_secrets": True}):
with prefect.context(secrets=dict(GOOGLE_APPLICATION_CREDENTIALS="{}")):
task.run(query="SELECT *")
task_proj.run(query="SELECT *")
task_proj.run(query="SELECT *", project="run-time")

x, y, z = client.call_args_list

assert x[1]["project"] == "default" ## pulled from credentials
assert y[1]["project"] == "test-init" ## pulled from init
assert z[1]["project"] == "run-time" ## pulled from run kwarg


class TestDryRuns:
def test_dry_run_doesnt_raise_if_limit_not_exceeded(self, monkeypatch):
task = BigQueryTask(dry_run_max_bytes=1200)

client = MagicMock(
query=MagicMock(return_value=MagicMock(total_bytes_processed=1200))
)
monkeypatch.setattr("prefect.tasks.google.bigquery.Credentials", MagicMock())
monkeypatch.setattr(
"prefect.tasks.google.bigquery.bigquery.Client",
MagicMock(return_value=client),
)

with set_temporary_config({"cloud.use_local_secrets": True}):
with prefect.context(secrets=dict(GOOGLE_APPLICATION_CREDENTIALS="{}")):
task.run(query="SELECT *")

def test_dry_run_raises_if_limit_is_exceeded(self, monkeypatch):
task = BigQueryTask(dry_run_max_bytes=1200)

client = MagicMock(
query=MagicMock(return_value=MagicMock(total_bytes_processed=21836427))
)
monkeypatch.setattr("prefect.tasks.google.bigquery.Credentials", MagicMock())
monkeypatch.setattr(
"prefect.tasks.google.bigquery.bigquery.Client",
MagicMock(return_value=client),
)

with set_temporary_config({"cloud.use_local_secrets": True}):
with prefect.context(secrets=dict(GOOGLE_APPLICATION_CREDENTIALS="{}")):
with pytest.raises(ValueError) as exc:
task.run(query="SELECT *")

assert (
"Query will process 21836427 bytes which is above the set maximum of 1200 for this task"
in str(exc.value)
)