diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 29b913a..361df0f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -59,9 +59,9 @@ jobs: - python_version: "3.8" "ibis_version": "2.1.1" - python_version: "3.9" - "ibis_version": "3.0.0" + "ibis_version": "3.0.2" - python_version: "3.10" - "ibis_version": "3.0.0" + "ibis_version": "3.0.2" - python_version: "3.10" "ibis_version": "github" diff --git a/.github/workflows/system-tests-pr.yml b/.github/workflows/system-tests-pr.yml index 16ae71f..a55330d 100644 --- a/.github/workflows/system-tests-pr.yml +++ b/.github/workflows/system-tests-pr.yml @@ -23,9 +23,9 @@ jobs: - python_version: "3.8" "ibis_version": "2.1.1" - python_version: "3.9" - "ibis_version": "3.0.0" + "ibis_version": "3.0.2" - python_version: "3.10" - "ibis_version": "3.0.0" + "ibis_version": "3.0.2" - python_version: "3.10" "ibis_version": "github" diff --git a/.github/workflows/system-tests.yml b/.github/workflows/system-tests.yml index d922760..e32ffd1 100644 --- a/.github/workflows/system-tests.yml +++ b/.github/workflows/system-tests.yml @@ -21,9 +21,9 @@ jobs: - python_version: "3.8" "ibis_version": "2.1.1" - python_version: "3.9" - "ibis_version": "3.0.0" + "ibis_version": "3.0.2" - python_version: "3.10" - "ibis_version": "3.0.0" + "ibis_version": "3.0.2" - python_version: "3.10" "ibis_version": "github" diff --git a/CHANGELOG.md b/CHANGELOG.md index fd98c81..f72dbb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +### [2.1.3](https://github.com/ibis-project/ibis-bigquery/compare/v2.1.2...v2.1.3) (2022-05-25) + + +### Bug Fixes + +* ensure that ScalarParameter names are used instead of Alias names ([#135](https://github.com/ibis-project/ibis-bigquery/issues/135)) ([bfe539a](https://github.com/ibis-project/ibis-bigquery/commit/bfe539a7c60439f7a521e230736aab3961dbabcc)) + ### [2.1.2](https://github.com/ibis-project/ibis-bigquery/compare/v2.1.1...v2.1.2) (2022-04-26) diff --git a/environment.yml b/environment.yml index 97f177e..48e928a 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ dependencies: - sqlalchemy=1.3.23 # dev -- black=19.10b0 # Same as ibis +- black=22.3.0 # Same as ibis - pytest - pytest-cov - pytest-mock diff --git a/ibis_bigquery/__init__.py b/ibis_bigquery/__init__.py index 8ed5703..53a6728 100644 --- a/ibis_bigquery/__init__.py +++ b/ibis_bigquery/__init__.py @@ -29,6 +29,13 @@ except ImportError: pass +try: + from ibis.expr.operations import Alias +except ImportError: + # Allow older versions of ibis to work with ScalarParameters as well as + # versions >= 3.0.0 + Alias = None + __version__: str = ibis_bigquery_version.__version__ @@ -222,7 +229,24 @@ def _execute(self, stmt, results=True, query_parameters=None): def raw_sql(self, query: str, results=False, params=None): query_parameters = [ - bigquery_param(param, value) for param, value in (params or {}).items() + bigquery_param( + # unwrap Alias instances + # + # Without unwrapping we try to execute compiled code that uses + # the ScalarParameter's raw name (e.g., @param_1) and not the + # alias's name which will fail. By unwrapping, we always use + # the raw name. + # + # This workaround is backwards compatible and doesn't require + # changes to ibis. + ( + param + if Alias is None or not isinstance(param.op(), Alias) + else param.op().arg + ), + value, + ) + for param, value in (params or {}).items() ] return self._execute(query, results=results, query_parameters=query_parameters) diff --git a/ibis_bigquery/udf/__init__.py b/ibis_bigquery/udf/__init__.py index 7f0ccfe..d528a0e 100644 --- a/ibis_bigquery/udf/__init__.py +++ b/ibis_bigquery/udf/__init__.py @@ -23,7 +23,7 @@ def validate_output_type(*args): __all__ = ("udf",) -_udf_name_cache: Dict[str, Iterable[int]] = (collections.defaultdict(itertools.count)) +_udf_name_cache: Dict[str, Iterable[int]] = collections.defaultdict(itertools.count) def create_udf_node(name, fields): @@ -181,21 +181,18 @@ def wrapper(f): signature = inspect.signature(f) parameter_names = signature.parameters.keys() - udf_node_fields = collections.OrderedDict( - [ - (name, Arg(rlz.value(type))) - for name, type in zip(parameter_names, input_type) - ] - + [ - ( - "output_type", - lambda self, output_type=output_type: rlz.shape_like( - self.args, dtype=output_type - ), - ), - ("__slots__", ("js",)), - ] - ) + udf_node_fields = { + name: Arg(rlz.value(type)) + for name, type in zip(parameter_names, input_type) + } + + try: + udf_node_fields["output_type"] = rlz.shape_like("args", dtype=output_type) + except TypeError: + udf_node_fields["output_dtype"] = property(lambda _: output_type) + udf_node_fields["output_shape"] = rlz.shape_like("args") + + udf_node_fields["__slots__"] = ("js",) udf_node = create_udf_node(f.__name__, udf_node_fields) diff --git a/ibis_bigquery/udf/core.py b/ibis_bigquery/udf/core.py index bab59e7..18b25b3 100644 --- a/ibis_bigquery/udf/core.py +++ b/ibis_bigquery/udf/core.py @@ -70,7 +70,9 @@ def wrapper(*args, **kwargs): def rewrite_print(node): return ast.Call( func=ast.Attribute( - value=ast.Name(id="console", ctx=ast.Load()), attr="log", ctx=ast.Load(), + value=ast.Name(id="console", ctx=ast.Load()), + attr="log", + ctx=ast.Load(), ), args=node.args, keywords=node.keywords, @@ -395,7 +397,9 @@ def visit_Compare(self, node): @semicolon def visit_AugAssign(self, node): return "{} {}= {}".format( - self.visit(node.target), self.visit(node.op), self.visit(node.value), + self.visit(node.target), + self.visit(node.op), + self.visit(node.value), ) def visit_Module(self, node): @@ -420,8 +424,7 @@ def visit_Lambda(self, node): @contextlib.contextmanager def local_scope(self): - """Assign symbols to local variables. - """ + """Assign symbols to local variables.""" self.scope = self.scope.new_child() try: yield self.scope @@ -444,7 +447,9 @@ def visit_If(self, node): def visit_IfExp(self, node): return "({} ? {} : {})".format( - self.visit(node.test), self.visit(node.body), self.visit(node.orelse), + self.visit(node.test), + self.visit(node.body), + self.visit(node.orelse), ) def visit_Index(self, node): @@ -604,6 +609,6 @@ def range(n): z = (x if y else b) + 2 + foobar foo = Rectangle(1, 2) nnn = len(values) - return [sum(values) - a + b * y ** -x, z, foo.width, nnn] + return [sum(values) - a + b * y**-x, z, foo.width, nnn] print(my_func.js) diff --git a/ibis_bigquery/udf/find.py b/ibis_bigquery/udf/find.py index 92c1072..bafa1d3 100644 --- a/ibis_bigquery/udf/find.py +++ b/ibis_bigquery/udf/find.py @@ -4,8 +4,7 @@ class NameFinder: - """Helper class to find the unique names in an AST. - """ + """Helper class to find the unique names in an AST.""" __slots__ = () diff --git a/ibis_bigquery/version.py b/ibis_bigquery/version.py index 4eabd0b..e835b9d 100644 --- a/ibis_bigquery/version.py +++ b/ibis_bigquery/version.py @@ -1 +1 @@ -__version__ = "2.1.2" +__version__ = "2.1.3" diff --git a/requirements.txt b/requirements.txt index 3e47b30..27ec92e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ google-cloud-bigquery pydata-google-auth # dev -black==19.10b0 +black==22.3.0 pytest pytest-cov pytest-mock diff --git a/tests/system/conftest.py b/tests/system/conftest.py index c41900e..13a2396 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -72,14 +72,18 @@ def credentials(default_credentials): @pytest.fixture(scope="session") def client(credentials, project_id): return bq.connect( - project_id=project_id, dataset_id=DATASET_ID, credentials=credentials, + project_id=project_id, + dataset_id=DATASET_ID, + credentials=credentials, ) @pytest.fixture(scope="session") def client2(credentials, project_id): return bq.connect( - project_id=project_id, dataset_id=DATASET_ID, credentials=credentials, + project_id=project_id, + dataset_id=DATASET_ID, + credentials=credentials, ) @@ -184,7 +188,9 @@ def load_functional_alltypes_data(request, bqclient, create_functional_alltypes_ filepath = download_file("{}/functional_alltypes.csv".format(TESTING_DATA_URI)) with open(filepath.name, "rb") as csvfile: job = bqclient.load_table_from_file( - csvfile, table, job_config=load_config, + csvfile, + table, + job_config=load_config, ).result() if job.error_result: print("error") @@ -238,7 +244,9 @@ def load_functional_alltypes_parted_data( filepath = download_file("{}/functional_alltypes.csv".format(TESTING_DATA_URI)) with open(filepath.name, "rb") as csvfile: job = bqclient.load_table_from_file( - csvfile, table, job_config=load_config, + csvfile, + table, + job_config=load_config, ).result() if job.error_result: print("error") @@ -261,7 +269,9 @@ def load_struct_table_data(request, bqclient, struct_bq_table): filepath = download_file("{}/struct_table.avro".format(TESTING_DATA_URI)) with open(filepath.name, "rb") as avrofile: job = bqclient.load_table_from_file( - avrofile, struct_bq_table, job_config=load_config, + avrofile, + struct_bq_table, + job_config=load_config, ).result() if job.error_result: print("error") diff --git a/tests/system/test_client.py b/tests/system/test_client.py index d994113..602f26c 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -1,6 +1,7 @@ import collections import datetime import decimal +import re import ibis import ibis.expr.datatypes as dt @@ -11,12 +12,21 @@ import pandas.testing as tm import pytest import pytz +from pytest import param import ibis_bigquery from ibis_bigquery.client import bigquery_param IBIS_VERSION = packaging.version.Version(ibis.__version__) IBIS_1_4_VERSION = packaging.version.Version("1.4.0") +IBIS_3_0_VERSION = packaging.version.Version("3.0.0") + +older_than_3 = pytest.mark.xfail( + IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3" +) +at_least_3 = pytest.mark.xfail( + IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3" +) def test_table(alltypes): @@ -84,7 +94,8 @@ def test_compile_toplevel(): def test_struct_field_access(struct_table): expr = struct_table.struct_col["string_field"] result = expr.execute() - expected = pd.Series([None, "a"], name="tmp") + expected_name = "tmp" if IBIS_VERSION < IBIS_3_0_VERSION else "string_field" + expected = pd.Series([None, "a"], name=expected_name) tm.assert_series_equal(result, expected) @@ -202,7 +213,43 @@ def test_different_partition_col_name(monkeypatch, client): assert col in parted_alltypes.columns -def test_subquery_scalar_params(alltypes, project_id, dataset_id): +def scalar_params_ibis3(project_id, dataset_id): + return f"""\ +SELECT count\\(`foo`\\) AS `count` +FROM \\( + SELECT `string_col`, sum\\(`float_col`\\) AS `foo` + FROM \\( + SELECT `float_col`, `timestamp_col`, `int_col`, `string_col` + FROM `{project_id}\\.{dataset_id}\\.functional_alltypes` + \\) t1 + WHERE `timestamp_col` < @param_\\d+ + GROUP BY 1 +\\) t0""" + + +def scalar_params_not_ibis3(project_id, dataset_id): + return f"""\ +SELECT count\\(`foo`\\) AS `count` +FROM \\( + SELECT `string_col`, sum\\(`float_col`\\) AS `foo` + FROM \\( + SELECT `float_col`, `timestamp_col`, `int_col`, `string_col` + FROM `{project_id}\\.{dataset_id}\\.functional_alltypes` + WHERE `timestamp_col` < @my_param + \\) t1 + GROUP BY 1 +\\) t0""" + + +@pytest.mark.parametrize( + "expected_fn", + [ + param(scalar_params_ibis3, marks=[older_than_3], id="ibis3"), + param(scalar_params_not_ibis3, marks=[at_least_3], id="not_ibis3"), + ], +) +def test_subquery_scalar_params(alltypes, project_id, dataset_id, expected_fn): + expected = expected_fn(project_id, dataset_id) t = alltypes param = ibis.param("timestamp").name("my_param") expr = ( @@ -214,20 +261,7 @@ def test_subquery_scalar_params(alltypes, project_id, dataset_id): .foo.count() ) result = expr.compile(params={param: "20140101"}) - expected = """\ -SELECT count(`foo`) AS `count` -FROM ( - SELECT `string_col`, sum(`float_col`) AS `foo` - FROM ( - SELECT `float_col`, `timestamp_col`, `int_col`, `string_col` - FROM `{}.{}.functional_alltypes` - WHERE `timestamp_col` < @my_param - ) t1 - GROUP BY 1 -) t0""".format( - project_id, dataset_id - ) - assert result == expected + assert re.match(expected, result) is not None def test_scalar_param_string(alltypes, df): @@ -455,18 +489,21 @@ def test_raw_sql(client): assert client.raw_sql("SELECT 1").fetchall() == [(1,)] -def test_scalar_param_scope(alltypes, project_id, dataset_id): +@pytest.mark.parametrize( + "pattern", + [ + param(r"@param_\d+", marks=[older_than_3], id="ibis3"), + param("@param", marks=[at_least_3], id="not_ibis3"), + ], +) +def test_scalar_param_scope(alltypes, project_id, dataset_id, pattern): t = alltypes param = ibis.param("timestamp") - mut = t.mutate(param=param).compile(params={param: "2017-01-01"}) - assert ( - mut - == """\ -SELECT *, @param AS `param` -FROM `{}.{}.functional_alltypes`""".format( - project_id, dataset_id - ) - ) + result = t.mutate(param=param).compile(params={param: "2017-01-01"}) + expected = f"""\ +SELECT \\*, {pattern} AS `param` +FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`""" + assert re.match(expected, result) is not None def test_parted_column_rename(parted_alltypes): @@ -615,7 +652,8 @@ def test_string_to_timestamp(client): assert result == timestamp timestamp_tz = pd.Timestamp( - datetime.datetime(year=2017, month=2, day=6, hour=5), tz=pytz.timezone("UTC"), + datetime.datetime(year=2017, month=2, day=6, hour=5), + tz=pytz.timezone("UTC"), ) expr_tz = ibis.literal("2017-02-06").to_timestamp("%F", "America/New_York") result_tz = client.execute(expr_tz) @@ -718,7 +756,6 @@ def test_boolean_casting(alltypes): count = result["count"] assert count.at[False] == 5840 assert count.at[True] == 730 - assert count.at[None] == 730 def test_approx_median(alltypes): diff --git a/tests/system/test_compiler.py b/tests/system/test_compiler.py index a261efb..629ab3b 100644 --- a/tests/system/test_compiler.py +++ b/tests/system/test_compiler.py @@ -1,24 +1,42 @@ +import re + import ibis import ibis.expr.datatypes as dt import packaging.version import pytest +from pytest import param pytestmark = pytest.mark.bigquery IBIS_VERSION = packaging.version.Version(ibis.__version__) IBIS_1_VERSION = packaging.version.Version("1.4.0") +IBIS_3_0_VERSION = packaging.version.Version("3.0.0") +older_than_3 = pytest.mark.xfail( + IBIS_VERSION < IBIS_3_0_VERSION, reason="requires ibis >= 3" +) +at_least_3 = pytest.mark.xfail( + IBIS_VERSION >= IBIS_3_0_VERSION, reason="requires ibis < 3" +) -def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id): + +@pytest.mark.parametrize( + "pattern", + [ + param(r"@param_\d+", marks=[older_than_3], id="ibis3"), + param("@param", marks=[at_least_3], id="not_ibis3"), + ], +) +def test_timestamp_accepts_date_literals(alltypes, project_id, dataset_id, pattern): date_string = "2009-03-01" param = ibis.param(dt.timestamp).name("param_0") expr = alltypes.mutate(param=param) params = {param: date_string} result = expr.compile(params=params) expected = f"""\ -SELECT *, @param AS `param` -FROM `{project_id}.{dataset_id}.functional_alltypes`""" - assert result == expected +SELECT \\*, {pattern} AS `param` +FROM `{project_id}\\.{dataset_id}\\.functional_alltypes`""" + assert re.match(expected, result) is not None @pytest.mark.parametrize( diff --git a/tests/system/test_connect.py b/tests/system/test_connect.py index e151ddc..33e288d 100644 --- a/tests/system/test_connect.py +++ b/tests/system/test_connect.py @@ -27,12 +27,17 @@ def mock_credentials(*args, **kwargs): return creds, "default-project-id" monkeypatch.setattr(pydata_google_auth, "default", mock_credentials) - con = ibis_bigquery.connect(project_id="explicit-project-id",) + con = ibis_bigquery.connect( + project_id="explicit-project-id", + ) assert con.billing_project == "explicit-project-id" def test_without_dataset(project_id, credentials): - con = ibis_bigquery.connect(project_id=project_id, credentials=credentials,) + con = ibis_bigquery.connect( + project_id=project_id, + credentials=credentials, + ) with pytest.raises(ValueError, match="Unable to determine BigQuery"): con.list_tables() @@ -62,7 +67,8 @@ def mock_default(*args, **kwargs): monkeypatch.setattr(pydata_google_auth, "default", mock_default) ibis_bigquery.connect( - project_id=project_id, dataset_id="bigquery-public-data.stackoverflow", + project_id=project_id, + dataset_id="bigquery-public-data.stackoverflow", ) assert len(mock_calls) == 1 @@ -73,7 +79,10 @@ def mock_default(*args, **kwargs): auth_local_webserver = kwargs["use_local_webserver"] auth_cache = kwargs["credentials_cache"] assert not auth_local_webserver - assert isinstance(auth_cache, pydata_google_auth.cache.ReadWriteCredentialsCache,) + assert isinstance( + auth_cache, + pydata_google_auth.cache.ReadWriteCredentialsCache, + ) def test_auth_local_webserver(project_id, credentials, monkeypatch): @@ -137,7 +146,10 @@ def mock_default(*args, **kwargs): assert len(mock_calls) == 1 _, kwargs = mock_calls[0] auth_cache = kwargs["credentials_cache"] - assert isinstance(auth_cache, pydata_google_auth.cache.WriteOnlyCredentialsCache,) + assert isinstance( + auth_cache, + pydata_google_auth.cache.WriteOnlyCredentialsCache, + ) def test_auth_cache_none(project_id, credentials, monkeypatch): diff --git a/tests/system/udf/test_udf_execute.py b/tests/system/udf/test_udf_execute.py index 08777bd..5f799af 100644 --- a/tests/system/udf/test_udf_execute.py +++ b/tests/system/udf/test_udf_execute.py @@ -39,7 +39,8 @@ def my_add(a, b): expected = (df.double_col + df.double_col).rename("tmp") tm.assert_series_equal( - result.value_counts().sort_index(), expected.value_counts().sort_index(), + result.value_counts().sort_index(), + expected.value_counts().sort_index(), ) @@ -216,11 +217,15 @@ def my_len(s): param(dt.float64, dt.int64, marks=pytest.mark.xfail(raises=TypeError)), # complex argument type, valid return type param( - dt.Array(dt.int64), dt.float64, marks=pytest.mark.xfail(raises=TypeError), + dt.Array(dt.int64), + dt.float64, + marks=pytest.mark.xfail(raises=TypeError), ), # valid argument type, complex invalid return type param( - dt.float64, dt.Array(dt.int64), marks=pytest.mark.xfail(raises=TypeError), + dt.float64, + dt.Array(dt.int64), + marks=pytest.mark.xfail(raises=TypeError), ), # both invalid param( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6471d45..1cfc388 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -8,7 +8,11 @@ ["project", "dataset", "expected"], [ ("my-project", "", ("my-project", "my-project", "")), - ("my-project", "my_dataset", ("my-project", "my-project", "my_dataset"),), + ( + "my-project", + "my_dataset", + ("my-project", "my-project", "my_dataset"), + ), ( "billing-project", "data-project.my_dataset", diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 53c4b8c..f84d192 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -12,20 +12,29 @@ IBIS_VERSION = packaging.version.Version(ibis.__version__) IBIS_1_4_VERSION = packaging.version.Version("1.4.0") +IBIS_3_0_VERSION = packaging.version.Version("3.0.0") @pytest.mark.parametrize( ("case", "expected", "dtype"), [ (datetime.date(2017, 1, 1), "DATE '2017-01-01'", dt.date), - (pd.Timestamp("2017-01-01"), "DATE '2017-01-01'", dt.date,), + ( + pd.Timestamp("2017-01-01"), + "DATE '2017-01-01'", + dt.date, + ), ("2017-01-01", "DATE '2017-01-01'", dt.date), ( datetime.datetime(2017, 1, 1, 4, 55, 59), "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp, ), - ("2017-01-01 04:55:59", "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp,), + ( + "2017-01-01 04:55:59", + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), ( pd.Timestamp("2017-01-01 04:55:59"), "TIMESTAMP '2017-01-01 04:55:59'", @@ -42,9 +51,24 @@ def test_literal_date(case, expected, dtype): @pytest.mark.parametrize( ("case", "expected", "dtype", "strftime_func"), [ - (datetime.date(2017, 1, 1), "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), - (pd.Timestamp("2017-01-01"), "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), - ("2017-01-01", "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), + ( + datetime.date(2017, 1, 1), + "DATE '2017-01-01'", + dt.date, + "FORMAT_DATE", + ), + ( + pd.Timestamp("2017-01-01"), + "DATE '2017-01-01'", + dt.date, + "FORMAT_DATE", + ), + ( + "2017-01-01", + "DATE '2017-01-01'", + dt.date, + "FORMAT_DATE", + ), ( datetime.datetime(2017, 1, 1, 4, 55, 59), "TIMESTAMP '2017-01-01 04:55:59'", @@ -82,8 +106,16 @@ def test_day_of_week(case, expected, dtype, strftime_func): @pytest.mark.parametrize( ("case", "expected", "dtype"), [ - ("test of hash", "'test of hash'", dt.string,), - (b"test of hash", "FROM_BASE64('dGVzdCBvZiBoYXNo')", dt.binary,), + ( + "test of hash", + "'test of hash'", + dt.string, + ), + ( + b"test of hash", + "FROM_BASE64('dGVzdCBvZiBoYXNo')", + dt.binary, + ), ], ) def test_hash(case, expected, dtype): @@ -98,14 +130,54 @@ def test_hash(case, expected, dtype): @pytest.mark.parametrize( ("case", "expected", "how", "dtype"), [ - ("test", "md5('test')", "md5", dt.string,), - (b"test", "md5(FROM_BASE64('dGVzdA=='))", "md5", dt.binary,), - ("test", "sha1('test')", "sha1", dt.string,), - (b"test", "sha1(FROM_BASE64('dGVzdA=='))", "sha1", dt.binary,), - ("test", "sha256('test')", "sha256", dt.string,), - (b"test", "sha256(FROM_BASE64('dGVzdA=='))", "sha256", dt.binary,), - ("test", "sha512('test')", "sha512", dt.string,), - (b"test", "sha512(FROM_BASE64('dGVzdA=='))", "sha512", dt.binary,), + ( + "test", + "md5('test')", + "md5", + dt.string, + ), + ( + b"test", + "md5(FROM_BASE64('dGVzdA=='))", + "md5", + dt.binary, + ), + ( + "test", + "sha1('test')", + "sha1", + dt.string, + ), + ( + b"test", + "sha1(FROM_BASE64('dGVzdA=='))", + "sha1", + dt.binary, + ), + ( + "test", + "sha256('test')", + "sha256", + dt.string, + ), + ( + b"test", + "sha256(FROM_BASE64('dGVzdA=='))", + "sha256", + dt.binary, + ), + ( + "test", + "sha512('test')", + "sha512", + dt.string, + ), + ( + b"test", + "sha512(FROM_BASE64('dGVzdA=='))", + "sha512", + dt.binary, + ), ], ) def test_hashbytes(case, expected, how, dtype): @@ -125,7 +197,11 @@ def test_hashbytes(case, expected, how, dtype): "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp, ), - ("2017-01-01 04:55:59", "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp,), + ( + "2017-01-01 04:55:59", + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), ( pd.Timestamp("2017-01-01 04:55:59"), "TIMESTAMP '2017-01-01 04:55:59'", @@ -238,8 +314,9 @@ def test_binary(): t = ibis.table([("value", "double")], name="t") expr = t["value"].cast(dt.binary).name("value_hash") result = ibis_bigquery.compile(expr) - expected = """\ -SELECT CAST(`value` AS BYTES) AS `tmp` + expected_name = "tmp" if IBIS_VERSION < IBIS_3_0_VERSION else "value_hash" + expected = f"""\ +SELECT CAST(`value` AS BYTES) AS `{expected_name}` FROM t""" assert result == expected @@ -270,21 +347,22 @@ def test_bucket(): buckets = [0, 1, 3] expr = t.value.bucket(buckets).name("foo") result = ibis_bigquery.compile(expr) - expected = """\ + expected_name = "tmp" if IBIS_VERSION < IBIS_3_0_VERSION else "foo" + expected = f"""\ SELECT CASE WHEN (0 <= `value`) AND (`value` < 1) THEN 0 WHEN (1 <= `value`) AND (`value` <= 3) THEN 1 ELSE CAST(NULL AS INT64) - END AS `tmp` + END AS `{expected_name}` FROM t""" - expected_2 = """\ + expected_2 = f"""\ SELECT CASE WHEN (`value` >= 0) AND (`value` < 1) THEN 0 WHEN (`value` >= 1) AND (`value` <= 3) THEN 1 ELSE CAST(NULL AS INT64) - END AS `tmp` + END AS `{expected_name}` FROM t""" assert result == expected or result == expected_2 @@ -301,10 +379,11 @@ def test_window_unbounded(kind, begin, end, expected): kwargs = {kind: (begin, end)} expr = t.a.sum().over(ibis.window(**kwargs)) result = ibis_bigquery.compile(expr) + expected_name = "tmp" if IBIS_VERSION < IBIS_3_0_VERSION else "sum" assert ( result == f"""\ -SELECT sum(`a`) OVER (ROWS BETWEEN {expected}) AS `tmp` +SELECT sum(`a`) OVER (ROWS BETWEEN {expected}) AS `{expected_name}` FROM t""" ) diff --git a/tests/unit/udf/test_core.py b/tests/unit/udf/test_core.py index e431c6b..fbf5b23 100644 --- a/tests/unit/udf/test_core.py +++ b/tests/unit/udf/test_core.py @@ -133,7 +133,7 @@ def test_binary_operators(op, expected): def test_pow(): def f(): a = 1 - return a ** 2 + return a**2 expected = """\ function f() {