From 40a0cdd734e5543bf7149c5e2ce346ee6dd4a824 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 21 Aug 2019 11:27:48 -0700 Subject: [PATCH] Determine the schema in `load_table_from_dataframe` based on dtypes. (#9049) * Determine the schema in `load_table_from_dataframe` based on dtypes. This PR updates `load_table_from_dataframe` to automatically determine the BigQuery schema based on the DataFrame's dtypes. If any field's type cannot be determined, fallback to the logic in the pandas `to_parquet` method. * Fix test coverage. * Reduce duplication by using OrderedDict * Add columns option to DataFrame constructor to ensure correct column order. --- .../google/cloud/bigquery/_pandas_helpers.py | 40 ++++++++++ bigquery/google/cloud/bigquery/client.py | 15 ++++ bigquery/tests/system.py | 76 +++++++++++++++++++ bigquery/tests/unit/test_client.py | 74 +++++++++++++++++- 4 files changed, 203 insertions(+), 2 deletions(-) diff --git a/bigquery/google/cloud/bigquery/_pandas_helpers.py b/bigquery/google/cloud/bigquery/_pandas_helpers.py index 5cc69e434b04..db7f36f3d93e 100644 --- a/bigquery/google/cloud/bigquery/_pandas_helpers.py +++ b/bigquery/google/cloud/bigquery/_pandas_helpers.py @@ -49,6 +49,21 @@ _PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds. +_PANDAS_DTYPE_TO_BQ = { + "bool": "BOOLEAN", + "datetime64[ns, UTC]": "TIMESTAMP", + "datetime64[ns]": "DATETIME", + "float32": "FLOAT", + "float64": "FLOAT", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", +} + class _DownloadState(object): """Flag to indicate that a thread should exit early.""" @@ -172,6 +187,31 @@ def bq_to_arrow_array(series, bq_field): return pyarrow.array(series, type=arrow_type) +def dataframe_to_bq_schema(dataframe): + """Convert a pandas DataFrame schema to a BigQuery schema. + + TODO(GH#8140): Add bq_schema argument to allow overriding autodetected + schema for a subset of columns. + + Args: + dataframe (pandas.DataFrame): + DataFrame to convert to convert to Parquet file. + + Returns: + Optional[Sequence[google.cloud.bigquery.schema.SchemaField]]: + The automatically determined schema. Returns None if the type of + any column cannot be determined. + """ + bq_schema = [] + for column, dtype in zip(dataframe.columns, dataframe.dtypes): + bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name) + if not bq_type: + return None + bq_field = schema.SchemaField(column, bq_type) + bq_schema.append(bq_field) + return tuple(bq_schema) + + def dataframe_to_arrow(dataframe, bq_schema): """Convert pandas dataframe to Arrow table, using BigQuery schema. diff --git a/bigquery/google/cloud/bigquery/client.py b/bigquery/google/cloud/bigquery/client.py index ae9adb4da15f..1b13ee126a5d 100644 --- a/bigquery/google/cloud/bigquery/client.py +++ b/bigquery/google/cloud/bigquery/client.py @@ -21,6 +21,7 @@ except ImportError: # Python 2.7 import collections as collections_abc +import copy import functools import gzip import io @@ -1521,11 +1522,25 @@ def load_table_from_dataframe( if job_config is None: job_config = job.LoadJobConfig() + else: + # Make a copy so that the job config isn't modified in-place. + job_config_properties = copy.deepcopy(job_config._properties) + job_config = job.LoadJobConfig() + job_config._properties = job_config_properties job_config.source_format = job.SourceFormat.PARQUET if location is None: location = self.location + if not job_config.schema: + autodetected_schema = _pandas_helpers.dataframe_to_bq_schema(dataframe) + + # Only use an explicit schema if we were able to determine one + # matching the dataframe. If not, fallback to the pandas to_parquet + # method. + if autodetected_schema: + job_config.schema = autodetected_schema + tmpfd, tmppath = tempfile.mkstemp(suffix="_job_{}.parquet".format(job_id[:8])) os.close(tmpfd) diff --git a/bigquery/tests/system.py b/bigquery/tests/system.py index fd9efa7752cf..59a72297ed87 100644 --- a/bigquery/tests/system.py +++ b/bigquery/tests/system.py @@ -13,6 +13,7 @@ # limitations under the License. import base64 +import collections import concurrent.futures import csv import datetime @@ -634,6 +635,81 @@ def test_load_table_from_local_avro_file_then_dump_table(self): sorted(row_tuples, key=by_wavelength), sorted(ROWS, key=by_wavelength) ) + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_automatic_schema(self): + """Test that a DataFrame with dtypes that map well to BigQuery types + can be uploaded without specifying a schema. + + https://github.com/googleapis/google-cloud-python/issues/9044 + """ + df_data = collections.OrderedDict( + [ + ("bool_col", pandas.Series([True, False, True], dtype="bool")), + ( + "ts_col", + pandas.Series( + [ + datetime.datetime(2010, 1, 2, 3, 44, 50), + datetime.datetime(2011, 2, 3, 14, 50, 59), + datetime.datetime(2012, 3, 14, 15, 16), + ], + dtype="datetime64[ns]", + ).dt.tz_localize(pytz.utc), + ), + ( + "dt_col", + pandas.Series( + [ + datetime.datetime(2010, 1, 2, 3, 44, 50), + datetime.datetime(2011, 2, 3, 14, 50, 59), + datetime.datetime(2012, 3, 14, 15, 16), + ], + dtype="datetime64[ns]", + ), + ), + ("float32_col", pandas.Series([1.0, 2.0, 3.0], dtype="float32")), + ("float64_col", pandas.Series([4.0, 5.0, 6.0], dtype="float64")), + ("int8_col", pandas.Series([-12, -11, -10], dtype="int8")), + ("int16_col", pandas.Series([-9, -8, -7], dtype="int16")), + ("int32_col", pandas.Series([-6, -5, -4], dtype="int32")), + ("int64_col", pandas.Series([-3, -2, -1], dtype="int64")), + ("uint8_col", pandas.Series([0, 1, 2], dtype="uint8")), + ("uint16_col", pandas.Series([3, 4, 5], dtype="uint16")), + ("uint32_col", pandas.Series([6, 7, 8], dtype="uint32")), + ] + ) + dataframe = pandas.DataFrame(df_data, columns=df_data.keys()) + + dataset_id = _make_dataset_id("bq_load_test") + self.temp_dataset(dataset_id) + table_id = "{}.{}.load_table_from_dataframe_w_automatic_schema".format( + Config.CLIENT.project, dataset_id + ) + + load_job = Config.CLIENT.load_table_from_dataframe(dataframe, table_id) + load_job.result() + + table = Config.CLIENT.get_table(table_id) + self.assertEqual( + tuple(table.schema), + ( + bigquery.SchemaField("bool_col", "BOOLEAN"), + bigquery.SchemaField("ts_col", "TIMESTAMP"), + bigquery.SchemaField("dt_col", "DATETIME"), + bigquery.SchemaField("float32_col", "FLOAT"), + bigquery.SchemaField("float64_col", "FLOAT"), + bigquery.SchemaField("int8_col", "INTEGER"), + bigquery.SchemaField("int16_col", "INTEGER"), + bigquery.SchemaField("int32_col", "INTEGER"), + bigquery.SchemaField("int64_col", "INTEGER"), + bigquery.SchemaField("uint8_col", "INTEGER"), + bigquery.SchemaField("uint16_col", "INTEGER"), + bigquery.SchemaField("uint32_col", "INTEGER"), + ), + ) + self.assertEqual(table.num_rows, 3) + @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_nulls(self): diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index d7ff3d2a90b3..8a2a1228cd65 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import collections import datetime import decimal import email @@ -5325,9 +5326,78 @@ def test_load_table_from_dataframe_w_custom_job_config(self): ) sent_config = load_table_from_file.mock_calls[0][2]["job_config"] - assert sent_config is job_config assert sent_config.source_format == job.SourceFormat.PARQUET + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_automatic_schema(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + df_data = collections.OrderedDict( + [ + ("int_col", [1, 2, 3]), + ("float_col", [1.0, 2.0, 3.0]), + ("bool_col", [True, False, True]), + ( + "dt_col", + pandas.Series( + [ + datetime.datetime(2010, 1, 2, 3, 44, 50), + datetime.datetime(2011, 2, 3, 14, 50, 59), + datetime.datetime(2012, 3, 14, 15, 16), + ], + dtype="datetime64[ns]", + ), + ), + ( + "ts_col", + pandas.Series( + [ + datetime.datetime(2010, 1, 2, 3, 44, 50), + datetime.datetime(2011, 2, 3, 14, 50, 59), + datetime.datetime(2012, 3, 14, 15, 16), + ], + dtype="datetime64[ns]", + ).dt.tz_localize(pytz.utc), + ), + ] + ) + dataframe = pandas.DataFrame(df_data, columns=df_data.keys()) + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + + with load_patch as load_table_from_file: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, location=self.LOCATION + ) + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.source_format == job.SourceFormat.PARQUET + assert tuple(sent_config.schema) == ( + SchemaField("int_col", "INTEGER"), + SchemaField("float_col", "FLOAT"), + SchemaField("bool_col", "BOOLEAN"), + SchemaField("dt_col", "DATETIME"), + SchemaField("ts_col", "TIMESTAMP"), + ) + @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_struct_fields_error(self): @@ -5509,7 +5579,7 @@ def test_load_table_from_dataframe_w_nulls(self): ) sent_config = load_table_from_file.mock_calls[0][2]["job_config"] - assert sent_config is job_config + assert sent_config.schema == schema assert sent_config.source_format == job.SourceFormat.PARQUET # Low-level tests