From 0d8d6dcce2101b8db6b26c85012978e65e1f999e Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 20 May 2022 14:28:55 -0500 Subject: [PATCH] refactor: update black version to fix lint CI (#130) --- environment.yml | 2 +- ibis_bigquery/udf/__init__.py | 2 +- ibis_bigquery/udf/core.py | 17 +++-- ibis_bigquery/udf/find.py | 3 +- requirements.txt | 2 +- tests/system/conftest.py | 20 +++-- tests/system/test_client.py | 3 +- tests/system/test_connect.py | 22 ++++-- tests/system/udf/test_udf_execute.py | 11 ++- tests/unit/test_client.py | 6 +- tests/unit/test_compiler.py | 107 +++++++++++++++++++++++---- tests/unit/udf/test_core.py | 2 +- 12 files changed, 154 insertions(+), 43 deletions(-) diff --git a/environment.yml b/environment.yml index 97f177e..48e928a 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ dependencies: - sqlalchemy=1.3.23 # dev -- black=19.10b0 # Same as ibis +- black=22.3.0 # Same as ibis - pytest - pytest-cov - pytest-mock diff --git a/ibis_bigquery/udf/__init__.py b/ibis_bigquery/udf/__init__.py index 7f0ccfe..457e3ce 100644 --- a/ibis_bigquery/udf/__init__.py +++ b/ibis_bigquery/udf/__init__.py @@ -23,7 +23,7 @@ def validate_output_type(*args): __all__ = ("udf",) -_udf_name_cache: Dict[str, Iterable[int]] = (collections.defaultdict(itertools.count)) +_udf_name_cache: Dict[str, Iterable[int]] = collections.defaultdict(itertools.count) def create_udf_node(name, fields): diff --git a/ibis_bigquery/udf/core.py b/ibis_bigquery/udf/core.py index bab59e7..18b25b3 100644 --- a/ibis_bigquery/udf/core.py +++ b/ibis_bigquery/udf/core.py @@ -70,7 +70,9 @@ def wrapper(*args, **kwargs): def rewrite_print(node): return ast.Call( func=ast.Attribute( - value=ast.Name(id="console", ctx=ast.Load()), attr="log", ctx=ast.Load(), + value=ast.Name(id="console", ctx=ast.Load()), + attr="log", + ctx=ast.Load(), ), args=node.args, keywords=node.keywords, @@ -395,7 +397,9 @@ def visit_Compare(self, node): @semicolon def visit_AugAssign(self, node): return "{} {}= {}".format( - self.visit(node.target), self.visit(node.op), self.visit(node.value), + self.visit(node.target), + self.visit(node.op), + self.visit(node.value), ) def visit_Module(self, node): @@ -420,8 +424,7 @@ def visit_Lambda(self, node): @contextlib.contextmanager def local_scope(self): - """Assign symbols to local variables. - """ + """Assign symbols to local variables.""" self.scope = self.scope.new_child() try: yield self.scope @@ -444,7 +447,9 @@ def visit_If(self, node): def visit_IfExp(self, node): return "({} ? {} : {})".format( - self.visit(node.test), self.visit(node.body), self.visit(node.orelse), + self.visit(node.test), + self.visit(node.body), + self.visit(node.orelse), ) def visit_Index(self, node): @@ -604,6 +609,6 @@ def range(n): z = (x if y else b) + 2 + foobar foo = Rectangle(1, 2) nnn = len(values) - return [sum(values) - a + b * y ** -x, z, foo.width, nnn] + return [sum(values) - a + b * y**-x, z, foo.width, nnn] print(my_func.js) diff --git a/ibis_bigquery/udf/find.py b/ibis_bigquery/udf/find.py index 92c1072..bafa1d3 100644 --- a/ibis_bigquery/udf/find.py +++ b/ibis_bigquery/udf/find.py @@ -4,8 +4,7 @@ class NameFinder: - """Helper class to find the unique names in an AST. - """ + """Helper class to find the unique names in an AST.""" __slots__ = () diff --git a/requirements.txt b/requirements.txt index 3e47b30..27ec92e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ google-cloud-bigquery pydata-google-auth # dev -black==19.10b0 +black==22.3.0 pytest pytest-cov pytest-mock diff --git a/tests/system/conftest.py b/tests/system/conftest.py index c41900e..13a2396 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -72,14 +72,18 @@ def credentials(default_credentials): @pytest.fixture(scope="session") def client(credentials, project_id): return bq.connect( - project_id=project_id, dataset_id=DATASET_ID, credentials=credentials, + project_id=project_id, + dataset_id=DATASET_ID, + credentials=credentials, ) @pytest.fixture(scope="session") def client2(credentials, project_id): return bq.connect( - project_id=project_id, dataset_id=DATASET_ID, credentials=credentials, + project_id=project_id, + dataset_id=DATASET_ID, + credentials=credentials, ) @@ -184,7 +188,9 @@ def load_functional_alltypes_data(request, bqclient, create_functional_alltypes_ filepath = download_file("{}/functional_alltypes.csv".format(TESTING_DATA_URI)) with open(filepath.name, "rb") as csvfile: job = bqclient.load_table_from_file( - csvfile, table, job_config=load_config, + csvfile, + table, + job_config=load_config, ).result() if job.error_result: print("error") @@ -238,7 +244,9 @@ def load_functional_alltypes_parted_data( filepath = download_file("{}/functional_alltypes.csv".format(TESTING_DATA_URI)) with open(filepath.name, "rb") as csvfile: job = bqclient.load_table_from_file( - csvfile, table, job_config=load_config, + csvfile, + table, + job_config=load_config, ).result() if job.error_result: print("error") @@ -261,7 +269,9 @@ def load_struct_table_data(request, bqclient, struct_bq_table): filepath = download_file("{}/struct_table.avro".format(TESTING_DATA_URI)) with open(filepath.name, "rb") as avrofile: job = bqclient.load_table_from_file( - avrofile, struct_bq_table, job_config=load_config, + avrofile, + struct_bq_table, + job_config=load_config, ).result() if job.error_result: print("error") diff --git a/tests/system/test_client.py b/tests/system/test_client.py index d994113..cba0015 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -615,7 +615,8 @@ def test_string_to_timestamp(client): assert result == timestamp timestamp_tz = pd.Timestamp( - datetime.datetime(year=2017, month=2, day=6, hour=5), tz=pytz.timezone("UTC"), + datetime.datetime(year=2017, month=2, day=6, hour=5), + tz=pytz.timezone("UTC"), ) expr_tz = ibis.literal("2017-02-06").to_timestamp("%F", "America/New_York") result_tz = client.execute(expr_tz) diff --git a/tests/system/test_connect.py b/tests/system/test_connect.py index e151ddc..33e288d 100644 --- a/tests/system/test_connect.py +++ b/tests/system/test_connect.py @@ -27,12 +27,17 @@ def mock_credentials(*args, **kwargs): return creds, "default-project-id" monkeypatch.setattr(pydata_google_auth, "default", mock_credentials) - con = ibis_bigquery.connect(project_id="explicit-project-id",) + con = ibis_bigquery.connect( + project_id="explicit-project-id", + ) assert con.billing_project == "explicit-project-id" def test_without_dataset(project_id, credentials): - con = ibis_bigquery.connect(project_id=project_id, credentials=credentials,) + con = ibis_bigquery.connect( + project_id=project_id, + credentials=credentials, + ) with pytest.raises(ValueError, match="Unable to determine BigQuery"): con.list_tables() @@ -62,7 +67,8 @@ def mock_default(*args, **kwargs): monkeypatch.setattr(pydata_google_auth, "default", mock_default) ibis_bigquery.connect( - project_id=project_id, dataset_id="bigquery-public-data.stackoverflow", + project_id=project_id, + dataset_id="bigquery-public-data.stackoverflow", ) assert len(mock_calls) == 1 @@ -73,7 +79,10 @@ def mock_default(*args, **kwargs): auth_local_webserver = kwargs["use_local_webserver"] auth_cache = kwargs["credentials_cache"] assert not auth_local_webserver - assert isinstance(auth_cache, pydata_google_auth.cache.ReadWriteCredentialsCache,) + assert isinstance( + auth_cache, + pydata_google_auth.cache.ReadWriteCredentialsCache, + ) def test_auth_local_webserver(project_id, credentials, monkeypatch): @@ -137,7 +146,10 @@ def mock_default(*args, **kwargs): assert len(mock_calls) == 1 _, kwargs = mock_calls[0] auth_cache = kwargs["credentials_cache"] - assert isinstance(auth_cache, pydata_google_auth.cache.WriteOnlyCredentialsCache,) + assert isinstance( + auth_cache, + pydata_google_auth.cache.WriteOnlyCredentialsCache, + ) def test_auth_cache_none(project_id, credentials, monkeypatch): diff --git a/tests/system/udf/test_udf_execute.py b/tests/system/udf/test_udf_execute.py index 08777bd..5f799af 100644 --- a/tests/system/udf/test_udf_execute.py +++ b/tests/system/udf/test_udf_execute.py @@ -39,7 +39,8 @@ def my_add(a, b): expected = (df.double_col + df.double_col).rename("tmp") tm.assert_series_equal( - result.value_counts().sort_index(), expected.value_counts().sort_index(), + result.value_counts().sort_index(), + expected.value_counts().sort_index(), ) @@ -216,11 +217,15 @@ def my_len(s): param(dt.float64, dt.int64, marks=pytest.mark.xfail(raises=TypeError)), # complex argument type, valid return type param( - dt.Array(dt.int64), dt.float64, marks=pytest.mark.xfail(raises=TypeError), + dt.Array(dt.int64), + dt.float64, + marks=pytest.mark.xfail(raises=TypeError), ), # valid argument type, complex invalid return type param( - dt.float64, dt.Array(dt.int64), marks=pytest.mark.xfail(raises=TypeError), + dt.float64, + dt.Array(dt.int64), + marks=pytest.mark.xfail(raises=TypeError), ), # both invalid param( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6471d45..1cfc388 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -8,7 +8,11 @@ ["project", "dataset", "expected"], [ ("my-project", "", ("my-project", "my-project", "")), - ("my-project", "my_dataset", ("my-project", "my-project", "my_dataset"),), + ( + "my-project", + "my_dataset", + ("my-project", "my-project", "my_dataset"), + ), ( "billing-project", "data-project.my_dataset", diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index 53c4b8c..22970ba 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -18,14 +18,22 @@ ("case", "expected", "dtype"), [ (datetime.date(2017, 1, 1), "DATE '2017-01-01'", dt.date), - (pd.Timestamp("2017-01-01"), "DATE '2017-01-01'", dt.date,), + ( + pd.Timestamp("2017-01-01"), + "DATE '2017-01-01'", + dt.date, + ), ("2017-01-01", "DATE '2017-01-01'", dt.date), ( datetime.datetime(2017, 1, 1, 4, 55, 59), "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp, ), - ("2017-01-01 04:55:59", "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp,), + ( + "2017-01-01 04:55:59", + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), ( pd.Timestamp("2017-01-01 04:55:59"), "TIMESTAMP '2017-01-01 04:55:59'", @@ -42,9 +50,24 @@ def test_literal_date(case, expected, dtype): @pytest.mark.parametrize( ("case", "expected", "dtype", "strftime_func"), [ - (datetime.date(2017, 1, 1), "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), - (pd.Timestamp("2017-01-01"), "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), - ("2017-01-01", "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), + ( + datetime.date(2017, 1, 1), + "DATE '2017-01-01'", + dt.date, + "FORMAT_DATE", + ), + ( + pd.Timestamp("2017-01-01"), + "DATE '2017-01-01'", + dt.date, + "FORMAT_DATE", + ), + ( + "2017-01-01", + "DATE '2017-01-01'", + dt.date, + "FORMAT_DATE", + ), ( datetime.datetime(2017, 1, 1, 4, 55, 59), "TIMESTAMP '2017-01-01 04:55:59'", @@ -82,8 +105,16 @@ def test_day_of_week(case, expected, dtype, strftime_func): @pytest.mark.parametrize( ("case", "expected", "dtype"), [ - ("test of hash", "'test of hash'", dt.string,), - (b"test of hash", "FROM_BASE64('dGVzdCBvZiBoYXNo')", dt.binary,), + ( + "test of hash", + "'test of hash'", + dt.string, + ), + ( + b"test of hash", + "FROM_BASE64('dGVzdCBvZiBoYXNo')", + dt.binary, + ), ], ) def test_hash(case, expected, dtype): @@ -98,14 +129,54 @@ def test_hash(case, expected, dtype): @pytest.mark.parametrize( ("case", "expected", "how", "dtype"), [ - ("test", "md5('test')", "md5", dt.string,), - (b"test", "md5(FROM_BASE64('dGVzdA=='))", "md5", dt.binary,), - ("test", "sha1('test')", "sha1", dt.string,), - (b"test", "sha1(FROM_BASE64('dGVzdA=='))", "sha1", dt.binary,), - ("test", "sha256('test')", "sha256", dt.string,), - (b"test", "sha256(FROM_BASE64('dGVzdA=='))", "sha256", dt.binary,), - ("test", "sha512('test')", "sha512", dt.string,), - (b"test", "sha512(FROM_BASE64('dGVzdA=='))", "sha512", dt.binary,), + ( + "test", + "md5('test')", + "md5", + dt.string, + ), + ( + b"test", + "md5(FROM_BASE64('dGVzdA=='))", + "md5", + dt.binary, + ), + ( + "test", + "sha1('test')", + "sha1", + dt.string, + ), + ( + b"test", + "sha1(FROM_BASE64('dGVzdA=='))", + "sha1", + dt.binary, + ), + ( + "test", + "sha256('test')", + "sha256", + dt.string, + ), + ( + b"test", + "sha256(FROM_BASE64('dGVzdA=='))", + "sha256", + dt.binary, + ), + ( + "test", + "sha512('test')", + "sha512", + dt.string, + ), + ( + b"test", + "sha512(FROM_BASE64('dGVzdA=='))", + "sha512", + dt.binary, + ), ], ) def test_hashbytes(case, expected, how, dtype): @@ -125,7 +196,11 @@ def test_hashbytes(case, expected, how, dtype): "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp, ), - ("2017-01-01 04:55:59", "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp,), + ( + "2017-01-01 04:55:59", + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), ( pd.Timestamp("2017-01-01 04:55:59"), "TIMESTAMP '2017-01-01 04:55:59'", diff --git a/tests/unit/udf/test_core.py b/tests/unit/udf/test_core.py index e431c6b..fbf5b23 100644 --- a/tests/unit/udf/test_core.py +++ b/tests/unit/udf/test_core.py @@ -133,7 +133,7 @@ def test_binary_operators(op, expected): def test_pow(): def f(): a = 1 - return a ** 2 + return a**2 expected = """\ function f() {