diff --git a/README.md b/README.md index 3defa91a..eb8778cb 100644 --- a/README.md +++ b/README.md @@ -28,14 +28,28 @@ which is located in `tests/conftest.py`. # Setup Create and activate your conda environment with python3.9: ```commandline -conda create -y -n substrait_consumer_testing -c conda-forge python=3.9 +conda create -y -n substrait_consumer_testing -c conda-forge python=3.9 openjdk conda activate substrait_consumer_testing ``` +*Note: Java is used by Jpype to access the Isthmus producer. +JPype should work with all versions of Java but to see details on which versions are +officially supported see https://jpype.readthedocs.io/en/latest/install.html* Install requirements from the top level directory: ```commandline pip install -r requirements.txt ``` + +Get the java dependencies needed by the Isthmus Substrait producer: +1. Clone the substrait-java repo +2. From the consumer-testing repo, run the build-and-copy-isthmus-shadow-jar.sh script +```commandline +git clone https://github.com/substrait-io/substrait-java.git +cd consumer-testing +sh build-and-copy-isthmus-shadow-jar.sh +``` +*This shell script may not work on Windows environments.* + # How to Run Tests TPCH tests are located in the `tests/integration` folder and substrait function tests are located in the `tests/functional` folder. diff --git a/build-and-copy-isthmus-shadow-jar.sh b/build-and-copy-isthmus-shadow-jar.sh new file mode 100644 index 00000000..72721207 --- /dev/null +++ b/build-and-copy-isthmus-shadow-jar.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +echo "Enter the absolute path of the substrait-java repo" +read substrait_java_path +cd ${substrait_java_path}/isthmus; ../gradlew shadowJar +cd - +mkdir -p jars +cp ${substrait_java_path}/isthmus/build/libs/*all.jar ./jars/ diff --git a/requirements.txt b/requirements.txt index 2edd6a29..4431e053 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ duckdb filelock ibis-framework ibis-substrait +JPype1 protobuf --extra-index-url https://pypi.fury.io/arrow-nightlies --prefer-binary --pre pyarrow diff --git a/tests/conftest.py b/tests/conftest.py index 9d1dcf78..e873063d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from filelock import FileLock from tests.consumers import AceroConsumer, DuckDBConsumer -from tests.producers import DuckDBProducer, IbisProducer +from tests.producers import DuckDBProducer, IbisProducer, IsthmusProducer @pytest.fixture(scope="session") @@ -43,7 +43,7 @@ def pytest_addoption(parser): ) -PRODUCERS = [DuckDBProducer, IbisProducer] +PRODUCERS = [DuckDBProducer, IbisProducer, IsthmusProducer] CONSUMERS = [AceroConsumer, DuckDBConsumer] diff --git a/tests/consumers.py b/tests/consumers.py index 39822e77..da99a265 100644 --- a/tests/consumers.py +++ b/tests/consumers.py @@ -81,7 +81,7 @@ class AceroConsumer: def __init__(self): self.created_tables = set() self.tables = {} - self.table_provider = lambda names: self.tables[names[0]] + self.table_provider = lambda names: self.tables[names[0].lower()] def setup(self, db_connection, file_names: Iterable[str]): if len(file_names) > 0: diff --git a/tests/context.py b/tests/context.py new file mode 100644 index 00000000..319d1fde --- /dev/null +++ b/tests/context.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import jpype.imports + +import tests.java_definitions as java + +REPO_DIR = Path(__file__).parent.parent +from com.google.protobuf.util import JsonFormat as json_formatter + +schema_file = Path.joinpath(REPO_DIR, "tests/data/tpch_parquet/schema.sql") + + +def produce_isthmus_substrait(sql_string, schema_list): + """ + Produce the substrait plan using Isthmus. + + Parameters: + sql_string: + SQL query. + schema_list: + List of schemas. + + Returns: + Substrait plan in json format. + """ + sql_to_substrait = java.SqlToSubstraitClass() + java_sql_string = jpype.java.lang.String(sql_string) + plan = sql_to_substrait.execute(java_sql_string, schema_list) + json_plan = json_formatter.printer().print_(plan) + return json_plan + + +def get_schema(file_names): + """ + Create the list of schemas based on the given file names. If there are no files + give, a custom schema for the data is used. + + Parameters: + file_names: List of file names. + + Returns: + List of all schemas as a java list. + """ + arr = java.ArrayListClass() + if file_names: + text_schema_file = open(schema_file) + schema_string = text_schema_file.read().replace("\n", " ").split(";")[:-1] + for create_table in schema_string: + java_obj = jpype.JObject @ jpype.JString(create_table) + arr.add(java_obj) + else: + java_obj = jpype.JObject @ jpype.JString( + "CREATE TABLE T(a integer, b integer, c boolean, d boolean)" + ) + arr.add(java_obj) + + return java.ListClass @ arr diff --git a/tests/functional/common.py b/tests/functional/common.py index 7ded576c..1062f91c 100644 --- a/tests/functional/common.py +++ b/tests/functional/common.py @@ -38,7 +38,7 @@ def substrait_function_test( db_con: DuckDBPyConnection, created_tables: set, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, @@ -69,9 +69,10 @@ def substrait_function_test( """ producer.set_db_connection(db_con) consumer.setup(db_con, file_names) + supported_producers = sql_query[1] # Load the parquet files into DuckDB and return all the table names as a list - sql_query = producer.format_sql(created_tables, sql_query, file_names) + sql_query = producer.format_sql(created_tables, sql_query[0], file_names) # Convert the SQL/Ibis expression to a substrait query plan if type(producer).__name__ == "IbisProducer": @@ -80,9 +81,13 @@ def substrait_function_test( sql_query, consumer, ibis_expr(*args) ) else: - pytest.skip("ibis expression currently undefined") + pytest.xfail("ibis expression currently undefined") else: - substrait_plan = producer.produce_substrait(sql_query, consumer) + if type(producer) in supported_producers: + substrait_plan = producer.produce_substrait(sql_query, consumer) + else: + pytest.xfail(f"{type(producer).__name__} does not support the following SQL: " + f"{sql_query}") actual_result = consumer.run_substrait_query(substrait_plan) expected_result = db_con.query(f"{sql_query}").arrow() @@ -97,7 +102,7 @@ def substrait_function_test( def load_custom_duckdb_table(db_connection): - db_connection.execute("create table t (a int, b int, c boolean, d boolean)") + db_connection.execute("create table t (a BIGINT, b BIGINT, c boolean, d boolean)") db_connection.execute( "INSERT INTO t VALUES " "(1, 1, TRUE, TRUE), (2, 1, FALSE, TRUE), (3, 1, TRUE, TRUE), " diff --git a/tests/functional/extension_functions/test_approximation_functions.py b/tests/functional/extension_functions/test_approximation_functions.py index 602d5f21..93a7f4e7 100644 --- a/tests/functional/extension_functions/test_approximation_functions.py +++ b/tests/functional/extension_functions/test_approximation_functions.py @@ -35,7 +35,7 @@ def test_approximation_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_arithmetic_decimal_functions.py b/tests/functional/extension_functions/test_arithmetic_decimal_functions.py index 3de145dc..11c11cfe 100644 --- a/tests/functional/extension_functions/test_arithmetic_decimal_functions.py +++ b/tests/functional/extension_functions/test_arithmetic_decimal_functions.py @@ -38,7 +38,7 @@ def test_arithmetic_decimal_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_arithmetic_functions.py b/tests/functional/extension_functions/test_arithmetic_functions.py index 97718090..132cd9aa 100644 --- a/tests/functional/extension_functions/test_arithmetic_functions.py +++ b/tests/functional/extension_functions/test_arithmetic_functions.py @@ -41,7 +41,7 @@ def test_arithmetic_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_boolean_functions.py b/tests/functional/extension_functions/test_boolean_functions.py index 7fb452fd..61a61c63 100644 --- a/tests/functional/extension_functions/test_boolean_functions.py +++ b/tests/functional/extension_functions/test_boolean_functions.py @@ -17,8 +17,8 @@ class TestBooleanFunctions: """ @staticmethod - @pytest.fixture(scope="function", autouse=True) - def setup_teardown_function(request): + @pytest.fixture(scope="class", autouse=True) + def setup_teardown_class(request): cls = request.cls cls.db_connection = duckdb.connect() @@ -41,10 +41,11 @@ def test_boolean_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, + partsupp ) -> None: substrait_function_test( self.db_connection, @@ -54,5 +55,6 @@ def test_boolean_functions( ibis_expr, producer, consumer, + partsupp, self.table_t, ) diff --git a/tests/functional/extension_functions/test_comparison_functions.py b/tests/functional/extension_functions/test_comparison_functions.py index 8f67d60b..fe9419ae 100644 --- a/tests/functional/extension_functions/test_comparison_functions.py +++ b/tests/functional/extension_functions/test_comparison_functions.py @@ -17,8 +17,8 @@ class TestComparisonFunctions: """ @staticmethod - @pytest.fixture(scope="function", autouse=True) - def setup_teardown_function(request): + @pytest.fixture(scope="class", autouse=True) + def setup_teardown_class(request): cls = request.cls cls.db_connection = duckdb.connect() @@ -37,7 +37,7 @@ def test_comparison_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_datetime_functions.py b/tests/functional/extension_functions/test_datetime_functions.py index f3d96d16..601ead28 100644 --- a/tests/functional/extension_functions/test_datetime_functions.py +++ b/tests/functional/extension_functions/test_datetime_functions.py @@ -17,8 +17,8 @@ class TestDatetimeFunctions: """ @staticmethod - @pytest.fixture(scope="function", autouse=True) - def setup_teardown_function(request): + @pytest.fixture(scope="class", autouse=True) + def setup_teardown_class(request): cls = request.cls cls.db_connection = duckdb.connect() @@ -35,7 +35,7 @@ def test_datetime_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_logarithmic_functions.py b/tests/functional/extension_functions/test_logarithmic_functions.py index 7eaea7b4..55e0b9cf 100644 --- a/tests/functional/extension_functions/test_logarithmic_functions.py +++ b/tests/functional/extension_functions/test_logarithmic_functions.py @@ -35,7 +35,7 @@ def test_logarithmic_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_rounding_functions.py b/tests/functional/extension_functions/test_rounding_functions.py index decf037e..5af0af60 100644 --- a/tests/functional/extension_functions/test_rounding_functions.py +++ b/tests/functional/extension_functions/test_rounding_functions.py @@ -35,7 +35,7 @@ def test_rounding_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_string_functions.py b/tests/functional/extension_functions/test_string_functions.py index 1dde1cdd..41d1d49f 100644 --- a/tests/functional/extension_functions/test_string_functions.py +++ b/tests/functional/extension_functions/test_string_functions.py @@ -35,7 +35,7 @@ def test_string_functions( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, consumer, diff --git a/tests/functional/extension_functions/test_substrait_function_names.py b/tests/functional/extension_functions/test_substrait_function_names.py index 5ea6402b..86b16fa8 100644 --- a/tests/functional/extension_functions/test_substrait_function_names.py +++ b/tests/functional/extension_functions/test_substrait_function_names.py @@ -141,7 +141,7 @@ def test_rounding_function_names( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, partsupp, @@ -157,7 +157,7 @@ def run_function_name_test( self, test_name: str, file_names: Iterable[str], - sql_query: str, + sql_query: tuple, ibis_expr: Callable[[Table], Table], producer, *args @@ -185,7 +185,7 @@ def run_function_name_test( producer.set_db_connection(self.db_connection) # Load the parquet files into DuckDB and return all the table names as a list - sql_query = producer.format_sql(self.created_tables, sql_query, file_names) + sql_query = producer.format_sql(self.created_tables, sql_query[0], file_names) # Grab the json representation of the produced substrait plan to verify # the proper substrait function name. diff --git a/tests/functional/queries/sql/approximation_functions_sql.py b/tests/functional/queries/sql/approximation_functions_sql.py index e3f97fa1..f9cc08b5 100644 --- a/tests/functional/queries/sql/approximation_functions_sql.py +++ b/tests/functional/queries/sql/approximation_functions_sql.py @@ -1,7 +1,11 @@ +from tests.producers import * + SQL_AGGREGATE = { - "approx_count_distinct": + "approx_count_distinct": ( """ SELECT approx_count_distinct(l_comment) FROM '{}'; """, + [DuckDBProducer], + ), } diff --git a/tests/functional/queries/sql/arithmetic_demical_functions_sql.py b/tests/functional/queries/sql/arithmetic_demical_functions_sql.py index 3fe00633..f1e4f99a 100644 --- a/tests/functional/queries/sql/arithmetic_demical_functions_sql.py +++ b/tests/functional/queries/sql/arithmetic_demical_functions_sql.py @@ -1,50 +1,70 @@ +from tests.producers import * + SQL_SCALAR = { - "add": + "add": ( """ SELECT L_TAX, L_DISCOUNT, add(L_TAX, L_DISCOUNT) AS ADD_KEY FROM '{}'; """, - "subtract": + [DuckDBProducer], + ), + "subtract": ( """ SELECT L_TAX, L_DISCOUNT, subtract(L_TAX, L_DISCOUNT) AS SUBTRACT_KEY FROM '{}'; """, - "multiply": + [DuckDBProducer], + ), + "multiply": ( """ SELECT L_TAX, L_EXTENDEDPRICE, multiply(L_TAX, L_EXTENDEDPRICE) AS MULTIPLY_KEY FROM '{}'; """, - "divide": + [DuckDBProducer], + ), + "divide": ( """ SELECT L_TAX, L_EXTENDEDPRICE, divide(L_TAX, L_EXTENDEDPRICE) AS DIVIDE_KEY FROM '{}'; """, - "modulus": + [DuckDBProducer], + ), + "modulus": ( """ SELECT L_EXTENDEDPRICE, L_TAX, mod(L_EXTENDEDPRICE, L_TAX) AS MODULUS_KEY FROM '{}'; """, + [DuckDBProducer], + ), } SQL_AGGREGATE = { - "sum": + "sum": ( """ SELECT sum(L_EXTENDEDPRICE) AS SUM_EXTENDEDPRICE FROM '{}'; """, - "avg": + [DuckDBProducer], + ), + "avg": ( """ SELECT avg(L_EXTENDEDPRICE) AS AVG_EXTENDEDPRICE FROM '{}'; """, - "min": + [DuckDBProducer], + ), + "min": ( """ SELECT min(L_EXTENDEDPRICE) AS MIN_EXTENDEDPRICE FROM '{}'; """, - "max": + [DuckDBProducer], + ), + "max": ( """ SELECT max(L_EXTENDEDPRICE) AS MAX_EXTENDEDPRICE FROM '{}'; """, + [DuckDBProducer], + ), } diff --git a/tests/functional/queries/sql/arithmetic_functions_sql.py b/tests/functional/queries/sql/arithmetic_functions_sql.py index 1a79630a..a4e6fb49 100644 --- a/tests/functional/queries/sql/arithmetic_functions_sql.py +++ b/tests/functional/queries/sql/arithmetic_functions_sql.py @@ -1,150 +1,210 @@ +from tests.producers import * + SQL_SCALAR = { - "add": + "add": ( """ SELECT PS_PARTKEY, PS_SUPPKEY, add(PS_PARTKEY, PS_SUPPKEY) AS ADD_KEY FROM '{}'; """, - "subtract": + [DuckDBProducer], + ), + "subtract": ( """ SELECT PS_PARTKEY, PS_SUPPKEY, subtract(PS_PARTKEY, PS_SUPPKEY) AS SUBTRACT_KEY FROM '{}'; """, - "multiply": + [DuckDBProducer], + ), + "multiply": ( """ SELECT PS_PARTKEY, multiply(PS_PARTKEY, 10) AS MULTIPLY_KEY FROM '{}'; """, - "divide": + [DuckDBProducer], + ), + "divide": ( """ SELECT PS_PARTKEY, divide(PS_PARTKEY, 10) AS DIVIDE_KEY FROM '{}'; """, - "modulus": + [DuckDBProducer], + ), + "modulus": ( """ SELECT PS_PARTKEY, mod(PS_PARTKEY, 10) AS MODULO_KEY FROM '{}'; """, - "factorial": + [DuckDBProducer], + ), + "factorial": ( """ SELECT PS_PARTKEY, factorial(PS_PARTKEY) AS FACTORIAL_KEY FROM '{}'; """, - "power": + [DuckDBProducer], + ), + "power": ( """ SELECT PS_PARTKEY, power(PS_PARTKEY, 2) AS POWER_KEY FROM '{}'; """, - "sqrt": + [DuckDBProducer], + ), + "sqrt": ( """ SELECT PS_PARTKEY, sqrt(PS_PARTKEY) AS SQRT_KEY FROM '{}'; """, - "exp": + [DuckDBProducer], + ), + "exp": ( """ SELECT PS_PARTKEY, exp(PS_PARTKEY) AS EXP_KEY FROM '{}'; """, - "negate": + [DuckDBProducer], + ), + "negate": ( """ SELECT PS_PARTKEY, negate(PS_PARTKEY) AS NEGATE_KEY FROM '{}'; """, - "cos": + [DuckDBProducer], + ), + "cos": ( """ SELECT cos(PS_SUPPLYCOST) AS COS_SUPPLY FROM '{}'; """, - "acos": + [DuckDBProducer], + ), + "acos": ( """ SELECT acos(L_TAX) AS ACOS_TAX FROM '{}'; """, - "sin": + [DuckDBProducer], + ), + "sin": ( """ SELECT sin(PS_SUPPLYCOST) AS SIN_SUPPLY FROM '{}'; """, - "asin": + [DuckDBProducer], + ), + "asin": ( """ SELECT asin(L_TAX) AS ASIN_TAX FROM '{}'; """, - "tan": + [DuckDBProducer], + ), + "tan": ( """ SELECT tan(PS_SUPPLYCOST) AS TAN_SUPPLY FROM '{}'; """, - "atan": + [DuckDBProducer], + ), + "atan": ( """ SELECT atan(L_TAX) AS ATAN_TAX FROM '{}'; """, - "atan2": + [DuckDBProducer], + ), + "atan2": ( """ SELECT atan2(L_TAX, L_TAX) AS ATAN2_TAX FROM '{}'; """, - "abs": + [DuckDBProducer], + ), + "abs": ( """ SELECT a, abs(a) AS ABS_A FROM 't'; """, - "sign": + [DuckDBProducer], + ), + "sign": ( """ SELECT a, sign(a) AS SIGN_A FROM 't'; """, + [DuckDBProducer], + ), } SQL_AGGREGATE = { - "sum": + "sum": ( """ SELECT sum(PS_SUPPLYCOST) AS SUM_SUPPLYCOST FROM '{}'; """, - "count": + [DuckDBProducer], + ), + "count": ( """ SELECT count(PS_SUPPLYCOST) AS COUNT_SUPPLYCOST FROM '{}'; """, - "avg": + [DuckDBProducer], + ), + "avg": ( """ SELECT avg(PS_SUPPLYCOST) AS AVG_SUPPLYCOST FROM '{}'; """, - "min": + [DuckDBProducer], + ), + "min": ( """ SELECT min(PS_SUPPLYCOST) AS MIN_SUPPLYCOST FROM '{}'; """, - "max": + [DuckDBProducer], + ), + "max": ( """ SELECT max(PS_SUPPLYCOST) AS MAX_SUPPLYCOST FROM '{}'; """, - "median": + [DuckDBProducer], + ), + "median": ( """ SELECT median(PS_SUPPLYCOST) AS MEDIAN_SUPPLYCOST FROM '{}'; """, - "mode": + [DuckDBProducer], + ), + "mode": ( """ SELECT mode(PS_SUPPLYCOST) AS MODE_SUPPLYCOST FROM '{}'; """, - "product": + [DuckDBProducer], + ), + "product": ( """ SELECT product(PS_SUPPLYCOST) AS PRODUCT_SUPPLYCOST FROM '{}'; """, - "std_dev": + [DuckDBProducer], + ), + "std_dev": ( """ SELECT stddev(PS_SUPPLYCOST) AS STDDEV_SUPPLYCOST FROM '{}'; """, - "variance": + [DuckDBProducer], + ), + "variance": ( """ SELECT variance(PS_SUPPLYCOST) AS VARIANCE_SUPPLYCOST FROM '{}'; """, + [DuckDBProducer], + ), } diff --git a/tests/functional/queries/sql/boolean_functions_sql.py b/tests/functional/queries/sql/boolean_functions_sql.py index 7d48b9e4..b40fea56 100644 --- a/tests/functional/queries/sql/boolean_functions_sql.py +++ b/tests/functional/queries/sql/boolean_functions_sql.py @@ -1,36 +1,50 @@ +from tests.producers import * + SQL_SCALAR = { - "or": + "or": ( """ SELECT a FROM 't' WHERE a = 5 OR a = 7; """, - "and": + [DuckDBProducer, IsthmusProducer], + ), + "and": ( """ SELECT a, b FROM 't' WHERE a < 5 AND b = 1; """, - "not": + [DuckDBProducer, IsthmusProducer], + ), + "not": ( """ SELECT c FROM 't' WHERE NOT c """, - "xor": + [DuckDBProducer, IsthmusProducer], + ), + "xor": ( """ SELECT a, b, xor(a, b) AS xor_a_b FROM 't'; """, + [DuckDBProducer], + ), } SQL_AGGREGATE = { - "bool_and": + "bool_and": ( """ SELECT bool_and(c) AS bool_and_c FROM 't' """, - "bool_or": + [DuckDBProducer], + ), + "bool_or": ( """ SELECT bool_or(c) AS bool_or_c FROM 't' """, + [DuckDBProducer], + ), } diff --git a/tests/functional/queries/sql/comparison_functions_sql.py b/tests/functional/queries/sql/comparison_functions_sql.py index e2c02917..5c322868 100644 --- a/tests/functional/queries/sql/comparison_functions_sql.py +++ b/tests/functional/queries/sql/comparison_functions_sql.py @@ -1,75 +1,103 @@ +from tests.producers import * + SQL_SCALAR = { - "not_equal": + "not_equal": ( """ SELECT N_NAME FROM '{}' WHERE NOT N_NAME = 'CANADA' """, - "equal": + [DuckDBProducer], + ), + "equal": ( """ SELECT PS_AVAILQTY, PS_PARTKEY FROM '{}' WHERE PS_AVAILQTY = PS_PARTKEY """, - "is_not_distinct_from": + [DuckDBProducer, IsthmusProducer], + ), + "is_not_distinct_from": ( """ SELECT a FROM 't' WHERE a IS NOT DISTINCT FROM NULL """, - "lt": + [DuckDBProducer], + ), + "lt": ( """ SELECT PS_AVAILQTY FROM '{}' WHERE PS_AVAILQTY < 10 """, - "lte": + [DuckDBProducer, IsthmusProducer], + ), + "lte": ( """ SELECT PS_AVAILQTY FROM '{}' WHERE PS_AVAILQTY <= 10 """, - "gt": + [DuckDBProducer, IsthmusProducer], + ), + "gt": ( """ SELECT PS_AVAILQTY FROM '{}' WHERE PS_AVAILQTY > 10 """, - "gte": + [DuckDBProducer, IsthmusProducer], + ), + "gte": ( """ SELECT PS_AVAILQTY FROM '{}' WHERE PS_AVAILQTY >= 10 """, - "is_not_null": + [DuckDBProducer, IsthmusProducer], + ), + "is_not_null": ( """ SELECT a FROM 't' WHERE a IS NOT NULL """, - "is_null": + [DuckDBProducer], + ), + "is_null": ( """ SELECT a FROM 't' WHERE a IS NULL """, - "is_nan": + [DuckDBProducer], + ), + "is_nan": ( """ SELECT a, isnan(a) as isnan_a FROM 't' """, - "is_finite": + [DuckDBProducer], + ), + "is_finite": ( """ SELECT a, isfinite(a) as isfinite_a FROM 't' """, - "is_infinite": + [DuckDBProducer], + ), + "is_infinite": ( """ SELECT a, isinf(a) as isinf_a FROM 't' """, - "coalesce": + [DuckDBProducer], + ), + "coalesce": ( """ SELECT coalesce(NULL,NULL,'test_string') """, + [DuckDBProducer], + ), } diff --git a/tests/functional/queries/sql/datetime_functions_sql.py b/tests/functional/queries/sql/datetime_functions_sql.py index 88349265..609efff4 100644 --- a/tests/functional/queries/sql/datetime_functions_sql.py +++ b/tests/functional/queries/sql/datetime_functions_sql.py @@ -1,41 +1,59 @@ +from tests.producers import * + SQL_SCALAR = { - "extract": + "extract": ( """ SELECT L_SHIPDATE, extract('year' FROM L_SHIPDATE) FROM '{}'; """, - "add": + [DuckDBProducer], + ), + "add": ( """ SELECT L_SHIPDATE, L_SHIPDATE + INTERVAL 5 DAY FROM '{}'; """, - "add_intervals": + [DuckDBProducer], + ), + "add_intervals": ( """ SELECT INTERVAL 1 HOUR + INTERVAL 5 HOUR """, - "subtract": + [DuckDBProducer], + ), + "subtract": ( """ SELECT L_SHIPDATE, L_SHIPDATE - INTERVAL 5 DAY FROM '{}'; """, - "lt": + [DuckDBProducer], + ), + "lt": ( """ SELECT L_COMMITDATE, L_RECEIPTDATE, L_COMMITDATE < L_RECEIPTDATE FROM '{}'; """, - "lte": + [DuckDBProducer, IsthmusProducer], + ), + "lte": ( """ SELECT L_COMMITDATE, L_RECEIPTDATE, L_COMMITDATE <= L_RECEIPTDATE FROM '{}'; """, - "gt": + [DuckDBProducer, IsthmusProducer], + ), + "gt": ( """ SELECT L_COMMITDATE, L_RECEIPTDATE, L_COMMITDATE > L_RECEIPTDATE FROM '{}'; """, - "gte": + [DuckDBProducer, IsthmusProducer], + ), + "gte": ( """ SELECT L_COMMITDATE, L_RECEIPTDATE, L_COMMITDATE >= L_RECEIPTDATE FROM '{}'; """, + [DuckDBProducer, IsthmusProducer], + ), } diff --git a/tests/functional/queries/sql/logarithmic_functions_sql.py b/tests/functional/queries/sql/logarithmic_functions_sql.py index 4860e5a6..981c60d5 100644 --- a/tests/functional/queries/sql/logarithmic_functions_sql.py +++ b/tests/functional/queries/sql/logarithmic_functions_sql.py @@ -1,18 +1,32 @@ +from tests.producers import * + SQL_SCALAR = { - "ln": """ + "ln": ( + """ SELECT PS_SUPPLYCOST, ln(PS_SUPPLYCOST) AS LN_SUPPLY FROM '{}'; """, - "log10": """ + [DuckDBProducer], + ), + "log10": ( + """ SELECT PS_SUPPLYCOST, log10(PS_SUPPLYCOST) AS LOG10_SUPPLY FROM '{}'; """, - "log2": """ + [DuckDBProducer], + ), + "log2": ( + """ SELECT PS_SUPPLYCOST, log2(PS_SUPPLYCOST) AS LOG2_SUPPLY FROM '{}'; """, - "logb": """ + [DuckDBProducer], + ), + "logb": ( + """ SELECT PS_SUPPLYCOST, logb(PS_SUPPLYCOST, 10) AS LOGB_SUPPLY FROM '{}'; """, + [DuckDBProducer], + ), } diff --git a/tests/functional/queries/sql/rounding_functions_sql.py b/tests/functional/queries/sql/rounding_functions_sql.py index 4d09c1c3..42b73c5e 100644 --- a/tests/functional/queries/sql/rounding_functions_sql.py +++ b/tests/functional/queries/sql/rounding_functions_sql.py @@ -1,12 +1,18 @@ +from tests.producers import * + SQL_SCALAR = { - "ceil": + "ceil": ( """ SELECT PS_SUPPLYCOST, ceil(PS_SUPPLYCOST) AS CEIL_SUPPLYCOST FROM '{}'; """, - "floor": + [DuckDBProducer], + ), + "floor": ( """ SELECT PS_SUPPLYCOST, floor(PS_SUPPLYCOST) AS FLOOR_SUPPLYCOST FROM '{}'; """, + [DuckDBProducer], + ), } diff --git a/tests/functional/queries/sql/string_functions_sql.py b/tests/functional/queries/sql/string_functions_sql.py index 1cef06a2..f5e9f818 100644 --- a/tests/functional/queries/sql/string_functions_sql.py +++ b/tests/functional/queries/sql/string_functions_sql.py @@ -1,102 +1,173 @@ +from tests.producers import * + SQL_SCALAR = { - "concat": """ + "concat": ( + """ SELECT N_NAME, concat(N_NAME, N_COMMENT) AS concat_nation FROM '{}'; """, - "concat_ws": """ + [DuckDBProducer], + ), + "concat_ws": ( + """ SELECT concat_ws('.', N_NAME, N_COMMENT) FROM '{}'; """, - "like": """ + [DuckDBProducer], + ), + "like": ( + """ SELECT N_NAME FROM '{}' WHERE N_NAME LIKE 'ALGERIA'; """, - "starts_with": """ + [DuckDBProducer], + ), + "starts_with": ( + """ SELECT N_NAME FROM '{}' WHERE prefix(N_NAME, 'A'); """, - "ends_with": """ + [DuckDBProducer], + ), + "ends_with": ( + """ SELECT N_NAME FROM '{}' WHERE suffix(N_NAME, 'A'); """, - "substring": """ + [DuckDBProducer], + ), + "substring": ( + """ SELECT N_NAME, substr(N_NAME, 1, 3) AS substr_name FROM '{}'; """, - "contains": """ + [DuckDBProducer], + ), + "contains": ( + """ SELECT N_NAME FROM '{}' WHERE contains(N_NAME, 'IA'); """, - "strpos": """ + [DuckDBProducer], + ), + "strpos": ( + """ SELECT N_NAME, strpos(N_NAME, 'A') AS strpos_name FROM '{}' """, - "replace": """ + [DuckDBProducer], + ), + "replace": ( + """ SELECT N_NAME, replace(N_NAME, 'A', 'a') AS replace_name FROM '{}' """, - "repeat": """ + [DuckDBProducer], + ), + "repeat": ( + """ SELECT N_NAME, repeat(N_NAME, 2) AS repeated_N_NAME FROM '{}' """, - "reverse": """ + [DuckDBProducer], + ), + "reverse": ( + """ SELECT N_NAME, reverse(N_NAME) AS reversed_N_NAME FROM '{}' """, - "lower": """ + [DuckDBProducer], + ), + "lower": ( + """ SELECT N_NAME, lower(N_NAME) AS lowercase_N_NAME FROM '{}' """, - "upper": """ + [DuckDBProducer], + ), + "upper": ( + """ SELECT O_COMMENT, upper(O_COMMENT) AS uppercase_O_COMMENT FROM '{}' """, - "char_length": """ + [DuckDBProducer], + ), + "char_length": ( + """ SELECT N_NAME, length(N_NAME) AS char_length_N_NAME FROM '{}' """, - "bit_length": """ + [DuckDBProducer], + ), + "bit_length": ( + """ SELECT N_NAME, bit_length(N_NAME) AS bit_length_N_NAME FROM '{}' """, - "ltrim": """ + [DuckDBProducer], + ), + "ltrim": ( + """ SELECT N_NAME, ltrim(N_NAME, 'A') AS ltrim_N_NAME FROM '{}' """, - "rtrim": """ + [DuckDBProducer], + ), + "rtrim": ( + """ SELECT N_NAME, rtrim(N_NAME, 'A') AS rtrim_N_NAME FROM '{}' """, - "trim": """ + [DuckDBProducer], + ), + "trim": ( + """ SELECT N_NAME, trim(N_NAME, 'A') AS trim_N_NAME FROM '{}' """, - "lpad": """ + [DuckDBProducer], + ), + "lpad": ( + """ SELECT N_NAME, lpad(N_NAME, 10, ' ') AS lpad_N_NAME FROM '{}' """, - "rpad": """ + [DuckDBProducer], + ), + "rpad": ( + """ SELECT N_NAME, rpad(N_NAME, 10, ' ') AS rpad_N_NAME FROM '{}' """, - "left": """ + [DuckDBProducer], + ), + "left": ( + """ SELECT N_NAME, left(N_NAME, 2) AS left_extract_N_NAME FROM '{}' """, - "right": """ + [DuckDBProducer], + ), + "right": ( + """ SELECT N_NAME, right(N_NAME, 2) AS right_extract_N_NAME FROM '{}' """, + [DuckDBProducer], + ), } SQL_AGGREGATE = { - "string_agg": """ + "string_agg": ( + """ SELECT N_NAME, string_agg(N_NAME, ',') FROM '{}' GROUP BY N_NAME """, + [DuckDBProducer], + ), } diff --git a/tests/java_definitions.py b/tests/java_definitions.py new file mode 100644 index 00000000..99a6c88d --- /dev/null +++ b/tests/java_definitions.py @@ -0,0 +1,24 @@ +import os +from pathlib import Path + +import jpype + +REPO_DIR = Path(__file__).parent.parent +isthmus_jars = Path.joinpath(REPO_DIR, "jars/*") + +the_java_home = "CONDA_PREFIX" +if "JAVA_HOME" in os.environ: + the_java_home = "JAVA_HOME" + +java_home_path = os.environ[the_java_home] +jvm_path = java_home_path + +if not os.path.isfile(jvm_path): + jvm_path = java_home_path + "/lib/libjli.dylib" + +jpype.startJVM("--enable-preview", convertStrings=True, jvmpath=jvm_path) +jpype.addClassPath(isthmus_jars) + +ArrayListClass = jpype.JClass("java.util.ArrayList") +ListClass = jpype.JClass("java.util.List") +SqlToSubstraitClass = jpype.JClass("io.substrait.isthmus.SqlToSubstrait") diff --git a/tests/producers.py b/tests/producers.py index 22becbd6..a9272abf 100644 --- a/tests/producers.py +++ b/tests/producers.py @@ -8,6 +8,7 @@ from ibis_substrait.compiler.core import SubstraitCompiler from tests.common import SubstraitUtils +from tests.context import get_schema, produce_isthmus_substrait class DuckDBProducer: @@ -43,38 +44,9 @@ def produce_substrait( proto_bytes = duckdb_substrait_plan.fetchone()[0] return proto_bytes - def load_tables_from_parquet( - self, - created_tables: set, - file_names: Iterable[str], - ) -> list: - """ - Load all the parquet files into separate tables in DuckDB. - - Parameters: - created_tables: - The set of tables that have already been created. - file_names: - Name of parquet files. - Returns: - A list of the table names. - """ - parquet_file_paths = SubstraitUtils.get_full_path(file_names) - table_names = [] - for file_name, file_path in zip(file_names, parquet_file_paths): - table_name = Path(file_name).stem - table_name = table_name.translate(str.maketrans("", "", string.punctuation)) - if table_name not in created_tables: - create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');" - self.db_connection.execute(create_table_sql) - created_tables.add(table_name) - table_names.append(table_name) - - return table_names - def format_sql(self, created_tables, sql_query, file_names): if len(file_names) > 0: - table_names = self.load_tables_from_parquet(created_tables, file_names) + table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names) sql_query = sql_query.format(*table_names) return sql_query @@ -116,37 +88,84 @@ def produce_substrait( substrait_plan = json_format.MessageToJson(tpch_proto_bytes) return substrait_plan - def load_tables_from_parquet( - self, - created_tables: set, - file_names: Iterable[str], - ) -> list: + def format_sql(self, created_tables, sql_query, file_names): + if len(file_names) > 0: + table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names) + sql_query = sql_query.format(*table_names) + return sql_query + + +class IsthmusProducer: + def __init__(self, db_connection=None): + if db_connection is not None: + self.db_connection = db_connection + else: + self.db_connection = duckdb.connect() + + self.db_connection.execute("INSTALL substrait") + self.db_connection.execute("LOAD substrait") + self.compiler = SubstraitCompiler() + self.file_names = None + + def set_db_connection(self, db_connection): + self.db_connection = db_connection + + def produce_substrait( + self, sql_query: str, consumer, ibis_expr: str = None + ) -> str: """ - Load all the parquet files into separate tables in DuckDB. + Produce the Isthmus substrait plan using the given SQL query. Parameters: - created_tables: - The set of tables that have already been created. - file_names: - Name of parquet files. + sql_query: + SQL query. + consumer: + Name of substrait consumer. Returns: - A list of the table names. + Substrait query plan in json format. """ - parquet_file_paths = SubstraitUtils.get_full_path(file_names) - table_names = [] - for file_name, file_path in zip(file_names, parquet_file_paths): - table_name = Path(file_name).stem - table_name = table_name.translate(str.maketrans("", "", string.punctuation)) - if table_name not in created_tables: - create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');" - self.db_connection.execute(create_table_sql) - created_tables.add(table_name) - table_names.append(table_name) - - return table_names + schema_list = get_schema(self.file_names) + substrait_plan_str = produce_isthmus_substrait(sql_query, schema_list) + + return substrait_plan_str def format_sql(self, created_tables, sql_query, file_names): + sql_query = sql_query.replace("'{}'", "{}") + sql_query = sql_query.replace("'t'", "t") if len(file_names) > 0: - table_names = self.load_tables_from_parquet(created_tables, file_names) + self.file_names = file_names + table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names) sql_query = sql_query.format(*table_names) return sql_query + + +def load_tables_from_parquet( + db_connection, + created_tables: set, + file_names: Iterable[str], +) -> list: + """ + Load all the parquet files into separate tables in DuckDB. + + Parameters: + db_connection: + DuckDB Connection. + created_tables: + The set of tables that have already been created. + file_names: + Name of parquet files. + Returns: + A list of the table names. + """ + parquet_file_paths = SubstraitUtils.get_full_path(file_names) + table_names = [] + for file_name, file_path in zip(file_names, parquet_file_paths): + table_name = Path(file_name).stem + table_name = table_name.translate(str.maketrans("", "", string.punctuation)) + if table_name not in created_tables: + create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');" + db_connection.execute(create_table_sql) + created_tables.add(table_name) + table_names.append(table_name) + + return table_names