Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/google/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ PIP package Version required
=========================================== ======================================
``apache-airflow`` ``>=2.10.0``
``apache-airflow-providers-common-compat`` ``>=1.4.0``
``apache-airflow-providers-common-sql`` ``>=1.20.0``
``apache-airflow-providers-common-sql`` ``>=1.27.0``
``asgiref`` ``>=3.5.2``
``dill`` ``>=0.2.3``
``gcloud-aio-auth`` ``>=5.2.0``
Expand Down
8 changes: 2 additions & 6 deletions providers/google/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ requires-python = "~=3.9"
dependencies = [
"apache-airflow>=2.10.0",
"apache-airflow-providers-common-compat>=1.4.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow-providers-common-sql>=1.27.0",
"asgiref>=3.5.2",
"dill>=0.2.3",
"gcloud-aio-auth>=5.2.0",
Expand Down Expand Up @@ -129,11 +129,6 @@ dependencies = [
# See https://github.com/looker-open-source/sdk-codegen/issues/1518
"looker-sdk>=22.4.0,!=24.18.0",
"pandas-gbq>=0.7.0",
# In pandas 2.2 minimal version of the sqlalchemy is 2.0
# https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies
# However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723
# In addition FAB also limit sqlalchemy to < 2.0
"pandas>=2.1.2,<2.2",
# A transient dependency of google-cloud-bigquery-datatransfer, but we
# further constrain it since older versions are buggy.
"proto-plus>=1.19.6",
Expand Down Expand Up @@ -233,6 +228,7 @@ dev = [
"apache-airflow-providers-trino",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"apache-airflow-providers-apache-kafka",
"apache-airflow-providers-common-sql[pandas,polars]",
]

# To build docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from collections.abc import Iterable, Mapping, Sequence
from copy import deepcopy
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, NoReturn, Union, cast
from typing import TYPE_CHECKING, Any, NoReturn, Union, cast, overload

from aiohttp import ClientSession as ClientSession
from gcloud.aio.bigquery import Job, Table as Table_async
Expand Down Expand Up @@ -57,8 +57,13 @@
from pandas_gbq import read_gbq
from pandas_gbq.gbq import GbqConnector # noqa: F401 used in ``airflow.contrib.hooks.bigquery``
from sqlalchemy import create_engine
from typing_extensions import Literal

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import (
AirflowException,
AirflowOptionalProviderFeatureException,
AirflowProviderDeprecationWarning,
)
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.utils.bigquery import bq_cast
Expand All @@ -77,6 +82,7 @@

if TYPE_CHECKING:
import pandas as pd
import polars as pl
from google.api_core.page_iterator import HTTPIterator
from google.api_core.retry import Retry
from requests import Session
Expand Down Expand Up @@ -275,15 +281,57 @@ def insert_rows(
"""
raise NotImplementedError()

def get_pandas_df(
def _get_pandas_df(
self,
sql: str,
parameters: Iterable | Mapping[str, Any] | None = None,
dialect: str | None = None,
**kwargs,
) -> pd.DataFrame:
if dialect is None:
dialect = "legacy" if self.use_legacy_sql else "standard"

credentials, project_id = self.get_credentials_and_project_id()

return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)

def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame:
try:
import polars as pl
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Polars is not installed. Please install it with `pip install polars`."
)

if dialect is None:
dialect = "legacy" if self.use_legacy_sql else "standard"

credentials, project_id = self.get_credentials_and_project_id()

pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
return pl.from_pandas(pandas_df)

@overload
def get_df(
self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs
) -> pd.DataFrame: ...

@overload
def get_df(
self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs
) -> pl.DataFrame: ...

def get_df(
self,
sql,
parameters=None,
dialect=None,
*,
df_type: Literal["pandas", "polars"] = "pandas",
**kwargs,
) -> pd.DataFrame | pl.DataFrame:
"""
Get a Pandas DataFrame for the BigQuery results.
Get a DataFrame for the BigQuery results.

The DbApiHook method must be overridden because Pandas doesn't support
PEP 249 connections, except for SQLite.
Expand All @@ -299,12 +347,19 @@ def get_pandas_df(
defaults to use `self.use_legacy_sql` if not specified
:param kwargs: (optional) passed into pandas_gbq.read_gbq method
"""
if dialect is None:
dialect = "legacy" if self.use_legacy_sql else "standard"
if df_type == "polars":
return self._get_polars_df(sql, parameters, dialect, **kwargs)

credentials, project_id = self.get_credentials_and_project_id()
if df_type == "pandas":
return self._get_pandas_df(sql, parameters, dialect, **kwargs)

return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
@deprecated(
planned_removal_date="November 30, 2025",
use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_df",
category=AirflowProviderDeprecationWarning,
)
def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs):
return self._get_pandas_df(sql, parameters, dialect, **kwargs)

@GoogleBaseHook.fallback_to_default_project_id
def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool:
Expand Down
19 changes: 16 additions & 3 deletions providers/google/tests/unit/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,22 @@ def test_bigquery_table_partition_exists_false_no_partition(self, mock_client):
assert result is False

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.read_gbq")
def test_get_pandas_df(self, mock_read_gbq):
self.hook.get_pandas_df("select 1")

@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_get_df(self, mock_read_gbq, df_type):
import pandas as pd
import polars as pl

mock_read_gbq.return_value = pd.DataFrame({"a": [1, 2, 3]})
result = self.hook.get_df("select 1", df_type=df_type)

expected_type = pd.DataFrame if df_type == "pandas" else pl.DataFrame
assert isinstance(result, expected_type)
assert result.shape == (3, 1)
assert result.columns == ["a"]
if df_type == "pandas":
assert result["a"].tolist() == [1, 2, 3]
else:
assert result.to_series().to_list() == [1, 2, 3]
mock_read_gbq.assert_called_once_with(
"select 1", credentials=CREDENTIALS, dialect="legacy", project_id=PROJECT_ID
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,52 @@ class TestBigQueryDataframeResultsSystem(GoogleSystemTest):
def setup_method(self):
self.instance = hook.BigQueryHook()

def test_output_is_dataframe_with_valid_query(self):
import pandas as pd
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_output_is_dataframe_with_valid_query(self, df_type):
df = self.instance.get_df("select 1", df_type=df_type)
if df_type == "polars":
import polars as pl

df = self.instance.get_pandas_df("select 1")
assert isinstance(df, pd.DataFrame)
assert isinstance(df, pl.DataFrame)
else:
import pandas as pd

def test_throws_exception_with_invalid_query(self):
assert isinstance(df, pd.DataFrame)

@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_throws_exception_with_invalid_query(self, df_type):
with pytest.raises(Exception) as ctx:
self.instance.get_pandas_df("from `1`")
self.instance.get_df("from `1`", df_type=df_type)
assert "Reason: " in str(ctx.value), ""

def test_succeeds_with_explicit_legacy_query(self):
df = self.instance.get_pandas_df("select 1", dialect="legacy")
assert df.iloc(0)[0][0] == 1
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_succeeds_with_explicit_legacy_query(self, df_type):
df = self.instance.get_df("select 1", df_type=df_type)
if df_type == "polars":
assert df.item(0, 0) == 1
else:
assert df.iloc[0][0] == 1

def test_succeeds_with_explicit_std_query(self):
df = self.instance.get_pandas_df("select * except(b) from (select 1 a, 2 b)", dialect="standard")
assert df.iloc(0)[0][0] == 1
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_succeeds_with_explicit_std_query(self, df_type):
df = self.instance.get_df(
"select * except(b) from (select 1 a, 2 b)",
parameters=None,
dialect="standard",
df_type=df_type,
)
if df_type == "polars":
assert df.item(0, 0) == 1
else:
assert df.iloc[0][0] == 1

def test_throws_exception_with_incompatible_syntax(self):
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_throws_exception_with_incompatible_syntax(self, df_type):
with pytest.raises(Exception) as ctx:
self.instance.get_pandas_df("select * except(b) from (select 1 a, 2 b)", dialect="legacy")
self.instance.get_df(
"select * except(b) from (select 1 a, 2 b)",
parameters=None,
dialect="legacy",
df_type=df_type,
)
assert "Reason: " in str(ctx.value), ""