Skip to content

Commit 1233182

Browse files
author
Jim Fulton
authored
fix: the unnest function lost needed type information (#298)
1 parent 6ffcef6 commit 1233182

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

sqlalchemy_bigquery/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from google.api_core.exceptions import NotFound
3636

3737
import sqlalchemy
38+
import sqlalchemy.sql.expression
39+
import sqlalchemy.sql.functions
3840
import sqlalchemy.sql.sqltypes
3941
import sqlalchemy.sql.type_api
4042
from sqlalchemy.exc import NoSuchTableError
@@ -1092,6 +1094,21 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
10921094
return view.view_query
10931095

10941096

1097+
class unnest(sqlalchemy.sql.functions.GenericFunction):
1098+
def __init__(self, *args, **kwargs):
1099+
expr = kwargs.pop("expr", None)
1100+
if expr is not None:
1101+
args = (expr,) + args
1102+
if len(args) != 1:
1103+
raise TypeError("The unnest function requires a single argument.")
1104+
arg = args[0]
1105+
if isinstance(arg, sqlalchemy.sql.expression.ColumnElement):
1106+
if not isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY):
1107+
raise TypeError("The argument to unnest must have an ARRAY type.")
1108+
self.type = arg.type.item_type
1109+
super().__init__(*args, **kwargs)
1110+
1111+
10951112
dialect = BigQueryDialect
10961113

10971114
try:

tests/unit/test_sqlalchemy_bigquery.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from google.cloud import bigquery
1111
from google.cloud.bigquery.dataset import DatasetListItem
1212
from google.cloud.bigquery.table import TableListItem
13+
import packaging.version
1314
import pytest
1415
import sqlalchemy
1516

@@ -178,3 +179,53 @@ def test_follow_dialect_attribute_convention():
178179

179180
assert sqlalchemy_bigquery.dialect is sqlalchemy_bigquery.BigQueryDialect
180181
assert sqlalchemy_bigquery.base.dialect is sqlalchemy_bigquery.BigQueryDialect
182+
183+
184+
@pytest.mark.parametrize(
185+
"args,kw,error",
186+
[
187+
((), {}, "The unnest function requires a single argument."),
188+
((1, 1), {}, "The unnest function requires a single argument."),
189+
((1,), {"expr": 1}, "The unnest function requires a single argument."),
190+
((1, 1), {"expr": 1}, "The unnest function requires a single argument."),
191+
(
192+
(),
193+
{"expr": sqlalchemy.Column("x", sqlalchemy.String)},
194+
"The argument to unnest must have an ARRAY type.",
195+
),
196+
(
197+
(sqlalchemy.Column("x", sqlalchemy.String),),
198+
{},
199+
"The argument to unnest must have an ARRAY type.",
200+
),
201+
],
202+
)
203+
def test_unnest_function_errors(args, kw, error):
204+
# Make sure the unnest function is registered with SQLAlchemy, which
205+
# happens when sqlalchemy_bigquery is imported.
206+
import sqlalchemy_bigquery # noqa
207+
208+
with pytest.raises(TypeError, match=error):
209+
sqlalchemy.func.unnest(*args, **kw)
210+
211+
212+
@pytest.mark.parametrize(
213+
"args,kw",
214+
[
215+
((), {"expr": sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String))}),
216+
((sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String)),), {}),
217+
],
218+
)
219+
def test_unnest_function(args, kw):
220+
# Make sure the unnest function is registered with SQLAlchemy, which
221+
# happens when sqlalchemy_bigquery is imported.
222+
import sqlalchemy_bigquery # noqa
223+
224+
f = sqlalchemy.func.unnest(*args, **kw)
225+
assert isinstance(f.type, sqlalchemy.String)
226+
if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse(
227+
"1.4"
228+
):
229+
assert isinstance(
230+
sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String
231+
)

0 commit comments

Comments
 (0)