Skip to content

Commit

Permalink
Optionally include indexes in table written by `load_table_from_dataf…
Browse files Browse the repository at this point in the history
…rame`. (#9084)

* Specify the index data type in partial schema to `load_table_from_dataframe` to include it.

If an index (or level of a multi-index) has a name and is present in the
schema passed to `load_table_from_dataframe`, then that index will be
serialized and written to the table. Otherwise, the index is omitted
from the serialized table.

* Don't include index if has same name as column name.

* Move `load_table_dataframe` sample from `snippets.py` to `samples/`.

Sample now demonstrates how to manually include the index with a
partial schema definition. Update docs reference to new
`load_table_dataframe` sample location.
  • Loading branch information
tswast authored Aug 28, 2019
1 parent 0c7dd62 commit a6ed945
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 57 deletions.
47 changes: 0 additions & 47 deletions bigquery/docs/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2536,52 +2536,5 @@ def test_list_rows_as_dataframe(client):
assert len(df) == table.num_rows # verify the number of rows


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.parametrize("parquet_engine", ["pyarrow", "fastparquet"])
def test_load_table_from_dataframe(client, to_delete, parquet_engine):
if parquet_engine == "pyarrow" and pyarrow is None:
pytest.skip("Requires `pyarrow`")
if parquet_engine == "fastparquet" and fastparquet is None:
pytest.skip("Requires `fastparquet`")

pandas.set_option("io.parquet.engine", parquet_engine)

dataset_id = "load_table_from_dataframe_{}".format(_millis())
dataset = bigquery.Dataset(client.dataset(dataset_id))
client.create_dataset(dataset)
to_delete.append(dataset)

# [START bigquery_load_table_dataframe]
# from google.cloud import bigquery
# import pandas
# client = bigquery.Client()
# dataset_id = 'my_dataset'

dataset_ref = client.dataset(dataset_id)
table_ref = dataset_ref.table("monty_python")
records = [
{"title": u"The Meaning of Life", "release_year": 1983},
{"title": u"Monty Python and the Holy Grail", "release_year": 1975},
{"title": u"Life of Brian", "release_year": 1979},
{"title": u"And Now for Something Completely Different", "release_year": 1971},
]
# Optionally set explicit indices.
# If indices are not specified, a column will be created for the default
# indices created by pandas.
index = [u"Q24980", u"Q25043", u"Q24953", u"Q16403"]
dataframe = pandas.DataFrame(records, index=pandas.Index(index, name="wikidata_id"))

job = client.load_table_from_dataframe(dataframe, table_ref, location="US")

job.result() # Waits for table load to complete.

assert job.state == "DONE"
table = client.get_table(table_ref)
assert table.num_rows == 4
# [END bigquery_load_table_dataframe]
column_names = [field.name for field in table.schema]
assert sorted(column_names) == ["release_year", "title", "wikidata_id"]


if __name__ == "__main__":
pytest.main()
2 changes: 1 addition & 1 deletion bigquery/docs/usage/pandas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ install the BigQuery python client library with :mod:`pandas` and
The following example demonstrates how to create a :class:`pandas.DataFrame`
and load it into a new table:

.. literalinclude:: ../snippets.py
.. literalinclude:: ../samples/load_table_dataframe.py
:language: python
:dedent: 4
:start-after: [START bigquery_load_table_dataframe]
Expand Down
2 changes: 2 additions & 0 deletions bigquery/google/cloud/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from google.cloud.bigquery.dataset import AccessEntry
from google.cloud.bigquery.dataset import Dataset
from google.cloud.bigquery.dataset import DatasetReference
from google.cloud.bigquery import enums
from google.cloud.bigquery.enums import StandardSqlDataTypes
from google.cloud.bigquery.external_config import ExternalConfig
from google.cloud.bigquery.external_config import BigtableOptions
Expand Down Expand Up @@ -124,6 +125,7 @@
"GoogleSheetsOptions",
"DEFAULT_RETRY",
# Enum Constants
"enums",
"Compression",
"CreateDisposition",
"DestinationFormat",
Expand Down
64 changes: 57 additions & 7 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,49 @@ def bq_to_arrow_array(series, bq_field):
return pyarrow.array(series, type=arrow_type)


def get_column_or_index(dataframe, name):
"""Return a column or index as a pandas series."""
if name in dataframe.columns:
return dataframe[name].reset_index(drop=True)

if isinstance(dataframe.index, pandas.MultiIndex):
if name in dataframe.index.names:
return (
dataframe.index.get_level_values(name)
.to_series()
.reset_index(drop=True)
)
else:
if name == dataframe.index.name:
return dataframe.index.to_series().reset_index(drop=True)

raise ValueError("column or index '{}' not found.".format(name))


def list_columns_and_indexes(dataframe):
"""Return all index and column names with dtypes.
Returns:
Sequence[Tuple[dtype, str]]:
Returns a sorted list of indexes and column names with
corresponding dtypes. If an index is missing a name or has the
same name as a column, the index is omitted.
"""
column_names = frozenset(dataframe.columns)
columns_and_indexes = []
if isinstance(dataframe.index, pandas.MultiIndex):
for name in dataframe.index.names:
if name and name not in column_names:
values = dataframe.index.get_level_values(name)
columns_and_indexes.append((name, values.dtype))
else:
if dataframe.index.name and dataframe.index.name not in column_names:
columns_and_indexes.append((dataframe.index.name, dataframe.index.dtype))

columns_and_indexes += zip(dataframe.columns, dataframe.dtypes)
return columns_and_indexes


def dataframe_to_bq_schema(dataframe, bq_schema):
"""Convert a pandas DataFrame schema to a BigQuery schema.
Expand Down Expand Up @@ -217,7 +260,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_schema_unused = set()

bq_schema_out = []
for column, dtype in zip(dataframe.columns, dataframe.dtypes):
for column, dtype in list_columns_and_indexes(dataframe):
# Use provided type from schema, if present.
bq_field = bq_schema_index.get(column)
if bq_field:
Expand All @@ -229,7 +272,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
# pandas dtype.
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
if not bq_type:
warnings.warn("Unable to determine type of column '{}'.".format(column))
warnings.warn(u"Unable to determine type of column '{}'.".format(column))
return None
bq_field = schema.SchemaField(column, bq_type)
bq_schema_out.append(bq_field)
Expand All @@ -238,7 +281,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
# column, but it was not found.
if bq_schema_unused:
raise ValueError(
"bq_schema contains fields not present in dataframe: {}".format(
u"bq_schema contains fields not present in dataframe: {}".format(
bq_schema_unused
)
)
Expand All @@ -261,20 +304,25 @@ def dataframe_to_arrow(dataframe, bq_schema):
BigQuery schema.
"""
column_names = set(dataframe.columns)
column_and_index_names = set(
name for name, _ in list_columns_and_indexes(dataframe)
)
bq_field_names = set(field.name for field in bq_schema)

extra_fields = bq_field_names - column_names
extra_fields = bq_field_names - column_and_index_names
if extra_fields:
raise ValueError(
"bq_schema contains fields not present in dataframe: {}".format(
u"bq_schema contains fields not present in dataframe: {}".format(
extra_fields
)
)

# It's okay for indexes to be missing from bq_schema, but it's not okay to
# be missing columns.
missing_fields = column_names - bq_field_names
if missing_fields:
raise ValueError(
"bq_schema is missing fields from dataframe: {}".format(missing_fields)
u"bq_schema is missing fields from dataframe: {}".format(missing_fields)
)

arrow_arrays = []
Expand All @@ -283,7 +331,9 @@ def dataframe_to_arrow(dataframe, bq_schema):
for bq_field in bq_schema:
arrow_fields.append(bq_to_arrow_field(bq_field))
arrow_names.append(bq_field.name)
arrow_arrays.append(bq_to_arrow_array(dataframe[bq_field.name], bq_field))
arrow_arrays.append(
bq_to_arrow_array(get_column_or_index(dataframe, bq_field.name), bq_field)
)

if all((field is not None for field in arrow_fields)):
return pyarrow.Table.from_arrays(
Expand Down
73 changes: 73 additions & 0 deletions bigquery/samples/load_table_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def load_table_dataframe(client, table_id):
# [START bigquery_load_table_dataframe]
from google.cloud import bigquery
import pandas

# TODO(developer): Construct a BigQuery client object.
# client = bigquery.Client()

# TODO(developer): Set table_id to the ID of the table to create.
# table_id = "your-project.your_dataset.your_table_name"

records = [
{"title": u"The Meaning of Life", "release_year": 1983},
{"title": u"Monty Python and the Holy Grail", "release_year": 1975},
{"title": u"Life of Brian", "release_year": 1979},
{"title": u"And Now for Something Completely Different", "release_year": 1971},
]
dataframe = pandas.DataFrame(
records,
# In the loaded table, the column order reflects the order of the
# columns in the DataFrame.
columns=["title", "release_year"],
# Optionally, set a named index, which can also be written to the
# BigQuery table.
index=pandas.Index(
[u"Q24980", u"Q25043", u"Q24953", u"Q16403"], name="wikidata_id"
),
)
job_config = bigquery.LoadJobConfig(
# Specify a (partial) schema. All columns are always written to the
# table. The schema is used to assist in data type definitions.
schema=[
# Specify the type of columns whose type cannot be auto-detected. For
# example the "title" column uses pandas dtype "object", so its
# data type is ambiguous.
bigquery.SchemaField("title", bigquery.enums.SqlTypeNames.STRING),
# Indexes are written if included in the schema by name.
bigquery.SchemaField("wikidata_id", bigquery.enums.SqlTypeNames.STRING),
],
# Optionally, set the write disposition. BigQuery appends loaded rows
# to an existing table by default, but with WRITE_TRUNCATE write
# disposition it replaces the table with the loaded data.
write_disposition="WRITE_TRUNCATE",
)

job = client.load_table_from_dataframe(
dataframe, table_id, job_config=job_config, location="US"
)
job.result() # Waits for table load to complete.

table = client.get_table(table_id)
print(
"Loaded {} rows and {} columns to {}".format(
table.num_rows, len(table.schema), table_id
)
)
# [END bigquery_load_table_dataframe]
return table
30 changes: 30 additions & 0 deletions bigquery/samples/tests/test_load_table_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from .. import load_table_dataframe


pytest.importorskip("pandas")
pytest.importorskip("pyarrow")


def test_load_table_dataframe(capsys, client, random_table_id):
table = load_table_dataframe.load_table_dataframe(client, random_table_id)
out, _ = capsys.readouterr()
assert "Loaded 4 rows and 3 columns" in out

column_names = [field.name for field in table.schema]
assert column_names == ["wikidata_id", "title", "release_year"]
Loading

0 comments on commit a6ed945

Please sign in to comment.