From 20d8379842fb93645de6dc13005ece4560ee5d62 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 21 Jul 2023 15:32:22 +0200 Subject: [PATCH 01/11] fix #8321 safe the 'column' of a df as object, so we keep an owning reference, in case the object gets created when accessed, not just referenced --- tools/pythonpkg/src/pandas/bind.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/pythonpkg/src/pandas/bind.cpp b/tools/pythonpkg/src/pandas/bind.cpp index 47c43a48bcee..152c01459339 100644 --- a/tools/pythonpkg/src/pandas/bind.cpp +++ b/tools/pythonpkg/src/pandas/bind.cpp @@ -9,13 +9,13 @@ namespace { struct PandasBindColumn { public: - PandasBindColumn(py::handle name, py::handle type, py::handle column) : name(name), type(type), handle(column) { + PandasBindColumn(py::handle name, py::handle type, py::object column) : name(name), type(type), handle(column) { } public: py::handle name; py::handle type; - py::handle handle; + py::object handle; }; struct PandasDataFrameBind { @@ -27,7 +27,7 @@ struct PandasDataFrameBind { } PandasBindColumn operator[](idx_t index) const { D_ASSERT(index < names.size()); - auto column = getter(names[index]); + auto column = py::reinterpret_borrow(getter(names[index])); auto type = types[index]; auto name = names[index]; return PandasBindColumn(name, type, column); From 2fc8224273b7363f32c77978cdc94ae2f78de384 Mon Sep 17 00:00:00 2001 From: Tishj Date: Fri, 21 Jul 2023 15:49:43 +0200 Subject: [PATCH 02/11] add test for copy_on_write functionality in pandas --- .../tests/fast/pandas/test_copy_on_write.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py diff --git a/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py b/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py new file mode 100644 index 000000000000..7b2aa77819c6 --- /dev/null +++ b/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py @@ -0,0 +1,29 @@ +import duckdb +import pytest +import pandas + +# Make sure the variable get's properly reset even in case of error +@pytest.fixture(autouse=True) +def scoped_copy_on_write_setting(): + old_value = pandas.options.mode.copy_on_write + pandas.options.mode.copy_on_write = True + yield + # Reset it at the end of the function + pandas.options.mode.copy_on_write = old_value + return + +class TestCopyOnWrite(object): + def test_copy_on_write(self): + assert pandas.options.mode.copy_on_write == True + + con = duckdb.connect() + df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) + rel = con.sql('select * from df_in') + res = rel.fetchall() + assert res == [ + (1,), + (2,), + (3,), + (4,), + (5,) + ] From f4d1706b992759110e680c0dbfc9c05c5962e0e2 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Mon, 17 Jul 2023 05:55:44 +0000 Subject: [PATCH 03/11] build: add black to format.py --- .github/workflows/CodeQuality.yml | 3 ++- scripts/format.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CodeQuality.yml b/.github/workflows/CodeQuality.yml index 10bdba780f51..9bacece7613d 100644 --- a/.github/workflows/CodeQuality.yml +++ b/.github/workflows/CodeQuality.yml @@ -49,13 +49,14 @@ jobs: - name: Install shell: bash - run: sudo apt-get update -y -qq && sudo apt-get install -y -qq ninja-build clang-format && sudo pip3 install cmake-format + run: sudo apt-get update -y -qq && sudo apt-get install -y -qq ninja-build clang-format && sudo pip3 install cmake-format black - name: Format Check shell: bash run: | clang-format --version clang-format --dump-config + black --version make format-check-silent - name: Generated Check diff --git a/scripts/format.py b/scripts/format.py index f198dacf032a..2336111159d8 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -13,7 +13,7 @@ cpp_format_command = 'clang-format --sort-includes=0 -style=file' cmake_format_command = 'cmake-format' -extensions = ['.cpp', '.c', '.hpp', '.h', '.cc', '.hh', 'CMakeLists.txt', '.test', '.test_slow', '.test_coverage', '.benchmark'] +extensions = ['.cpp', '.c', '.hpp', '.h', '.cc', '.hh', 'CMakeLists.txt', '.test', '.test_slow', '.test_coverage', '.benchmark', '.py'] formatted_directories = ['src', 'benchmark', 'test', 'tools', 'examples', 'extension'] ignored_files = ['tpch_constants.hpp', 'tpcds_constants.hpp', '_generated', 'tpce_flat_input.hpp', 'test_csv_header.hpp', 'duckdb.cpp', 'duckdb.hpp', 'json.hpp', 'sqlite3.h', 'shell.c', @@ -161,7 +161,8 @@ def get_changed_files(revision): '.h': cpp_format_command, '.hh': cpp_format_command, '.cc': cpp_format_command, - '.txt': cmake_format_command + '.txt': cmake_format_command, + '.py': 'black --quiet - --skip-string-normalization --line-length 120 --stdin-filename', } difference_files = [] @@ -234,7 +235,7 @@ def get_formatted_text(f, full_path, directory, ext): header.append('\n') return ''.join(header + lines) proc_command = format_commands[ext].split(' ') + [full_path] - proc = subprocess.Popen(proc_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + proc = subprocess.Popen(proc_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=open(full_path) if ext == '.py' else None) new_text = proc.stdout.read().decode('utf8') stderr = proc.stderr.read().decode('utf8') if len(stderr) > 0: From 045a0193a3892057f3b99e7358fa22a7c953017b Mon Sep 17 00:00:00 2001 From: Elliana May Date: Sat, 22 Jul 2023 19:39:05 +0800 Subject: [PATCH 04/11] fix: format --- examples/python/duckdb-python.py | 15 +- extension/excel/excel_config.py | 14 +- extension/fts/fts_config.py | 55 +- extension/httpfs/httpfs_config.py | 11 +- extension/icu/icu_config.py | 14 +- extension/icu/scripts/inline-data.py | 4 +- extension/jemalloc/jemalloc_config.py | 74 +- extension/json/json_config.py | 33 +- extension/parquet/parquet_config.py | 73 +- extension/tpcds/tpcds_config.py | 72 +- extension/tpch/tpch_config.py | 21 +- .../visualizer/generate_visualizer_header.py | 9 +- extension/visualizer/visualizer_config.py | 1 + test/memoryleak/test_memory_leaks.py | 49 +- test/parquet/generate_parquet_test.py | 63 +- test/sqlserver/scrape.py | 178 +-- tools/juliapkg/release.py | 27 +- tools/nodejs/configure.py | 9 +- tools/odbc/test/isql-test.py | 296 +++-- tools/odbc/test/run_psqlodbc_test.py | 67 +- tools/pythonpkg/pyduckdb/__init__.py | 4 +- tools/pythonpkg/pyduckdb/bytes_io_wrapper.py | 49 +- tools/pythonpkg/pyduckdb/filesystem.py | 108 +- tools/pythonpkg/pyduckdb/spark/__init__.py | 8 +- tools/pythonpkg/pyduckdb/spark/_globals.py | 4 +- tools/pythonpkg/pyduckdb/spark/conf.py | 80 +- tools/pythonpkg/pyduckdb/spark/context.py | 274 +++-- tools/pythonpkg/pyduckdb/spark/exception.py | 15 +- .../pythonpkg/pyduckdb/spark/sql/__init__.py | 8 +- tools/pythonpkg/pyduckdb/spark/sql/catalog.py | 99 +- tools/pythonpkg/pyduckdb/spark/sql/conf.py | 26 +- .../pythonpkg/pyduckdb/spark/sql/dataframe.py | 32 +- .../pyduckdb/spark/sql/readwriter.py | 45 +- tools/pythonpkg/pyduckdb/spark/sql/session.py | 227 ++-- .../pythonpkg/pyduckdb/spark/sql/streaming.py | 39 +- tools/pythonpkg/pyduckdb/spark/sql/types.py | 10 +- tools/pythonpkg/pyduckdb/spark/sql/udf.py | 10 +- .../spark/tests/basic/test_spark_catalog.py | 49 +- .../spark/tests/basic/test_spark_dataframe.py | 44 +- .../tests/basic/test_spark_runtime_config.py | 29 +- .../spark/tests/basic/test_spark_session.py | 145 +-- .../pyduckdb/spark/tests/conftest.py | 3 +- tools/pythonpkg/pyduckdb/udf.py | 37 +- tools/pythonpkg/pyduckdb/value/constant.py | 58 +- tools/pythonpkg/setup.py | 87 +- tools/pythonpkg/tests/conftest.py | 30 +- .../test_pandas_categorical_coverage.py | 40 +- .../tests/extensions/json/test_read_json.py | 22 +- .../pythonpkg/tests/extensions/test_httpfs.py | 86 +- tools/pythonpkg/tests/fast/api/test_3324.py | 18 +- tools/pythonpkg/tests/fast/api/test_3654.py | 24 +- tools/pythonpkg/tests/fast/api/test_3728.py | 6 +- tools/pythonpkg/tests/fast/api/test_6315.py | 1 + .../tests/fast/api/test_attribute_getter.py | 11 +- tools/pythonpkg/tests/fast/api/test_config.py | 13 +- .../tests/fast/api/test_connection_close.py | 8 +- tools/pythonpkg/tests/fast/api/test_cursor.py | 10 +- .../pythonpkg/tests/fast/api/test_dbapi00.py | 39 +- .../pythonpkg/tests/fast/api/test_dbapi01.py | 19 +- .../pythonpkg/tests/fast/api/test_dbapi04.py | 16 +- .../pythonpkg/tests/fast/api/test_dbapi05.py | 28 +- .../pythonpkg/tests/fast/api/test_dbapi07.py | 3 +- .../pythonpkg/tests/fast/api/test_dbapi08.py | 1 + .../pythonpkg/tests/fast/api/test_dbapi09.py | 1 + .../pythonpkg/tests/fast/api/test_dbapi10.py | 8 +- .../pythonpkg/tests/fast/api/test_dbapi11.py | 8 +- .../pythonpkg/tests/fast/api/test_dbapi12.py | 22 +- .../pythonpkg/tests/fast/api/test_dbapi13.py | 1 + .../tests/fast/api/test_dbapi_fetch.py | 1 + .../tests/fast/api/test_duckdb_connection.py | 42 +- .../tests/fast/api/test_duckdb_query.py | 100 +- .../pythonpkg/tests/fast/api/test_explain.py | 61 +- .../tests/fast/api/test_insert_into.py | 5 +- tools/pythonpkg/tests/fast/api/test_join.py | 1 + .../tests/fast/api/test_query_interrupt.py | 2 + .../pythonpkg/tests/fast/api/test_read_csv.py | 751 ++++++------ .../tests/fast/api/test_streaming_result.py | 101 +- tools/pythonpkg/tests/fast/api/test_to_csv.py | 48 +- .../tests/fast/api/test_to_parquet.py | 7 +- .../api/test_with_propagating_exceptions.py | 2 +- .../fast/arrow/parquet_write_roundtrip.py | 37 +- tools/pythonpkg/tests/fast/arrow/test_2426.py | 11 +- tools/pythonpkg/tests/fast/arrow/test_6584.py | 10 +- tools/pythonpkg/tests/fast/arrow/test_6796.py | 23 +- tools/pythonpkg/tests/fast/arrow/test_7652.py | 8 +- tools/pythonpkg/tests/fast/arrow/test_7699.py | 13 +- .../fast/arrow/test_arrow_batch_index.py | 3 + .../fast/arrow/test_arrow_case_sensitive.py | 28 +- .../tests/fast/arrow/test_arrow_fetch.py | 26 +- .../arrow/test_arrow_fetch_recordbatch.py | 95 +- .../tests/fast/arrow/test_arrow_list.py | 63 +- .../arrow/test_arrow_recordbatchreader.py | 147 ++- .../fast/arrow/test_arrow_replacement_scan.py | 23 +- .../tests/fast/arrow/test_arrow_scanner.py | 79 +- .../tests/fast/arrow/test_arrow_types.py | 11 +- .../tests/fast/arrow/test_binary_type.py | 8 +- .../fast/arrow/test_buffer_size_option.py | 11 +- .../tests/fast/arrow/test_dataset.py | 90 +- tools/pythonpkg/tests/fast/arrow/test_date.py | 41 +- .../tests/fast/arrow/test_dictionary_arrow.py | 100 +- .../tests/fast/arrow/test_filter_pushdown.py | 286 +++-- .../tests/fast/arrow/test_integration.py | 130 ++- .../tests/fast/arrow/test_interval.py | 50 +- .../tests/fast/arrow/test_large_string.py | 8 +- .../tests/fast/arrow/test_multiple_reads.py | 9 +- .../tests/fast/arrow/test_nested_arrow.py | 172 +-- .../tests/fast/arrow/test_parallel.py | 31 +- .../pythonpkg/tests/fast/arrow/test_polars.py | 11 +- .../tests/fast/arrow/test_progress.py | 23 +- .../fast/arrow/test_projection_pushdown.py | 11 +- tools/pythonpkg/tests/fast/arrow/test_time.py | 72 +- .../fast/arrow/test_timestamp_timezone.py | 125 +- .../tests/fast/arrow/test_timestamps.py | 54 +- tools/pythonpkg/tests/fast/arrow/test_tpch.py | 82 +- .../tests/fast/arrow/test_unregister.py | 9 +- tools/pythonpkg/tests/fast/arrow/test_view.py | 9 +- .../tests/fast/numpy/test_numpy_new_path.py | 40 +- .../pythonpkg/tests/fast/pandas/test_2304.py | 105 +- .../tests/fast/pandas/test_append_df.py | 44 +- .../tests/fast/pandas/test_bug2281.py | 3 +- .../tests/fast/pandas/test_bug5922.py | 9 +- .../pandas/test_create_table_from_pandas.py | 16 +- .../fast/pandas/test_date_as_datetime.py | 3 +- .../tests/fast/pandas/test_datetime_time.py | 22 +- .../fast/pandas/test_datetime_timestamp.py | 63 +- .../tests/fast/pandas/test_df_analyze.py | 90 +- .../fast/pandas/test_df_object_resolution.py | 444 ++++---- .../fast/pandas/test_df_recursive_nested.py | 331 ++---- .../tests/fast/pandas/test_fetch_df_chunk.py | 48 +- .../tests/fast/pandas/test_fetch_nested.py | 241 ++-- .../fast/pandas/test_implicit_pandas_scan.py | 6 +- .../tests/fast/pandas/test_issue_1767.py | 2 +- .../pythonpkg/tests/fast/pandas/test_limit.py | 20 +- .../tests/fast/pandas/test_pandas_arrow.py | 183 ++- .../tests/fast/pandas/test_pandas_category.py | 113 +- .../tests/fast/pandas/test_pandas_enum.py | 18 +- .../tests/fast/pandas/test_pandas_limit.py | 1 + .../tests/fast/pandas/test_pandas_na.py | 46 +- .../tests/fast/pandas/test_pandas_object.py | 78 +- .../tests/fast/pandas/test_pandas_string.py | 21 +- .../fast/pandas/test_pandas_timestamp.py | 9 +- .../tests/fast/pandas/test_pandas_types.py | 99 +- .../fast/pandas/test_pandas_unregister.py | 1 + .../tests/fast/pandas/test_pandas_update.py | 15 +- .../fast/pandas/test_parallel_pandas_scan.py | 52 +- .../pandas/test_partitioned_pandas_scan.py | 1 + .../tests/fast/pandas/test_progress_bar.py | 15 +- .../pandas/test_pyarrow_filter_pushdown.py | 238 ++-- .../test_pyarrow_projection_pushdown.py | 6 +- .../tests/fast/pandas/test_same_name.py | 79 +- .../tests/fast/pandas/test_stride.py | 6 +- .../tests/fast/pandas/test_timedelta.py | 25 +- .../tests/fast/pandas/test_timestamp.py | 30 +- .../relational_api/test_rapi_aggregations.py | 120 +- .../fast/relational_api/test_rapi_close.py | 259 ++--- .../relational_api/test_rapi_description.py | 5 +- .../relational_api/test_rapi_functions.py | 1 + .../fast/relational_api/test_rapi_query.py | 21 +- .../pythonpkg/tests/fast/sqlite/test_types.py | 18 +- .../tests/fast/test_alex_multithread.py | 58 +- tools/pythonpkg/tests/fast/test_all_types.py | 149 ++- .../tests/fast/test_ambiguous_prepare.py | 2 +- tools/pythonpkg/tests/fast/test_case_alias.py | 3 +- .../tests/fast/test_context_manager.py | 3 +- tools/pythonpkg/tests/fast/test_filesystem.py | 19 +- .../tests/fast/test_get_table_names.py | 4 +- .../tests/fast/test_import_export.py | 28 +- .../test_import_without_pyarrow_dataset.py | 22 +- tools/pythonpkg/tests/fast/test_insert.py | 30 +- .../tests/fast/test_many_con_same_file.py | 14 +- tools/pythonpkg/tests/fast/test_map.py | 73 +- .../pythonpkg/tests/fast/test_memory_leaks.py | 4 +- .../tests/fast/test_multi_statement.py | 8 +- .../pythonpkg/tests/fast/test_multithread.py | 195 ++-- .../tests/fast/test_non_default_conn.py | 84 +- .../tests/fast/test_parameter_list.py | 15 +- tools/pythonpkg/tests/fast/test_parquet.py | 140 ++- tools/pythonpkg/tests/fast/test_pytorch.py | 65 +- tools/pythonpkg/tests/fast/test_relation.py | 136 ++- .../fast/test_relation_dependency_leak.py | 19 +- .../tests/fast/test_replacement_scan.py | 87 +- tools/pythonpkg/tests/fast/test_result.py | 32 +- .../tests/fast/test_runtime_error.py | 40 +- tools/pythonpkg/tests/fast/test_tf.py | 55 +- .../pythonpkg/tests/fast/test_transaction.py | 7 +- tools/pythonpkg/tests/fast/test_type.py | 46 +- tools/pythonpkg/tests/fast/test_unicode.py | 3 +- tools/pythonpkg/tests/fast/test_value.py | 123 +- .../tests/fast/test_windows_abs_path.py | 3 +- tools/pythonpkg/tests/fast/types/test_blob.py | 1 + .../tests/fast/types/test_boolean.py | 3 +- .../tests/fast/types/test_decimal.py | 27 +- .../tests/fast/types/test_hugeint.py | 5 +- tools/pythonpkg/tests/fast/types/test_nan.py | 5 +- .../pythonpkg/tests/fast/types/test_nested.py | 12 +- tools/pythonpkg/tests/fast/types/test_null.py | 9 +- .../tests/fast/types/test_numeric.py | 12 +- .../pythonpkg/tests/fast/types/test_numpy.py | 10 +- .../tests/fast/types/test_object_int.py | 51 +- .../tests/fast/types/test_unsigned.py | 5 +- .../tests/fast/udf/test_remove_function.py | 33 +- tools/pythonpkg/tests/fast/udf/test_scalar.py | 214 ++-- .../tests/fast/udf/test_scalar_arrow.py | 41 +- .../tests/fast/udf/test_scalar_native.py | 108 +- .../pythonpkg/tests/slow/test_h2oai_arrow.py | 105 +- tools/pythonpkg/tests/stubs/test_stubs.py | 23 +- tools/release-pip.py | 50 +- tools/rpkg/rconfigure.py | 5 +- tools/shell/shell-test.py | 1003 +++++++++++------ tools/swift/create_package.py | 16 +- tools/upload-s3.py | 23 +- 211 files changed, 7112 insertions(+), 5083 deletions(-) mode change 100755 => 100644 tools/pythonpkg/setup.py diff --git a/examples/python/duckdb-python.py b/examples/python/duckdb-python.py index 5b8a8fe118e3..6a745b6948ab 100644 --- a/examples/python/duckdb-python.py +++ b/examples/python/duckdb-python.py @@ -30,7 +30,8 @@ # we can query pandas data frames as if they were SQL views # create a sample pandas data frame import pandas as pd -test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + +test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) # make this data frame available as a view in duckdb conn.register("test_df", test_df) @@ -49,8 +50,9 @@ # create a relation from a CSV file -# first create a CSV file from our pandas example +# first create a CSV file from our pandas example import tempfile, os + temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) test_df.to_csv(temp_file_name, index=False) @@ -139,9 +141,7 @@ # turn the relation into something else again - - -# compute the query result from the relation +# compute the query result from the relation res = rel.execute() print(res) # res is a query result, you can call fetchdf() or fetchnumpy() or fetchone() on it @@ -164,7 +164,7 @@ # Inserting elements into table_3 print(conn.values([5, 'five']).insert_into("test_table3")) rel_3 = conn.table("test_table3") -rel_3.insert([6,'six']) +rel_3.insert([6, 'six']) # create a SQL-accessible view of the relation print(rel.create_view('test_view')) @@ -183,6 +183,3 @@ # this also works directly on data frames res = duckdb.query(test_df, 'my_name_for_test_df', 'SELECT * FROM my_name_for_test_df') print(res.df()) - - - diff --git a/extension/excel/excel_config.py b/extension/excel/excel_config.py index 1bb4c113ce21..6509469ab042 100644 --- a/extension/excel/excel_config.py +++ b/extension/excel/excel_config.py @@ -1,6 +1,16 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/excel/include', 'extension/excel/numformat/include']] +include_directories = [ + os.path.sep.join(x.split('/')) for x in ['extension/excel/include', 'extension/excel/numformat/include'] +] # source files source_files = [os.path.sep.join(x.split('/')) for x in ['extension/excel/excel_extension.cpp']] -source_files += [os.path.sep.join(x.split('/')) for x in ['extension/excel/numformat/nf_calendar.cpp', 'extension/excel/numformat/nf_localedata.cpp', 'extension/excel/numformat/nf_zformat.cpp']] +source_files += [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/excel/numformat/nf_calendar.cpp', + 'extension/excel/numformat/nf_localedata.cpp', + 'extension/excel/numformat/nf_zformat.cpp', + ] +] diff --git a/extension/fts/fts_config.py b/extension/fts/fts_config.py index d95908061f6b..1f8f0eb0ec7e 100644 --- a/extension/fts/fts_config.py +++ b/extension/fts/fts_config.py @@ -1,8 +1,55 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/fts/include', 'third_party/snowball/libstemmer', 'third_party/snowball/runtime', 'third_party/snowball/src_c']] +include_directories = [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/fts/include', + 'third_party/snowball/libstemmer', + 'third_party/snowball/runtime', + 'third_party/snowball/src_c', + ] +] # source files -source_files = [os.path.sep.join(x.split('/')) for x in ['extension/fts/fts_extension.cpp', 'extension/fts/fts_indexing.cpp']] +source_files = [ + os.path.sep.join(x.split('/')) for x in ['extension/fts/fts_extension.cpp', 'extension/fts/fts_indexing.cpp'] +] # snowball -source_files += [os.path.sep.join(x.split('/')) for x in ['third_party/snowball/libstemmer/libstemmer.cpp', 'third_party/snowball/runtime/utilities.cpp', 'third_party/snowball/runtime/api.cpp', 'third_party/snowball/src_c/stem_UTF_8_arabic.cpp', 'third_party/snowball/src_c/stem_UTF_8_basque.cpp', 'third_party/snowball/src_c/stem_UTF_8_catalan.cpp', 'third_party/snowball/src_c/stem_UTF_8_danish.cpp', 'third_party/snowball/src_c/stem_UTF_8_dutch.cpp', 'third_party/snowball/src_c/stem_UTF_8_english.cpp', 'third_party/snowball/src_c/stem_UTF_8_finnish.cpp', 'third_party/snowball/src_c/stem_UTF_8_french.cpp', 'third_party/snowball/src_c/stem_UTF_8_german.cpp', 'third_party/snowball/src_c/stem_UTF_8_german2.cpp', 'third_party/snowball/src_c/stem_UTF_8_greek.cpp', 'third_party/snowball/src_c/stem_UTF_8_hindi.cpp', 'third_party/snowball/src_c/stem_UTF_8_hungarian.cpp', 'third_party/snowball/src_c/stem_UTF_8_indonesian.cpp', 'third_party/snowball/src_c/stem_UTF_8_irish.cpp', 'third_party/snowball/src_c/stem_UTF_8_italian.cpp', 'third_party/snowball/src_c/stem_UTF_8_kraaij_pohlmann.cpp', 'third_party/snowball/src_c/stem_UTF_8_lithuanian.cpp', 'third_party/snowball/src_c/stem_UTF_8_lovins.cpp', 'third_party/snowball/src_c/stem_UTF_8_nepali.cpp', 'third_party/snowball/src_c/stem_UTF_8_norwegian.cpp', 'third_party/snowball/src_c/stem_UTF_8_porter.cpp', 'third_party/snowball/src_c/stem_UTF_8_portuguese.cpp', 'third_party/snowball/src_c/stem_UTF_8_romanian.cpp', 'third_party/snowball/src_c/stem_UTF_8_russian.cpp', 'third_party/snowball/src_c/stem_UTF_8_serbian.cpp', 'third_party/snowball/src_c/stem_UTF_8_spanish.cpp', 'third_party/snowball/src_c/stem_UTF_8_swedish.cpp', 'third_party/snowball/src_c/stem_UTF_8_tamil.cpp', 'third_party/snowball/src_c/stem_UTF_8_turkish.cpp']] - +source_files += [ + os.path.sep.join(x.split('/')) + for x in [ + 'third_party/snowball/libstemmer/libstemmer.cpp', + 'third_party/snowball/runtime/utilities.cpp', + 'third_party/snowball/runtime/api.cpp', + 'third_party/snowball/src_c/stem_UTF_8_arabic.cpp', + 'third_party/snowball/src_c/stem_UTF_8_basque.cpp', + 'third_party/snowball/src_c/stem_UTF_8_catalan.cpp', + 'third_party/snowball/src_c/stem_UTF_8_danish.cpp', + 'third_party/snowball/src_c/stem_UTF_8_dutch.cpp', + 'third_party/snowball/src_c/stem_UTF_8_english.cpp', + 'third_party/snowball/src_c/stem_UTF_8_finnish.cpp', + 'third_party/snowball/src_c/stem_UTF_8_french.cpp', + 'third_party/snowball/src_c/stem_UTF_8_german.cpp', + 'third_party/snowball/src_c/stem_UTF_8_german2.cpp', + 'third_party/snowball/src_c/stem_UTF_8_greek.cpp', + 'third_party/snowball/src_c/stem_UTF_8_hindi.cpp', + 'third_party/snowball/src_c/stem_UTF_8_hungarian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_indonesian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_irish.cpp', + 'third_party/snowball/src_c/stem_UTF_8_italian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_kraaij_pohlmann.cpp', + 'third_party/snowball/src_c/stem_UTF_8_lithuanian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_lovins.cpp', + 'third_party/snowball/src_c/stem_UTF_8_nepali.cpp', + 'third_party/snowball/src_c/stem_UTF_8_norwegian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_porter.cpp', + 'third_party/snowball/src_c/stem_UTF_8_portuguese.cpp', + 'third_party/snowball/src_c/stem_UTF_8_romanian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_russian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_serbian.cpp', + 'third_party/snowball/src_c/stem_UTF_8_spanish.cpp', + 'third_party/snowball/src_c/stem_UTF_8_swedish.cpp', + 'third_party/snowball/src_c/stem_UTF_8_tamil.cpp', + 'third_party/snowball/src_c/stem_UTF_8_turkish.cpp', + ] +] diff --git a/extension/httpfs/httpfs_config.py b/extension/httpfs/httpfs_config.py index effa2c822d3d..894af2e5c13b 100644 --- a/extension/httpfs/httpfs_config.py +++ b/extension/httpfs/httpfs_config.py @@ -1,5 +1,12 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/httpfs/include', 'third_party/httplib', 'extension/parquet/include']] +include_directories = [ + os.path.sep.join(x.split('/')) + for x in ['extension/httpfs/include', 'third_party/httplib', 'extension/parquet/include'] +] # source files -source_files = [os.path.sep.join(x.split('/')) for x in ['extension/httpfs/' + s for s in ['httpfs_extension.cpp', 'httpfs.cpp', 's3fs.cpp', 'crypto.cpp']]] +source_files = [ + os.path.sep.join(x.split('/')) + for x in ['extension/httpfs/' + s for s in ['httpfs_extension.cpp', 'httpfs.cpp', 's3fs.cpp', 'crypto.cpp']] +] diff --git a/extension/icu/icu_config.py b/extension/icu/icu_config.py index 4f3d7e57def6..e77321875f6d 100644 --- a/extension/icu/icu_config.py +++ b/extension/icu/icu_config.py @@ -1,10 +1,18 @@ import os # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/icu/include', 'extension/icu/third_party/icu/common', 'extension/icu/third_party/icu/i18n']] +include_directories = [ + os.path.sep.join(x.split('/')) + for x in ['extension/icu/include', 'extension/icu/third_party/icu/common', 'extension/icu/third_party/icu/i18n'] +] # source files -source_directories = [os.path.sep.join(x.split('/')) for x in ['.', 'third_party/icu/common', 'third_party/icu/i18n', 'third_party/icu/stubdata']] +source_directories = [ + os.path.sep.join(x.split('/')) + for x in ['.', 'third_party/icu/common', 'third_party/icu/i18n', 'third_party/icu/stubdata'] +] source_files = [] base_path = os.path.dirname(os.path.abspath(__file__)) for dir in source_directories: - source_files += [os.path.join('extension', 'icu', dir, x) for x in os.listdir(os.path.join(base_path, dir)) if x.endswith('.cpp')] + source_files += [ + os.path.join('extension', 'icu', dir, x) for x in os.listdir(os.path.join(base_path, dir)) if x.endswith('.cpp') + ] diff --git a/extension/icu/scripts/inline-data.py b/extension/icu/scripts/inline-data.py index 6d7f0fbd654b..15e080efa136 100644 --- a/extension/icu/scripts/inline-data.py +++ b/extension/icu/scripts/inline-data.py @@ -15,6 +15,8 @@ extern "C" U_EXPORT const unsigned char U_ICUDATA_ENTRY_POINT [] = { %s }; -""" % (result_text,) +""" % ( + result_text, +) sys.stdout.write(new_contents) diff --git a/extension/jemalloc/jemalloc_config.py b/extension/jemalloc/jemalloc_config.py index 891d3a495498..873b8bb43008 100644 --- a/extension/jemalloc/jemalloc_config.py +++ b/extension/jemalloc/jemalloc_config.py @@ -1,5 +1,75 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/jemalloc/include', 'extension/jemalloc/jemalloc/include']] +include_directories = [ + os.path.sep.join(x.split('/')) for x in ['extension/jemalloc/include', 'extension/jemalloc/jemalloc/include'] +] # source files -source_files = [os.path.sep.join(x.split('/')) for x in ['extension/jemalloc/jemalloc_extension.cpp', 'extension/jemalloc/jemalloc/src/arena.cpp', 'extension/jemalloc/jemalloc/src/background_thread.cpp', 'extension/jemalloc/jemalloc/src/base.cpp', 'extension/jemalloc/jemalloc/src/bin.cpp', 'extension/jemalloc/jemalloc/src/bin_info.cpp', 'extension/jemalloc/jemalloc/src/bitmap.cpp', 'extension/jemalloc/jemalloc/src/buf_writer.cpp', 'extension/jemalloc/jemalloc/src/cache_bin.cpp', 'extension/jemalloc/jemalloc/src/ckh.cpp', 'extension/jemalloc/jemalloc/src/counter.cpp', 'extension/jemalloc/jemalloc/src/ctl.cpp', 'extension/jemalloc/jemalloc/src/decay.cpp', 'extension/jemalloc/jemalloc/src/div.cpp', 'extension/jemalloc/jemalloc/src/ecache.cpp', 'extension/jemalloc/jemalloc/src/edata.cpp', 'extension/jemalloc/jemalloc/src/edata_cache.cpp', 'extension/jemalloc/jemalloc/src/ehooks.cpp', 'extension/jemalloc/jemalloc/src/emap.cpp', 'extension/jemalloc/jemalloc/src/eset.cpp', 'extension/jemalloc/jemalloc/src/exp_grow.cpp', 'extension/jemalloc/jemalloc/src/extent.cpp', 'extension/jemalloc/jemalloc/src/extent_dss.cpp', 'extension/jemalloc/jemalloc/src/extent_mmap.cpp', 'extension/jemalloc/jemalloc/src/fxp.cpp', 'extension/jemalloc/jemalloc/src/hook.cpp', 'extension/jemalloc/jemalloc/src/hpa.cpp', 'extension/jemalloc/jemalloc/src/hpa_hooks.cpp', 'extension/jemalloc/jemalloc/src/hpdata.cpp', 'extension/jemalloc/jemalloc/src/inspect.cpp', 'extension/jemalloc/jemalloc/src/jemalloc.cpp', 'extension/jemalloc/jemalloc/src/large.cpp', 'extension/jemalloc/jemalloc/src/log.cpp', 'extension/jemalloc/jemalloc/src/malloc_io.cpp', 'extension/jemalloc/jemalloc/src/mutex.cpp', 'extension/jemalloc/jemalloc/src/nstime.cpp', 'extension/jemalloc/jemalloc/src/pa.cpp', 'extension/jemalloc/jemalloc/src/pa_extra.cpp', 'extension/jemalloc/jemalloc/src/pac.cpp', 'extension/jemalloc/jemalloc/src/pages.cpp', 'extension/jemalloc/jemalloc/src/pai.cpp', 'extension/jemalloc/jemalloc/src/peak_event.cpp', 'extension/jemalloc/jemalloc/src/prof.cpp', 'extension/jemalloc/jemalloc/src/prof_data.cpp', 'extension/jemalloc/jemalloc/src/prof_log.cpp', 'extension/jemalloc/jemalloc/src/prof_recent.cpp', 'extension/jemalloc/jemalloc/src/prof_stats.cpp', 'extension/jemalloc/jemalloc/src/prof_sys.cpp', 'extension/jemalloc/jemalloc/src/psset.cpp', 'extension/jemalloc/jemalloc/src/rtree.cpp', 'extension/jemalloc/jemalloc/src/safety_check.cpp', 'extension/jemalloc/jemalloc/src/san.cpp', 'extension/jemalloc/jemalloc/src/san_bump.cpp', 'extension/jemalloc/jemalloc/src/sc.cpp', 'extension/jemalloc/jemalloc/src/sec.cpp', 'extension/jemalloc/jemalloc/src/stats.cpp', 'extension/jemalloc/jemalloc/src/sz.cpp', 'extension/jemalloc/jemalloc/src/tcache.cpp', 'extension/jemalloc/jemalloc/src/test_hooks.cpp', 'extension/jemalloc/jemalloc/src/thread_event.cpp', 'extension/jemalloc/jemalloc/src/ticker.cpp', 'extension/jemalloc/jemalloc/src/tsd.cpp', 'extension/jemalloc/jemalloc/src/witness.cpp']] +source_files = [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/jemalloc/jemalloc_extension.cpp', + 'extension/jemalloc/jemalloc/src/arena.cpp', + 'extension/jemalloc/jemalloc/src/background_thread.cpp', + 'extension/jemalloc/jemalloc/src/base.cpp', + 'extension/jemalloc/jemalloc/src/bin.cpp', + 'extension/jemalloc/jemalloc/src/bin_info.cpp', + 'extension/jemalloc/jemalloc/src/bitmap.cpp', + 'extension/jemalloc/jemalloc/src/buf_writer.cpp', + 'extension/jemalloc/jemalloc/src/cache_bin.cpp', + 'extension/jemalloc/jemalloc/src/ckh.cpp', + 'extension/jemalloc/jemalloc/src/counter.cpp', + 'extension/jemalloc/jemalloc/src/ctl.cpp', + 'extension/jemalloc/jemalloc/src/decay.cpp', + 'extension/jemalloc/jemalloc/src/div.cpp', + 'extension/jemalloc/jemalloc/src/ecache.cpp', + 'extension/jemalloc/jemalloc/src/edata.cpp', + 'extension/jemalloc/jemalloc/src/edata_cache.cpp', + 'extension/jemalloc/jemalloc/src/ehooks.cpp', + 'extension/jemalloc/jemalloc/src/emap.cpp', + 'extension/jemalloc/jemalloc/src/eset.cpp', + 'extension/jemalloc/jemalloc/src/exp_grow.cpp', + 'extension/jemalloc/jemalloc/src/extent.cpp', + 'extension/jemalloc/jemalloc/src/extent_dss.cpp', + 'extension/jemalloc/jemalloc/src/extent_mmap.cpp', + 'extension/jemalloc/jemalloc/src/fxp.cpp', + 'extension/jemalloc/jemalloc/src/hook.cpp', + 'extension/jemalloc/jemalloc/src/hpa.cpp', + 'extension/jemalloc/jemalloc/src/hpa_hooks.cpp', + 'extension/jemalloc/jemalloc/src/hpdata.cpp', + 'extension/jemalloc/jemalloc/src/inspect.cpp', + 'extension/jemalloc/jemalloc/src/jemalloc.cpp', + 'extension/jemalloc/jemalloc/src/large.cpp', + 'extension/jemalloc/jemalloc/src/log.cpp', + 'extension/jemalloc/jemalloc/src/malloc_io.cpp', + 'extension/jemalloc/jemalloc/src/mutex.cpp', + 'extension/jemalloc/jemalloc/src/nstime.cpp', + 'extension/jemalloc/jemalloc/src/pa.cpp', + 'extension/jemalloc/jemalloc/src/pa_extra.cpp', + 'extension/jemalloc/jemalloc/src/pac.cpp', + 'extension/jemalloc/jemalloc/src/pages.cpp', + 'extension/jemalloc/jemalloc/src/pai.cpp', + 'extension/jemalloc/jemalloc/src/peak_event.cpp', + 'extension/jemalloc/jemalloc/src/prof.cpp', + 'extension/jemalloc/jemalloc/src/prof_data.cpp', + 'extension/jemalloc/jemalloc/src/prof_log.cpp', + 'extension/jemalloc/jemalloc/src/prof_recent.cpp', + 'extension/jemalloc/jemalloc/src/prof_stats.cpp', + 'extension/jemalloc/jemalloc/src/prof_sys.cpp', + 'extension/jemalloc/jemalloc/src/psset.cpp', + 'extension/jemalloc/jemalloc/src/rtree.cpp', + 'extension/jemalloc/jemalloc/src/safety_check.cpp', + 'extension/jemalloc/jemalloc/src/san.cpp', + 'extension/jemalloc/jemalloc/src/san_bump.cpp', + 'extension/jemalloc/jemalloc/src/sc.cpp', + 'extension/jemalloc/jemalloc/src/sec.cpp', + 'extension/jemalloc/jemalloc/src/stats.cpp', + 'extension/jemalloc/jemalloc/src/sz.cpp', + 'extension/jemalloc/jemalloc/src/tcache.cpp', + 'extension/jemalloc/jemalloc/src/test_hooks.cpp', + 'extension/jemalloc/jemalloc/src/thread_event.cpp', + 'extension/jemalloc/jemalloc/src/ticker.cpp', + 'extension/jemalloc/jemalloc/src/tsd.cpp', + 'extension/jemalloc/jemalloc/src/witness.cpp', + ] +] diff --git a/extension/json/json_config.py b/extension/json/json_config.py index 91f4fab32f70..d3003e20c7ac 100644 --- a/extension/json/json_config.py +++ b/extension/json/json_config.py @@ -1,5 +1,34 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/json/include', 'extension/json/yyjson/include']] +include_directories = [ + os.path.sep.join(x.split('/')) for x in ['extension/json/include', 'extension/json/yyjson/include'] +] # source files -source_files = [os.path.sep.join(x.split('/')) for x in ['extension/json/buffered_json_reader.cpp', 'extension/json/json_extension.cpp', 'extension/json/json_common.cpp', 'extension/json/json_functions.cpp', 'extension/json/json_scan.cpp', 'extension/json/json_functions/copy_json.cpp', 'extension/json/json_functions/json_array_length.cpp', 'extension/json/json_functions/json_contains.cpp', 'extension/json/json_functions/json_extract.cpp', 'extension/json/json_functions/json_keys.cpp', 'extension/json/json_functions/json_merge_patch.cpp', 'extension/json/json_functions/json_structure.cpp', 'extension/json/json_functions/json_transform.cpp', 'extension/json/json_functions/json_create.cpp', 'extension/json/json_functions/json_type.cpp', 'extension/json/json_functions/json_valid.cpp', 'extension/json/json_functions/read_json_objects.cpp', 'extension/json/json_functions/read_json.cpp', 'extension/json/yyjson/yyjson.cpp', 'extension/json/json_functions/json_serialize_sql.cpp', 'extension/json/json_serializer.cpp', 'extension/json/json_deserializer.cpp']] +source_files = [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/json/buffered_json_reader.cpp', + 'extension/json/json_extension.cpp', + 'extension/json/json_common.cpp', + 'extension/json/json_functions.cpp', + 'extension/json/json_scan.cpp', + 'extension/json/json_functions/copy_json.cpp', + 'extension/json/json_functions/json_array_length.cpp', + 'extension/json/json_functions/json_contains.cpp', + 'extension/json/json_functions/json_extract.cpp', + 'extension/json/json_functions/json_keys.cpp', + 'extension/json/json_functions/json_merge_patch.cpp', + 'extension/json/json_functions/json_structure.cpp', + 'extension/json/json_functions/json_transform.cpp', + 'extension/json/json_functions/json_create.cpp', + 'extension/json/json_functions/json_type.cpp', + 'extension/json/json_functions/json_valid.cpp', + 'extension/json/json_functions/read_json_objects.cpp', + 'extension/json/json_functions/read_json.cpp', + 'extension/json/yyjson/yyjson.cpp', + 'extension/json/json_functions/json_serialize_sql.cpp', + 'extension/json/json_serializer.cpp', + 'extension/json/json_deserializer.cpp', + ] +] diff --git a/extension/parquet/parquet_config.py b/extension/parquet/parquet_config.py index 239a488a3cd1..0848ff7f06ff 100644 --- a/extension/parquet/parquet_config.py +++ b/extension/parquet/parquet_config.py @@ -1,9 +1,72 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/parquet/include', 'third_party/parquet', 'third_party/snappy', 'third_party/thrift', 'third_party/zstd/include']] +include_directories = [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/parquet/include', + 'third_party/parquet', + 'third_party/snappy', + 'third_party/thrift', + 'third_party/zstd/include', + ] +] # source files -source_files = [os.path.sep.join(x.split('/')) for x in ['extension/parquet/parquet_extension.cpp', 'extension/parquet/column_writer.cpp', 'third_party/parquet/parquet_constants.cpp', 'third_party/parquet/parquet_types.cpp', 'third_party/thrift/thrift/protocol/TProtocol.cpp', 'third_party/thrift/thrift/transport/TTransportException.cpp', 'third_party/thrift/thrift/transport/TBufferTransports.cpp', 'third_party/snappy/snappy.cc', 'third_party/snappy/snappy-sinksource.cc']] +source_files = [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/parquet/parquet_extension.cpp', + 'extension/parquet/column_writer.cpp', + 'third_party/parquet/parquet_constants.cpp', + 'third_party/parquet/parquet_types.cpp', + 'third_party/thrift/thrift/protocol/TProtocol.cpp', + 'third_party/thrift/thrift/transport/TTransportException.cpp', + 'third_party/thrift/thrift/transport/TBufferTransports.cpp', + 'third_party/snappy/snappy.cc', + 'third_party/snappy/snappy-sinksource.cc', + ] +] # zstd -source_files += [os.path.sep.join(x.split('/')) for x in ['third_party/zstd/decompress/zstd_ddict.cpp', 'third_party/zstd/decompress/huf_decompress.cpp', 'third_party/zstd/decompress/zstd_decompress.cpp', 'third_party/zstd/decompress/zstd_decompress_block.cpp', 'third_party/zstd/common/entropy_common.cpp', 'third_party/zstd/common/fse_decompress.cpp', 'third_party/zstd/common/zstd_common.cpp', 'third_party/zstd/common/error_private.cpp', 'third_party/zstd/common/xxhash.cpp']] -source_files += [os.path.sep.join(x.split('/')) for x in ['third_party/zstd/compress/fse_compress.cpp', 'third_party/zstd/compress/hist.cpp', 'third_party/zstd/compress/huf_compress.cpp', 'third_party/zstd/compress/zstd_compress.cpp', 'third_party/zstd/compress/zstd_compress_literals.cpp', 'third_party/zstd/compress/zstd_compress_sequences.cpp', 'third_party/zstd/compress/zstd_compress_superblock.cpp', 'third_party/zstd/compress/zstd_double_fast.cpp', 'third_party/zstd/compress/zstd_fast.cpp', 'third_party/zstd/compress/zstd_lazy.cpp', 'third_party/zstd/compress/zstd_ldm.cpp', 'third_party/zstd/compress/zstd_opt.cpp']] -source_files += [os.path.sep.join(x.split('/')) for x in ['extension/parquet/parquet_reader.cpp', 'extension/parquet/parquet_timestamp.cpp', 'extension/parquet/parquet_writer.cpp', 'extension/parquet/column_reader.cpp', 'extension/parquet/parquet_statistics.cpp', 'extension/parquet/parquet_metadata.cpp', 'extension/parquet/zstd_file_system.cpp']] +source_files += [ + os.path.sep.join(x.split('/')) + for x in [ + 'third_party/zstd/decompress/zstd_ddict.cpp', + 'third_party/zstd/decompress/huf_decompress.cpp', + 'third_party/zstd/decompress/zstd_decompress.cpp', + 'third_party/zstd/decompress/zstd_decompress_block.cpp', + 'third_party/zstd/common/entropy_common.cpp', + 'third_party/zstd/common/fse_decompress.cpp', + 'third_party/zstd/common/zstd_common.cpp', + 'third_party/zstd/common/error_private.cpp', + 'third_party/zstd/common/xxhash.cpp', + ] +] +source_files += [ + os.path.sep.join(x.split('/')) + for x in [ + 'third_party/zstd/compress/fse_compress.cpp', + 'third_party/zstd/compress/hist.cpp', + 'third_party/zstd/compress/huf_compress.cpp', + 'third_party/zstd/compress/zstd_compress.cpp', + 'third_party/zstd/compress/zstd_compress_literals.cpp', + 'third_party/zstd/compress/zstd_compress_sequences.cpp', + 'third_party/zstd/compress/zstd_compress_superblock.cpp', + 'third_party/zstd/compress/zstd_double_fast.cpp', + 'third_party/zstd/compress/zstd_fast.cpp', + 'third_party/zstd/compress/zstd_lazy.cpp', + 'third_party/zstd/compress/zstd_ldm.cpp', + 'third_party/zstd/compress/zstd_opt.cpp', + ] +] +source_files += [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/parquet/parquet_reader.cpp', + 'extension/parquet/parquet_timestamp.cpp', + 'extension/parquet/parquet_writer.cpp', + 'extension/parquet/column_reader.cpp', + 'extension/parquet/parquet_statistics.cpp', + 'extension/parquet/parquet_metadata.cpp', + 'extension/parquet/zstd_file_system.cpp', + ] +] diff --git a/extension/tpcds/tpcds_config.py b/extension/tpcds/tpcds_config.py index f335fa765dc5..fc463364bc0e 100644 --- a/extension/tpcds/tpcds_config.py +++ b/extension/tpcds/tpcds_config.py @@ -1,7 +1,73 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/tpcds/include', 'extension/tpcds/dsdgen/include', 'extension/tpcds/dsdgen/include/dsdgen-c']] +include_directories = [ + os.path.sep.join(x.split('/')) + for x in ['extension/tpcds/include', 'extension/tpcds/dsdgen/include', 'extension/tpcds/dsdgen/include/dsdgen-c'] +] # source files source_files = [os.path.sep.join(x.split('/')) for x in ['extension/tpcds/tpcds_extension.cpp']] -source_files += [os.path.sep.join(x.split('/')) for x in ['extension/tpcds/dsdgen/dsdgen.cpp', 'extension/tpcds/dsdgen/append_info-c.cpp', 'extension/tpcds/dsdgen/dsdgen_helpers.cpp']] -source_files += [os.path.sep.join(x.split('/')) for x in ['extension/tpcds/dsdgen/dsdgen-c/skip_days.cpp', 'extension/tpcds/dsdgen/dsdgen-c/address.cpp', 'extension/tpcds/dsdgen/dsdgen-c/build_support.cpp', 'extension/tpcds/dsdgen/dsdgen-c/date.cpp', 'extension/tpcds/dsdgen/dsdgen-c/dbgen_version.cpp', 'extension/tpcds/dsdgen/dsdgen-c/decimal.cpp', 'extension/tpcds/dsdgen/dsdgen-c/dist.cpp', 'extension/tpcds/dsdgen/dsdgen-c/error_msg.cpp', 'extension/tpcds/dsdgen/dsdgen-c/genrand.cpp', 'extension/tpcds/dsdgen/dsdgen-c/join.cpp', 'extension/tpcds/dsdgen/dsdgen-c/list.cpp', 'extension/tpcds/dsdgen/dsdgen-c/load.cpp', 'extension/tpcds/dsdgen/dsdgen-c/misc.cpp', 'extension/tpcds/dsdgen/dsdgen-c/nulls.cpp', 'extension/tpcds/dsdgen/dsdgen-c/parallel.cpp', 'extension/tpcds/dsdgen/dsdgen-c/permute.cpp', 'extension/tpcds/dsdgen/dsdgen-c/pricing.cpp', 'extension/tpcds/dsdgen/dsdgen-c/r_params.cpp', 'extension/tpcds/dsdgen/dsdgen-c/release.cpp', 'extension/tpcds/dsdgen/dsdgen-c/scaling.cpp', 'extension/tpcds/dsdgen/dsdgen-c/scd.cpp', 'extension/tpcds/dsdgen/dsdgen-c/sparse.cpp', 'extension/tpcds/dsdgen/dsdgen-c/StringBuffer.cpp', 'extension/tpcds/dsdgen/dsdgen-c/tdef_functions.cpp', 'extension/tpcds/dsdgen/dsdgen-c/tdefs.cpp', 'extension/tpcds/dsdgen/dsdgen-c/text.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_call_center.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_catalog_page.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_catalog_returns.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_catalog_sales.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_customer.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_customer_address.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_customer_demographics.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_datetbl.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_household_demographics.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_income_band.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_inventory.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_item.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_promotion.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_reason.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_ship_mode.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_store.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_store_returns.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_store_sales.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_timetbl.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_warehouse.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_web_page.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_web_returns.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_web_sales.cpp', 'extension/tpcds/dsdgen/dsdgen-c/w_web_site.cpp', 'extension/tpcds/dsdgen/dsdgen-c/init.cpp']] +source_files += [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/tpcds/dsdgen/dsdgen.cpp', + 'extension/tpcds/dsdgen/append_info-c.cpp', + 'extension/tpcds/dsdgen/dsdgen_helpers.cpp', + ] +] +source_files += [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/tpcds/dsdgen/dsdgen-c/skip_days.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/address.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/build_support.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/date.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/dbgen_version.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/decimal.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/dist.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/error_msg.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/genrand.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/join.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/list.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/load.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/misc.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/nulls.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/parallel.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/permute.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/pricing.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/r_params.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/release.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/scaling.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/scd.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/sparse.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/StringBuffer.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/tdef_functions.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/tdefs.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/text.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_call_center.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_catalog_page.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_catalog_returns.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_catalog_sales.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_customer.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_customer_address.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_customer_demographics.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_datetbl.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_household_demographics.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_income_band.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_inventory.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_item.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_promotion.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_reason.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_ship_mode.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_store.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_store_returns.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_store_sales.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_timetbl.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_warehouse.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_web_page.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_web_returns.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_web_sales.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/w_web_site.cpp', + 'extension/tpcds/dsdgen/dsdgen-c/init.cpp', + ] +] diff --git a/extension/tpch/tpch_config.py b/extension/tpch/tpch_config.py index dfdc833c4128..4155ebe2732f 100644 --- a/extension/tpch/tpch_config.py +++ b/extension/tpch/tpch_config.py @@ -1,5 +1,22 @@ import os + # list all include directories -include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/tpch/include', 'extension/tpch/dbgen/include']] +include_directories = [ + os.path.sep.join(x.split('/')) for x in ['extension/tpch/include', 'extension/tpch/dbgen/include'] +] # source files -source_files = [os.path.sep.join(x.split('/')) for x in ['extension/tpch/tpch_extension.cpp', 'extension/tpch/dbgen/bm_utils.cpp', 'extension/tpch/dbgen/build.cpp', 'extension/tpch/dbgen/dbgen.cpp', 'extension/tpch/dbgen/dbgen_gunk.cpp', 'extension/tpch/dbgen/permute.cpp', 'extension/tpch/dbgen/rnd.cpp', 'extension/tpch/dbgen/rng64.cpp', 'extension/tpch/dbgen/speed_seed.cpp', 'extension/tpch/dbgen/text.cpp']] +source_files = [ + os.path.sep.join(x.split('/')) + for x in [ + 'extension/tpch/tpch_extension.cpp', + 'extension/tpch/dbgen/bm_utils.cpp', + 'extension/tpch/dbgen/build.cpp', + 'extension/tpch/dbgen/dbgen.cpp', + 'extension/tpch/dbgen/dbgen_gunk.cpp', + 'extension/tpch/dbgen/permute.cpp', + 'extension/tpch/dbgen/rnd.cpp', + 'extension/tpch/dbgen/rng64.cpp', + 'extension/tpch/dbgen/speed_seed.cpp', + 'extension/tpch/dbgen/text.cpp', + ] +] diff --git a/extension/visualizer/generate_visualizer_header.py b/extension/visualizer/generate_visualizer_header.py index f24b88cb32a8..d57ec41cf972 100644 --- a/extension/visualizer/generate_visualizer_header.py +++ b/extension/visualizer/generate_visualizer_header.py @@ -7,14 +7,17 @@ visualizer_script = os.path.join(visualizer_dir, 'script.js') visualizer_header = os.path.join(visualizer_dir, 'include', 'visualizer_constants.hpp') + def open_utf8(fpath, flags): import sys + if sys.version_info[0] < 3: return open(fpath, flags) else: return open(fpath, flags, encoding="utf8") -def get_byte_array(fpath, add_null_terminator = True): + +def get_byte_array(fpath, add_null_terminator=True): with open(fpath, 'rb') as f: text = bytearray(f.read()) result_text = "" @@ -29,6 +32,7 @@ def get_byte_array(fpath, add_null_terminator = True): result_text += ", 0" return result_text + def write_file(fname, varname): result = "const uint8_t %s[] = {" % (varname,) + get_byte_array(fname) + "};\n" return result @@ -78,4 +82,5 @@ def create_visualizer_header(): with open_utf8(visualizer_header, 'w+') as f: f.write(result) -create_visualizer_header() \ No newline at end of file + +create_visualizer_header() diff --git a/extension/visualizer/visualizer_config.py b/extension/visualizer/visualizer_config.py index 6f14a4080e22..d7c2456810dd 100644 --- a/extension/visualizer/visualizer_config.py +++ b/extension/visualizer/visualizer_config.py @@ -1,4 +1,5 @@ import os + # list all include directories include_directories = [os.path.sep.join(x.split('/')) for x in ['extension/visualizer/include']] # source files diff --git a/test/memoryleak/test_memory_leaks.py b/test/memoryleak/test_memory_leaks.py index a0aac2400a8f..96ed298b24da 100644 --- a/test/memoryleak/test_memory_leaks.py +++ b/test/memoryleak/test_memory_leaks.py @@ -5,18 +5,36 @@ parser = argparse.ArgumentParser(description='Runs the memory leak tests') -parser.add_argument('--unittest', dest='unittest', - action='store', help='Path to unittest executable', default='build/release/test/unittest') -parser.add_argument('--test', dest='test', - action='store', help='The name of the tests to run (* is all)', default='*') -parser.add_argument('--timeout', dest='timeout', - action='store', help='The maximum time to run the test and measure memory usage (in seconds)', default=60) -parser.add_argument('--threshold-percentage', dest='threshold_percentage', - action='store', help='The percentage threshold before we consider an increase a regression', default=0.01) -parser.add_argument('--threshold-absolute', dest='threshold_absolute', - action='store', help='The absolute threshold before we consider an increase a regression', default=1000) -parser.add_argument('--verbose', dest='verbose', - action='store', help='Verbose output', default=True) +parser.add_argument( + '--unittest', + dest='unittest', + action='store', + help='Path to unittest executable', + default='build/release/test/unittest', +) +parser.add_argument('--test', dest='test', action='store', help='The name of the tests to run (* is all)', default='*') +parser.add_argument( + '--timeout', + dest='timeout', + action='store', + help='The maximum time to run the test and measure memory usage (in seconds)', + default=60, +) +parser.add_argument( + '--threshold-percentage', + dest='threshold_percentage', + action='store', + help='The percentage threshold before we consider an increase a regression', + default=0.01, +) +parser.add_argument( + '--threshold-absolute', + dest='threshold_absolute', + action='store', + help='The absolute threshold before we consider an increase a regression', + default=1000, +) +parser.add_argument('--verbose', dest='verbose', action='store', help='Verbose output', default=True) args = parser.parse_args() @@ -53,6 +71,7 @@ print(f"No tests matching filter \"{test_filter}\" found") exit(0) + def sizeof_fmt(num, suffix="B"): for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: if abs(num) < 1000.0: @@ -60,9 +79,12 @@ def sizeof_fmt(num, suffix="B"): num /= 1000.0 return f"{num:.1f}Yi{suffix}" + def run_test(test_case): # launch the unittest program - proc = subprocess.Popen([unittest_program, test_case, '--memory-leak-tests'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + proc = subprocess.Popen( + [unittest_program, test_case, '--memory-leak-tests'], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) pid = proc.pid # capture the memory output for the duration of the program running @@ -125,6 +147,7 @@ def has_memory_leak(rss): sum_differences = sum(differences[-measurement_count:]) return sum_differences > (max_memory * args.threshold_percentage + args.threshold_absolute) + try: for index, test in enumerate(test_cases): print(f"[{index}/{len(test_cases)}] {test}") diff --git a/test/parquet/generate_parquet_test.py b/test/parquet/generate_parquet_test.py index 7fcd452afea6..075ece1226b4 100644 --- a/test/parquet/generate_parquet_test.py +++ b/test/parquet/generate_parquet_test.py @@ -1,15 +1,19 @@ import duckdb import os import sys + try: import pyarrow import pyarrow.parquet + can_run = True except: can_run = False + def generate_header(f): - f.write('''# name: test/parquet/test_parquet_reader.test + f.write( + '''# name: test/parquet/test_parquet_reader.test # description: Test Parquet Reader with files on data/parquet-testing # group: [parquet] @@ -18,29 +22,33 @@ def generate_header(f): statement ok PRAGMA enable_verification -''') +''' + ) + def get_files(): files_path = [] path = os.path.dirname(os.path.realpath(__file__)) - path = os.path.join(path,'..','..') + path = os.path.join(path, '..', '..') os.chdir(path) - path = os.path.join('data','parquet-testing') + path = os.path.join('data', 'parquet-testing') for root, dirs, files in os.walk(path): for file in files: - if file.endswith(".parquet"): + if file.endswith(".parquet"): files_path.append(os.path.join(root, file)) return files_path + def get_duckdb_answer(file_path): answer = [] try: - answer = duckdb.query("SELECT * FROM parquet_scan('"+file_path+"') limit 50").fetchall() + answer = duckdb.query("SELECT * FROM parquet_scan('" + file_path + "') limit 50").fetchall() except Exception as e: print(e) answer = 'fail' return answer + def get_arrow_answer(file_path): answer = [] try: @@ -51,60 +59,60 @@ def get_arrow_answer(file_path): except: return 'fail' - def check_result(duckdb_result, arrow_result): - if (arrow_result == 'fail'): + if arrow_result == 'fail': return 'skip' - if (duckdb_result == 'fail'): + if duckdb_result == 'fail': return 'fail' - if (duckdb_result != arrow_result): + if duckdb_result != arrow_result: return 'fail' return 'pass' + def sanitize_string(s): - return str(s).replace('None','NULL').replace("b'","").replace("'","") + return str(s).replace('None', 'NULL').replace("b'", "").replace("'", "") + def result_to_string(arrow_result): result = '' for row_idx in range(len(arrow_result)): - for col_idx in range(len(arrow_result[0])): + for col_idx in range(len(arrow_result[0])): value = arrow_result[row_idx][col_idx] if isinstance(value, dict): - items = [ - f"'{k}': {sanitize_string(v)}" # no quotes - for k, v in value.items() - ] + items = [f"'{k}': {sanitize_string(v)}" for k, v in value.items()] # no quotes value = "{" + ", ".join(items) + "}" print(type(value), value) else: value = sanitize_string(value) result += value + "\t" - result +="\n" + result += "\n" result += "\n" return result -def generate_parquet_test_body(result, arrow_result,file_path): - columns = 'I'*len(arrow_result[0]) - test_body = "query " + columns + "\n" - test_body += "SELECT * FROM parquet_scan('"+file_path+"') limit 50 \n" - test_body += "----\n" + +def generate_parquet_test_body(result, arrow_result, file_path): + columns = 'I' * len(arrow_result[0]) + test_body = "query " + columns + "\n" + test_body += "SELECT * FROM parquet_scan('" + file_path + "') limit 50 \n" + test_body += "----\n" test_body += result_to_string(arrow_result) return test_body + def generate_test(file_path): duckdb_result = get_duckdb_answer(file_path) arrow_result = get_arrow_answer(file_path) - result = check_result(duckdb_result,arrow_result) + result = check_result(duckdb_result, arrow_result) test_body = "" - if (result == 'skip'): + if result == 'skip': return - if (result == 'fail'): + if result == 'fail': test_body += "mode skip \n\n" - test_body += generate_parquet_test_body(result,arrow_result,file_path) + test_body += generate_parquet_test_body(result, arrow_result, file_path) test_body += "mode unskip \n\n" else: - test_body += generate_parquet_test_body(result,duckdb_result,file_path) + test_body += generate_parquet_test_body(result, duckdb_result, file_path) return test_body @@ -116,6 +124,7 @@ def generate_body(f): if test_body != None: f.write(test_body) + f = open("test_parquet_reader.test", "w") generate_header(f) diff --git a/test/sqlserver/scrape.py b/test/sqlserver/scrape.py index 104de437bcb3..32d89f75c54a 100644 --- a/test/sqlserver/scrape.py +++ b/test/sqlserver/scrape.py @@ -3,91 +3,109 @@ import re - -pages = ['avg', 'count', 'max', 'min', 'stdev', 'sum', 'var', 'cume-dist', 'first-value', 'last-value', 'lag', 'lead', 'percent-rank', 'dense-rank', 'ntile', 'rank', 'row-number'] +pages = [ + 'avg', + 'count', + 'max', + 'min', + 'stdev', + 'sum', + 'var', + 'cume-dist', + 'first-value', + 'last-value', + 'lag', + 'lead', + 'percent-rank', + 'dense-rank', + 'ntile', + 'rank', + 'row-number', +] # crash! # url = 'https://docs.microsoft.com/en-us/sql/t-sql/functions/first-value-transact-sql?view=sql-server-2017' url = 'https://docs.microsoft.com/en-us/sql/t-sql/functions/%s-transact-sql?view=sql-server-2017' + def transform_result_set(tblstr): - # find the row with the --- --- etc that indicates col count - cols = len(re.findall(r"-{3,}( |\n)", tblstr)) - if (cols < 1): - return "" - print("REQUIRE(result->ColumnCount() == %d);" % cols) - lineiterator = iter(tblstr.splitlines()) - in_data = False - result = [] - for c in range(0, cols): - result.append([]) - - - for line in lineiterator: - if '---' in line: - in_data = True - continue - if not in_data: - continue - if 'row(s) affected' in line: - continue - if re.match(r"^\s*$", line): - continue - - # now we have a real data line, split by space and trim - fields = re.split(r"\s{2,}",line) - if len(fields) < cols: - raise ValueError('Not enough fields') - - # print("// ", end='') - for c in range(0, cols): - f = fields[c].strip() - if f == '': - raise ValueError('Empty field') - if re.match(r"^\d+[\d,]*\.\d+$", f): - f = f.replace(',', '') - # print(f + '\t', end='') - needs_quotes = False - try: - float(f) - except ValueError: - needs_quotes = True - - if (f == "NULL"): - f = 'Value()' - needs_quotes = False - - if needs_quotes: - f = '"%s"' % f - result[c].append(f) - #print() - - for c in range(0, cols): - print('REQUIRE(CHECK_COLUMN(result, %d, {%s}));' % (c, ','.join(result[c]))) + # find the row with the --- --- etc that indicates col count + cols = len(re.findall(r"-{3,}( |\n)", tblstr)) + if cols < 1: + return "" + print("REQUIRE(result->ColumnCount() == %d);" % cols) + lineiterator = iter(tblstr.splitlines()) + in_data = False + result = [] + for c in range(0, cols): + result.append([]) + + for line in lineiterator: + if '---' in line: + in_data = True + continue + if not in_data: + continue + if 'row(s) affected' in line: + continue + if re.match(r"^\s*$", line): + continue + + # now we have a real data line, split by space and trim + fields = re.split(r"\s{2,}", line) + if len(fields) < cols: + raise ValueError('Not enough fields') + + # print("// ", end='') + for c in range(0, cols): + f = fields[c].strip() + if f == '': + raise ValueError('Empty field') + if re.match(r"^\d+[\d,]*\.\d+$", f): + f = f.replace(',', '') + # print(f + '\t', end='') + needs_quotes = False + try: + float(f) + except ValueError: + needs_quotes = True + + if f == "NULL": + f = 'Value()' + needs_quotes = False + + if needs_quotes: + f = '"%s"' % f + result[c].append(f) + # print() + + for c in range(0, cols): + print('REQUIRE(CHECK_COLUMN(result, %d, {%s}));' % (c, ','.join(result[c]))) + for p in pages: - r = requests.get(url % p) - print('\n\n// FROM %s\n' % url % p) - soup = BeautifulSoup(r.content, 'html.parser') - look_for_answer = False - - for code in soup.find_all('code'): - classes = code.get('class') - text = code.get_text() - if (text.count('\n') < 2): - continue - if ('SELECT ' in text and 'FROM ' in text): - if ('dbo.' in text or 'sys.' in text): - continue - query = text.strip() - query = re.sub(r"(^|\n)(GO|USE|DECLARE).*", "", query) - query = query.replace('\n', ' ') - query = re.sub(r"\s+", " ", query) - - print('\n\nresult = con.Query("%s");\nREQUIRE(!result->HasError());' % query.replace('"', '\\"')) - - look_for_answer = True - elif look_for_answer: - #print('-- ' + text.replace('\n', '\n-- ') + '\n') - transform_result_set(text) - look_for_answer = False + r = requests.get(url % p) + print('\n\n// FROM %s\n' % url % p) + soup = BeautifulSoup(r.content, 'html.parser') + look_for_answer = False + + for code in soup.find_all('code'): + classes = code.get('class') + text = code.get_text() + if text.count('\n') < 2: + continue + if 'SELECT ' in text and 'FROM ' in text: + if 'dbo.' in text or 'sys.' in text: + continue + query = text.strip() + query = re.sub(r"(^|\n)(GO|USE|DECLARE).*", "", query) + query = query.replace('\n', ' ') + query = re.sub(r"\s+", " ", query) + + print('\n\nresult = con.Query("%s");\nREQUIRE(!result->HasError());' % query.replace('"', '\\"')) + + look_for_answer = True + elif look_for_answer: + # print('-- ' + text.replace('\n', '\n-- ') + '\n') + transform_result_set(text) + look_for_answer = False diff --git a/tools/juliapkg/release.py b/tools/juliapkg/release.py index 12f887570135..8781c00d0f8f 100644 --- a/tools/juliapkg/release.py +++ b/tools/juliapkg/release.py @@ -4,15 +4,21 @@ import re parser = argparse.ArgumentParser(description='Publish a Julia release.') -parser.add_argument('--yggdrassil-fork', dest='yggdrassil', - action='store', help='Fork of the Julia Yggdrassil repository (https://github.com/JuliaPackaging/Yggdrasil)', default='/Users/myth/Programs/Yggdrasil') +parser.add_argument( + '--yggdrassil-fork', + dest='yggdrassil', + action='store', + help='Fork of the Julia Yggdrassil repository (https://github.com/JuliaPackaging/Yggdrasil)', + default='/Users/myth/Programs/Yggdrasil', +) args = parser.parse_args() if not os.path.isfile(os.path.join('tools', 'juliapkg', 'release.py')): print('This script must be run from the root DuckDB directory (i.e. `python3 tools/juliapkg/release.py`)') exit(1) -def run_syscall(syscall, ignore_failure = False): + +def run_syscall(syscall, ignore_failure=False): res = os.system(syscall) if ignore_failure: return @@ -20,6 +26,7 @@ def run_syscall(syscall, ignore_failure = False): print(f'Failed to execute {syscall}: got exit code {str(res)}') exit(1) + # helper script to generate a julia release duckdb_path = os.getcwd() @@ -51,7 +58,11 @@ def run_syscall(syscall, ignore_failure = False): text = f.read() text = re.sub('\nversion = v["][0-9.]+["]\n', f'\nversion = v"{tag[1:]}"\n', text) -text = re.sub('GitSource[(]["]https[:][/][/]github[.]com[/]duckdb[/]duckdb[.]git["][,] ["][a-zA-Z0-9]+["][)]', f'GitSource("https://github.com/duckdb/duckdb.git", "{hash}")', text) +text = re.sub( + 'GitSource[(]["]https[:][/][/]github[.]com[/]duckdb[/]duckdb[.]git["][,] ["][a-zA-Z0-9]+["][)]', + f'GitSource("https://github.com/duckdb/duckdb.git", "{hash}")', + text, +) with open(tarball_build, 'w+') as f: f.write(text) @@ -59,10 +70,14 @@ def run_syscall(syscall, ignore_failure = False): run_syscall(f'git add {tarball_build}') run_syscall(f'git commit -m "[DuckDB] Bump to {tag}"') run_syscall(f'git push --set-upstream origin {tag}') -run_syscall(f'gh pr create --title "[DuckDB] Bump to {tag}" --repo "https://github.com/JuliaPackaging/Yggdrasil" --body ""') +run_syscall( + f'gh pr create --title "[DuckDB] Bump to {tag}" --repo "https://github.com/JuliaPackaging/Yggdrasil" --body ""' +) print('PR has been created.\n') print(f'Next up we need to bump the version and DuckDB_jll version to {tag} in `tools/juliapkg/Project.toml`') print('This is not yet automated.') -print('> After that PR is merged - we need to post a comment containing the text `@JuliaRegistrator register subdir=tools/juliapkg`') +print( + '> After that PR is merged - we need to post a comment containing the text `@JuliaRegistrator register subdir=tools/juliapkg`' +) print('> For example, see https://github.com/duckdb/duckdb/commit/0f0461113f3341135471805c9928c4d71d1f5874') diff --git a/tools/nodejs/configure.py b/tools/nodejs/configure.py index 754fd090f34a..e4d1b39eae96 100644 --- a/tools/nodejs/configure.py +++ b/tools/nodejs/configure.py @@ -31,6 +31,7 @@ windows_options = cache['windows_options'] cflags = cache['cflags'] elif 'DUCKDB_NODE_BINDIR' in os.environ: + def find_library_path(libdir, libname): flist = os.listdir(libdir) for fname in flist: @@ -38,6 +39,7 @@ def find_library_path(libdir, libname): if os.path.isfile(fpath) and package_build.file_is_lib(fname, libname): return fpath raise Exception(f"Failed to find library {libname} in {libdir}") + # existing build existing_duckdb_dir = os.environ['DUCKDB_NODE_BINDIR'] cflags = os.environ['DUCKDB_NODE_CFLAGS'] @@ -48,7 +50,7 @@ def find_library_path(libdir, libname): result_libraries = package_build.get_libraries(existing_duckdb_dir, libraries, extensions) libraries = [] - for (libdir, libname) in result_libraries: + for libdir, libname in result_libraries: if libdir is None: continue libraries.append(find_library_path(libdir, libname)) @@ -72,7 +74,7 @@ def find_library_path(libdir, libname): 'include_list': include_list, 'libraries': libraries, 'cflags': cflags, - 'windows_options': windows_options + 'windows_options': windows_options, } with open(cache_file, 'wb+') as f: pickle.dump(cache, f) @@ -90,9 +92,11 @@ def find_library_path(libdir, libname): windows_options = ['/GR'] cflags = ['-frtti'] + def sanitize_path(x): return x.replace('\\', '/') + source_list = [sanitize_path(x) for x in source_list] include_list = [sanitize_path(x) for x in include_list] libraries = [sanitize_path(x) for x in libraries] @@ -100,6 +104,7 @@ def sanitize_path(x): with open(gyp_in, 'r') as f: input_json = json.load(f) + def replace_entries(node, replacement_map): if type(node) == type([]): for key in replacement_map.keys(): diff --git a/tools/odbc/test/isql-test.py b/tools/odbc/test/isql-test.py index 23e36c46ed28..5d551b61097f 100644 --- a/tools/odbc/test/isql-test.py +++ b/tools/odbc/test/isql-test.py @@ -5,50 +5,54 @@ import shutil if len(sys.argv) < 2: - raise Exception('need shell binary as parameter') + raise Exception('need shell binary as parameter') -extra_parameter="" +extra_parameter = "" if len(sys.argv) == 3: - extra_parameter = sys.argv[2] + extra_parameter = sys.argv[2] + def test_exception(command, input, stdout, stderr, errmsg): - print('--- COMMAND --') - print(' '.join(command)) - print('--- INPUT --') - print(input) - print('--- STDOUT --') - print(stdout) - print('--- STDERR --') - print(stderr) - raise Exception(errmsg) + print('--- COMMAND --') + print(' '.join(command)) + print('--- INPUT --') + print(input) + print('--- STDOUT --') + print(stdout) + print('--- STDERR --') + print(stderr) + raise Exception(errmsg) + def test(cmd, out=None, err=None, extra_commands=None, input_file=None): - ######### isql "DSN=DuckDB;Database=test.db" -k -b -d'|' /dev/null - command = [sys.argv[1], "DSN=DuckDB;Database=test.db", '-k', '-b', '-d|', '/dev/null'] - if extra_parameter: - command.append(extra_parameter) + ######### isql "DSN=DuckDB;Database=test.db" -k -b -d'|' /dev/null + command = [sys.argv[1], "DSN=DuckDB;Database=test.db", '-k', '-b', '-d|', '/dev/null'] + if extra_parameter: + command.append(extra_parameter) - if extra_commands: - command += extra_commands - if input_file: - command += [cmd] - res = subprocess.run(command, input=open(input_file, 'rb').read(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - else: - res = subprocess.run(command, input=bytearray(cmd, 'utf8'), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout = res.stdout.decode('utf8').strip() - stderr = res.stderr.decode('utf8').strip() + if extra_commands: + command += extra_commands + if input_file: + command += [cmd] + res = subprocess.run( + command, input=open(input_file, 'rb').read(), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + else: + res = subprocess.run(command, input=bytearray(cmd, 'utf8'), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout = res.stdout.decode('utf8').strip() + stderr = res.stderr.decode('utf8').strip() - if out and out not in stdout: - test_exception(command, cmd, stdout, stderr, 'out test failed') + if out and out not in stdout: + test_exception(command, cmd, stdout, stderr, 'out test failed') - if err and err not in stderr: - test_exception(command, cmd, stdout, stderr, 'err test failed') + if err and err not in stderr: + test_exception(command, cmd, stdout, stderr, 'err test failed') - if not err and stderr != '': - test_exception(command, cmd, stdout, stderr, 'got err test failed') + if not err and stderr != '': + test_exception(command, cmd, stdout, stderr, 'got err test failed') - if err is None and res.returncode != 0: - test_exception(command, cmd, stdout, stderr, 'process returned non-zero exit code but no error was specified') + if err is None and res.returncode != 0: + test_exception(command, cmd, stdout, stderr, 'process returned non-zero exit code but no error was specified') # basic tests @@ -56,70 +60,66 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): test('select 42, 43, 44;', out="42|43|44") -test("""CREATE TABLE people(id INTEGER, name VARCHAR); +test( + """CREATE TABLE people(id INTEGER, name VARCHAR); INSERT INTO people VALUES (1, 'Mark'), (2, 'Hannes'); SELECT * FROM people;""", -out=( -"1|Mark\n" -"2|Hannes") + out=("1|Mark\n" "2|Hannes"), ) -range_out="" +range_out = "" for i in range(10000): - if(i < 9999): - range_out += str(i) + "\n" - else: - range_out += str(i) -test("SELECT * FROM range(10000);", out = range_out) + if i < 9999: + range_out += str(i) + "\n" + else: + range_out += str(i) +test("SELECT * FROM range(10000);", out=range_out) # ### FROM test/sql/projection/test_simple_projection.test ################################# -test("""PRAGMA enable_verification +test( + """PRAGMA enable_verification CREATE TABLE a (i integer, j integer); SELECT * FROM a; -""", out="") +""", + out="", +) -test("""INSERT INTO a VALUES (42, 84); +test( + """INSERT INTO a VALUES (42, 84); SELECT * FROM a; -""", out="42|84") +""", + out="42|84", +) -test("""CREATE TABLE test (a INTEGER, b INTEGER); +test( + """CREATE TABLE test (a INTEGER, b INTEGER); INSERT INTO test VALUES (11, 22), (12, 21), (13, 22); -""", out="") - -test('SELECT a, b FROM test;', -out=( -"11|22\n" -"12|21\n" -"13|22") +""", + out="", ) +test('SELECT a, b FROM test;', out=("11|22\n" "12|21\n" "13|22")) + test('SELECT a + 2, b FROM test WHERE a = 11', out='13|22') test('SELECT a + 2, b FROM test WHERE a = 12', out='14|21') -test('SELECT cast(a AS VARCHAR) FROM test;', -out=( -"11\n" -"12\n" -"13") -) +test('SELECT cast(a AS VARCHAR) FROM test;', out=("11\n" "12\n" "13")) -test('SELECT cast(cast(a AS VARCHAR) as INTEGER) FROM test;', -out=( -"11\n" -"12\n" -"13") -) +test('SELECT cast(cast(a AS VARCHAR) as INTEGER) FROM test;', out=("11\n" "12\n" "13")) ### FROM test/sql/types/timestamp/test_timestamp.test ################################# test( -"""CREATE TABLE IF NOT EXISTS timestamp (t TIMESTAMP); + """CREATE TABLE IF NOT EXISTS timestamp (t TIMESTAMP); INSERT INTO timestamp VALUES ('2008-01-01 00:00:01'), (NULL), ('2007-01-01 00:00:01'), ('2008-02-01 00:00:01'), ('2008-01-02 00:00:01'), ('2008-01-01 10:00:00'), ('2008-01-01 00:10:00'), ('2008-01-01 00:00:10') -""") +""" +) test("SELECT timestamp '2017-07-23 13:10:11';", out='2017-07-23 13:10:11') -test("SELECT timestamp '2017-07-23T13:10:11', timestamp '2017-07-23T13:10:11Z';", -out='2017-07-23 13:10:11|2017-07-23 13:10:11') +test( + "SELECT timestamp '2017-07-23T13:10:11', timestamp '2017-07-23T13:10:11Z';", + out='2017-07-23 13:10:11|2017-07-23 13:10:11', +) test("SELECT timestamp ' 2017-07-23 13:10:11 ';", out='2017-07-23 13:10:11') @@ -127,15 +127,17 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): test("SELECT timestamp 'AA2017-07-23 13:10:11';", err="[ISQL]ERROR") test("SELECT timestamp '2017-07-23A13:10:11';", err="[ISQL]ERROR") -test('SELECT t FROM timestamp ORDER BY t;', -out=( -"2007-01-01 00:00:01\n" -"2008-01-01 00:00:01\n" -"2008-01-01 00:00:10\n" -"2008-01-01 00:10:00\n" -"2008-01-01 10:00:00\n" -"2008-01-02 00:00:01\n" -"2008-02-01 00:00:01") +test( + 'SELECT t FROM timestamp ORDER BY t;', + out=( + "2007-01-01 00:00:01\n" + "2008-01-01 00:00:01\n" + "2008-01-01 00:00:10\n" + "2008-01-01 00:10:00\n" + "2008-01-01 10:00:00\n" + "2008-01-02 00:00:01\n" + "2008-02-01 00:00:01" + ), ) test('SELECT MIN(t) FROM timestamp;', out='2007-01-01 00:00:01') @@ -148,16 +150,9 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): test('SELECT t/t FROM timestamp', err="[ISQL]ERROR") test('SELECT t%t FROM timestamp', err="[ISQL]ERROR") -test('SELECT t-t FROM timestamp', -out=( -"00:00:00\n" -"\n" -"00:00:00\n" -"00:00:00\n" -"00:00:00\n" -"00:00:00\n" -"00:00:00\n" -"00:00:00") +test( + 'SELECT t-t FROM timestamp', + out=("00:00:00\n" "\n" "00:00:00\n" "00:00:00\n" "00:00:00\n" "00:00:00\n" "00:00:00\n" "00:00:00"), ) test("SELECT YEAR(TIMESTAMP '1992-01-01 01:01:01');", out='1992') @@ -175,25 +170,14 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): ### FROM test/sql/types/time/test_time.test ################################# test( -"""CREATE TABLE times(i TIME); + """CREATE TABLE times(i TIME); INSERT INTO times VALUES ('00:01:20'), ('20:08:10.998'), ('20:08:10.33'), ('20:08:10.001'), (NULL); -""") - -test("SELECT * FROM times", -out=( -"00:01:20\n" -"20:08:10.998\n" -"20:08:10.33\n" -"20:08:10.001" -)) - -test("SELECT cast(i AS VARCHAR) FROM times", -out=( -"00:01:20\n" -"20:08:10.998\n" -"20:08:10.33\n" -"20:08:10.001" -)) +""" +) + +test("SELECT * FROM times", out=("00:01:20\n" "20:08:10.998\n" "20:08:10.33\n" "20:08:10.001")) + +test("SELECT cast(i AS VARCHAR) FROM times", out=("00:01:20\n" "20:08:10.998\n" "20:08:10.33\n" "20:08:10.001")) test("SELECT ''::TIME", err="[ISQL]ERROR") test("SELECT ' '::TIME", err="[ISQL]ERROR") @@ -212,33 +196,19 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): test("SELECT 4 / 0", out='') test( -"""DROP TABLE test + """DROP TABLE test CREATE TABLE test (a INTEGER, b INTEGER); INSERT INTO test VALUES (11, 22), (NULL, 21), (13, 22) -""") - -test("SELECT a FROM test", -out=( -"11\n" -"\n" -"13") +""" ) -test("SELECT cast(a AS BIGINT) FROM test;", -out=( -"11\n" -"\n" -"13") -) +test("SELECT a FROM test", out=("11\n" "\n" "13")) + +test("SELECT cast(a AS BIGINT) FROM test;", out=("11\n" "\n" "13")) test("SELECT a / 0 FROM test;", out='') test("SELECT a / (a - a) FROM test;", out='') -test("SELECT a + b FROM test;", -out=( -"33\n" -"\n" -"35") -) +test("SELECT a + b FROM test;", out=("33\n" "\n" "35")) ### FROM test/sql/types/decimal/test_decimal.test ################################# test("SELECT typeof('0.1'::DECIMAL);", out='DECIMAL(18,3)') @@ -263,7 +233,10 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): test("SELECT '123456.789'::DECIMAL(9,3)::VARCHAR;", out='123456.789') test("SELECT '123456789'::DECIMAL(9,0)::VARCHAR;", out='123456789') test("SELECT '123456789'::DECIMAL(18,3)::VARCHAR;", out='123456789.000') -test("SELECT '1701411834604692317316873037.1588410572'::DECIMAL(38,10)::VARCHAR;", out='1701411834604692317316873037.1588410572') +test( + "SELECT '1701411834604692317316873037.1588410572'::DECIMAL(38,10)::VARCHAR;", + out='1701411834604692317316873037.1588410572', +) test("SELECT '0'::DECIMAL(38,10)::VARCHAR;", out='0.0000000000') test("SELECT '0.00003'::DECIMAL(38,10)::VARCHAR;", out='0.0000300000') @@ -276,9 +249,10 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): ### FROM test/sql/types/date/test_date.test ################################# test( -"""CREATE TABLE dates(i DATE); + """CREATE TABLE dates(i DATE); INSERT INTO dates VALUES ('1993-08-14'), (NULL); -""") +""" +) # NULL is print as an empty string, thus python removes it from the stoudt test("SELECT * FROM dates", out='1993-08-14') @@ -304,41 +278,35 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): ### FROM test/sql/types/blob/test_blob.test ################################# test( -"""CREATE TABLE blobs (b BYTEA); + """CREATE TABLE blobs (b BYTEA); INSERT INTO blobs VALUES('\\xaa\\xff\\xaa'), ('\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA'), ('\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA'); -""") +""" +) -test("SELECT * FROM blobs", -out=( -"\\xAA\\xFF\\xAA\n" -"\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\n" -"\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA") +test( + "SELECT * FROM blobs", + out=("\\xAA\\xFF\\xAA\n" "\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\n" "\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA"), ) test( -"""DELETE FROM blobs; + """DELETE FROM blobs; INSERT INTO blobs VALUES('\\xaa\\xff\\xaa'), ('\\xaa\\xff\\xaa\\xaa\\xff\\xaa'), ('\\xaa\\xff\\xaa\\xaa\\xff\\xaa\\xaa\\xff\\xaa'); -""") +""" +) -test("SELECT * FROM blobs", -out=( -"\\xAA\\xFF\\xAA\n" -"\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\n" -"\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA") +test( + "SELECT * FROM blobs", + out=("\\xAA\\xFF\\xAA\n" "\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\n" "\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA\\xAA\\xFF\\xAA"), ) test( -"""DELETE FROM blobs; + """DELETE FROM blobs; INSERT INTO blobs VALUES('\\xaa1199'), ('\\xaa1199aa1199'), ('\\xaa1199aa1199aa1199'); -""") - -test("SELECT * FROM blobs", -out=( -"\\xAA1199\n" -"\\xAA1199aa1199\n" -"\\xAA1199aa1199aa1199") +""" ) +test("SELECT * FROM blobs", out=("\\xAA1199\n" "\\xAA1199aa1199\n" "\\xAA1199aa1199aa1199")) + test("INSERT INTO blobs VALUES('\\xGA\\xFF\\xAA')", err="[ISQL]ERROR") test("INSERT INTO blobs VALUES('\\xA')", err="[ISQL]ERROR") test("INSERT INTO blobs VALUES('\\xAA\\xA')", err="[ISQL]ERROR") @@ -349,10 +317,11 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): test('SELECT NULL::BLOB', out='') test( -"""CREATE TABLE blob_empty (b BYTEA); + """CREATE TABLE blob_empty (b BYTEA); INSERT INTO blob_empty VALUES(''), (''::BLOB); INSERT INTO blob_empty VALUES(NULL), (NULL::BLOB); -""") +""" +) test("SELECT * FROM blob_empty", out='') @@ -360,20 +329,13 @@ def test(cmd, out=None, err=None, extra_commands=None, input_file=None): ### FROM test/sql/types/string/test_unicode.test ################################# test( -"""CREATE TABLE emojis(id INTEGER, s VARCHAR); + """CREATE TABLE emojis(id INTEGER, s VARCHAR); INSERT INTO emojis VALUES (1, '🦆'), (2, '🦆🍞🦆') -""") - -test("SELECT * FROM emojis ORDER BY id", -out=( -"1|🦆\n" -"2|🦆🍞🦆") +""" ) -test("SELECT substring(s, 1, 1), substring(s, 2, 1) FROM emojis ORDER BY id", -out=( -"🦆|\n" -"🦆|🍞") -) +test("SELECT * FROM emojis ORDER BY id", out=("1|🦆\n" "2|🦆🍞🦆")) + +test("SELECT substring(s, 1, 1), substring(s, 2, 1) FROM emojis ORDER BY id", out=("🦆|\n" "🦆|🍞")) test("SELECT length(s) FROM emojis ORDER BY id", out="1\n3") diff --git a/tools/odbc/test/run_psqlodbc_test.py b/tools/odbc/test/run_psqlodbc_test.py index 9c9ff98bf34f..5c6f73444197 100644 --- a/tools/odbc/test/run_psqlodbc_test.py +++ b/tools/odbc/test/run_psqlodbc_test.py @@ -4,24 +4,52 @@ import shutil import sys -so_suffix = '.dylib' if platform.system()=='Darwin' else '.so' +so_suffix = '.dylib' if platform.system() == 'Darwin' else '.so' parser = argparse.ArgumentParser(description='Run the psqlodbc program.') -parser.add_argument('--build_psqlodbc', dest='build_psqlodbc', - action='store_const', const=True, help='clone and build psqlodbc') -parser.add_argument('--psqlodbc', dest='psqlodbcdir', - default=os.path.join(os.getcwd(), 'psqlodbc'), help='path of the psqlodbc directory') -parser.add_argument('--odbc_lib', dest='odbclib', - default=None, help='path of the odbc .so or .dylib file') -parser.add_argument('--release', dest='release', action='store_const', const=True, help='Specify to use release mode instead of debug mode') -parser.add_argument('--overwrite', dest='overwrite', action='store_const', const=True, help='Whether or not to overwrite the ~/.odbc.ini and ~/.odbcinst.ini files') -parser.add_argument('--fix', dest='fix', action='store_const', const=True, help='Whether or not to fix tests, or whether to just run them') +parser.add_argument( + '--build_psqlodbc', dest='build_psqlodbc', action='store_const', const=True, help='clone and build psqlodbc' +) +parser.add_argument( + '--psqlodbc', + dest='psqlodbcdir', + default=os.path.join(os.getcwd(), 'psqlodbc'), + help='path of the psqlodbc directory', +) +parser.add_argument('--odbc_lib', dest='odbclib', default=None, help='path of the odbc .so or .dylib file') +parser.add_argument( + '--release', + dest='release', + action='store_const', + const=True, + help='Specify to use release mode instead of debug mode', +) +parser.add_argument( + '--overwrite', + dest='overwrite', + action='store_const', + const=True, + help='Whether or not to overwrite the ~/.odbc.ini and ~/.odbcinst.ini files', +) +parser.add_argument( + '--fix', + dest='fix', + action='store_const', + const=True, + help='Whether or not to fix tests, or whether to just run them', +) parser.add_argument('--test', dest='test', default=None, help='A specific test to run (if any)') parser.add_argument('--trace_file', dest='tracefile', default='/tmp/odbctrace', help='Path to tracefile of ODBC script') parser.add_argument('--duckdb_dir', dest='duckdbdir', default=os.getcwd(), help='Path to DuckDB directory') parser.add_argument('--no-trace', dest='notrace', action='store_const', const=True, help='Do not print trace') parser.add_argument('--no-exit', dest='noexit', action='store_const', const=True, help='Do not exit on test failure') -parser.add_argument('--debugger', dest='debugger', default=None, choices=['lldb', 'gdb'], help='Debugger to attach (if any). If set, will set up the environment and give you a command to run with the debugger.') +parser.add_argument( + '--debugger', + dest='debugger', + default=None, + choices=['lldb', 'gdb'], + help='Debugger to attach (if any). If set, will set up the environment and give you a command to run with the debugger.', +) args = parser.parse_args() @@ -33,13 +61,15 @@ odbc_reset = os.path.join(odbc_build_dir, 'reset-db') odbc_test = os.path.join(odbc_build_dir, 'psql_odbc_test') + def reset_db(): - try_remove(args.tracefile) - try_remove(os.path.join(args.psqlodbcdir, 'contrib_regression')) - try_remove(os.path.join(args.psqlodbcdir, 'contrib_regression.wal')) + try_remove(args.tracefile) + try_remove(os.path.join(args.psqlodbcdir, 'contrib_regression')) + try_remove(os.path.join(args.psqlodbcdir, 'contrib_regression.wal')) + + command = odbc_reset + f' < {args.psqlodbcdir}/sampletables.sql' + syscall(command, 'Failed to reset db') - command = odbc_reset + f' < {args.psqlodbcdir}/sampletables.sql' - syscall(command, 'Failed to reset db') def print_trace_and_exit(): if args.notrace is None: @@ -48,6 +78,7 @@ def print_trace_and_exit(): if args.noexit is None: exit(1) + def syscall(arg, error, print_trace=True): ret = os.system(arg) if ret != 0: @@ -57,7 +88,6 @@ def syscall(arg, error, print_trace=True): exit(1) - if args.build_psqlodbc is not None: if not os.path.isdir('psqlodbc'): syscall('git clone git@github.com:Mytherin/psqlodbc.git', 'Failed to clone psqlodbc', False) @@ -96,12 +126,14 @@ def syscall(arg, error, print_trace=True): with open(os.path.join(user_dir, '.odbc.ini'), 'w+') as f: f.write(odbc_ini) + def try_remove(fpath): try: os.remove(fpath) except: pass + os.chdir(args.psqlodbcdir) os.environ['PSQLODBC_TEST_DSN'] = 'DuckDB' @@ -133,4 +165,3 @@ def try_remove(fpath): fix_suffix = ' --fix' if args.fix is not None else '' print(f"Running test {test}") syscall(odbc_test + ' ' + test + fix_suffix, f"Failed to run test {test}") - diff --git a/tools/pythonpkg/pyduckdb/__init__.py b/tools/pythonpkg/pyduckdb/__init__.py index a7ddf482ef69..e2fac5b9ef7e 100644 --- a/tools/pythonpkg/pyduckdb/__init__.py +++ b/tools/pythonpkg/pyduckdb/__init__.py @@ -26,7 +26,7 @@ TimestampNanosecondValue, TimestampTimeZoneValue, TimeValue, - TimeTimeZoneValue + TimeTimeZoneValue, ) __all__ = [ @@ -57,5 +57,5 @@ "TimestampNanosecondValue", "TimestampTimeZoneValue", "TimeValue", - "TimeTimeZoneValue" + "TimeTimeZoneValue", ] diff --git a/tools/pythonpkg/pyduckdb/bytes_io_wrapper.py b/tools/pythonpkg/pyduckdb/bytes_io_wrapper.py index 000314c8abf8..3657d84d11ff 100644 --- a/tools/pythonpkg/pyduckdb/bytes_io_wrapper.py +++ b/tools/pythonpkg/pyduckdb/bytes_io_wrapper.py @@ -35,30 +35,31 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ + class BytesIOWrapper: - # Wrapper that wraps a StringIO buffer and reads bytes from it - # Created for compat with pyarrow read_csv - def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: - self.buffer = buffer - self.encoding = encoding - # Because a character can be represented by more than 1 byte, - # it is possible that reading will produce more bytes than n - # We store the extra bytes in this overflow variable, and append the - # overflow to the front of the bytestring the next time reading is performed - self.overflow = b"" + # Wrapper that wraps a StringIO buffer and reads bytes from it + # Created for compat with pyarrow read_csv + def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: + self.buffer = buffer + self.encoding = encoding + # Because a character can be represented by more than 1 byte, + # it is possible that reading will produce more bytes than n + # We store the extra bytes in this overflow variable, and append the + # overflow to the front of the bytestring the next time reading is performed + self.overflow = b"" - def __getattr__(self, attr: str): - return getattr(self.buffer, attr) + def __getattr__(self, attr: str): + return getattr(self.buffer, attr) - def read(self, n: Union[int, None] = -1) -> bytes: - assert self.buffer is not None - bytestring = self.buffer.read(n).encode(self.encoding) - #When n=-1/n greater than remaining bytes: Read entire file/rest of file - combined_bytestring = self.overflow + bytestring - if n is None or n < 0 or n >= len(combined_bytestring): - self.overflow = b"" - return combined_bytestring - else: - to_return = combined_bytestring[:n] - self.overflow = combined_bytestring[n:] - return to_return + def read(self, n: Union[int, None] = -1) -> bytes: + assert self.buffer is not None + bytestring = self.buffer.read(n).encode(self.encoding) + # When n=-1/n greater than remaining bytes: Read entire file/rest of file + combined_bytestring = self.overflow + bytestring + if n is None or n < 0 or n >= len(combined_bytestring): + self.overflow = b"" + return combined_bytestring + else: + to_return = combined_bytestring[:n] + self.overflow = combined_bytestring[n:] + return to_return diff --git a/tools/pythonpkg/pyduckdb/filesystem.py b/tools/pythonpkg/pyduckdb/filesystem.py index 57216a97a59c..a0c3d57f8121 100644 --- a/tools/pythonpkg/pyduckdb/filesystem.py +++ b/tools/pythonpkg/pyduckdb/filesystem.py @@ -4,59 +4,61 @@ from .bytes_io_wrapper import BytesIOWrapper from io import TextIOBase + def is_file_like(obj): - # We only care that we can read from the file - return hasattr(obj, "read") and hasattr(obj, "seek") + # We only care that we can read from the file + return hasattr(obj, "read") and hasattr(obj, "seek") + class ModifiedMemoryFileSystem(MemoryFileSystem): - protocol = ('DUCKDB_INTERNAL_OBJECTSTORE',) - # defer to the original implementation that doesn't hardcode the protocol - _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) - - # Add this manually because it's apparently missing on windows??? - def unstrip_protocol(self, name): - """Format FS-specific path to generic, including protocol""" - protos = (self.protocol,) if isinstance(self.protocol, str) else self.protocol - for protocol in protos: - if name.startswith(f"{protocol}://"): - return name - return f"{protos[0]}://{name}" - - def info(self, path, **kwargs): - path = self._strip_protocol(path) - if path in self.store: - filelike = self.store[path] - return { - "name": path, - "size": getattr(filelike, "size", 0), - "type": "file", - "created": getattr(filelike, "created", None), - } - else: - raise FileNotFoundError(path) - - def _open( - self, - path, - mode="rb", - block_size=None, - autocommit=True, - cache_options=None, - **kwargs, - ): - path = self._strip_protocol(path) - if path in self.store: - f = self.store[path] - return f - else: - raise FileNotFoundError(path) - - def add_file(self, object, path): - if not is_file_like(object): - raise ValueError("Can not read from a non file-like object") - path = self._strip_protocol(path) - if isinstance(object, TextIOBase): - # Wrap this so that we can return a bytes object from 'read' - self.store[path] = BytesIOWrapper(object) - else: - self.store[path] = object + protocol = ('DUCKDB_INTERNAL_OBJECTSTORE',) + # defer to the original implementation that doesn't hardcode the protocol + _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) + + # Add this manually because it's apparently missing on windows??? + def unstrip_protocol(self, name): + """Format FS-specific path to generic, including protocol""" + protos = (self.protocol,) if isinstance(self.protocol, str) else self.protocol + for protocol in protos: + if name.startswith(f"{protocol}://"): + return name + return f"{protos[0]}://{name}" + + def info(self, path, **kwargs): + path = self._strip_protocol(path) + if path in self.store: + filelike = self.store[path] + return { + "name": path, + "size": getattr(filelike, "size", 0), + "type": "file", + "created": getattr(filelike, "created", None), + } + else: + raise FileNotFoundError(path) + + def _open( + self, + path, + mode="rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if path in self.store: + f = self.store[path] + return f + else: + raise FileNotFoundError(path) + + def add_file(self, object, path): + if not is_file_like(object): + raise ValueError("Can not read from a non file-like object") + path = self._strip_protocol(path) + if isinstance(object, TextIOBase): + # Wrap this so that we can return a bytes object from 'read' + self.store[path] = BytesIOWrapper(object) + else: + self.store[path] = object diff --git a/tools/pythonpkg/pyduckdb/spark/__init__.py b/tools/pythonpkg/pyduckdb/spark/__init__.py index e6216c1e81ab..66895dcb0800 100644 --- a/tools/pythonpkg/pyduckdb/spark/__init__.py +++ b/tools/pythonpkg/pyduckdb/spark/__init__.py @@ -4,10 +4,4 @@ from ._globals import _NoValue from .exception import ContributionsAcceptedError -__all__ = [ - "SparkSession", - "DataFrame", - "SparkConf", - "SparkContext", - "ContributionsAcceptedError" -] +__all__ = ["SparkSession", "DataFrame", "SparkConf", "SparkContext", "ContributionsAcceptedError"] diff --git a/tools/pythonpkg/pyduckdb/spark/_globals.py b/tools/pythonpkg/pyduckdb/spark/_globals.py index b14e4cb24ee8..888a4f1fb0fe 100644 --- a/tools/pythonpkg/pyduckdb/spark/_globals.py +++ b/tools/pythonpkg/pyduckdb/spark/_globals.py @@ -61,7 +61,9 @@ def __new__(cls): return cls.__instance # Make the _NoValue instance falsey - def __nonzero__(self): return False + def __nonzero__(self): + return False + __bool__ = __nonzero__ # needed for python 2 to preserve identity through a pickle diff --git a/tools/pythonpkg/pyduckdb/spark/conf.py b/tools/pythonpkg/pyduckdb/spark/conf.py index 2e5c6be60f9a..5e6e1ccc9718 100644 --- a/tools/pythonpkg/pyduckdb/spark/conf.py +++ b/tools/pythonpkg/pyduckdb/spark/conf.py @@ -1,43 +1,45 @@ from typing import TYPE_CHECKING, Optional, List, Tuple from pyduckdb.spark.exception import ContributionsAcceptedError + class SparkConf: - def __init__(self): - raise NotImplementedError - - def contains(self, key: str) -> bool: - raise ContributionsAcceptedError - - def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: - raise ContributionsAcceptedError - - def getAll(self) -> List[Tuple[str, str]]: - raise ContributionsAcceptedError - - def set(self, key: str, value: str) -> "SparkConf": - raise ContributionsAcceptedError - - def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf": - raise ContributionsAcceptedError - - def setAppName(self, value: str) -> "SparkConf": - raise ContributionsAcceptedError - - def setExecutorEnv(self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[List[Tuple[str, str]]] = None) -> "SparkConf": - raise ContributionsAcceptedError - - def setIfMissing(self, key: str, value: str) -> "SparkConf": - raise ContributionsAcceptedError - - def setMaster(self, value: str) -> "SparkConf": - raise ContributionsAcceptedError - - def setSparkHome(self, value: str) -> "SparkConf": - raise ContributionsAcceptedError - - def toDebugString(self) -> str: - raise ContributionsAcceptedError - -__all__ = [ - "SparkConf" -] + def __init__(self): + raise NotImplementedError + + def contains(self, key: str) -> bool: + raise ContributionsAcceptedError + + def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: + raise ContributionsAcceptedError + + def getAll(self) -> List[Tuple[str, str]]: + raise ContributionsAcceptedError + + def set(self, key: str, value: str) -> "SparkConf": + raise ContributionsAcceptedError + + def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf": + raise ContributionsAcceptedError + + def setAppName(self, value: str) -> "SparkConf": + raise ContributionsAcceptedError + + def setExecutorEnv( + self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[List[Tuple[str, str]]] = None + ) -> "SparkConf": + raise ContributionsAcceptedError + + def setIfMissing(self, key: str, value: str) -> "SparkConf": + raise ContributionsAcceptedError + + def setMaster(self, value: str) -> "SparkConf": + raise ContributionsAcceptedError + + def setSparkHome(self, value: str) -> "SparkConf": + raise ContributionsAcceptedError + + def toDebugString(self) -> str: + raise ContributionsAcceptedError + + +__all__ = ["SparkConf"] diff --git a/tools/pythonpkg/pyduckdb/spark/context.py b/tools/pythonpkg/pyduckdb/spark/context.py index 1601ad9a4811..76d1ed936237 100644 --- a/tools/pythonpkg/pyduckdb/spark/context.py +++ b/tools/pythonpkg/pyduckdb/spark/context.py @@ -5,160 +5,158 @@ from pyduckdb.spark.exception import ContributionsAcceptedError from pyduckdb.spark.conf import SparkConf + class SparkContext: - def __init__(self, master: str): - self._connection = duckdb.connect(master) + def __init__(self, master: str): + self._connection = duckdb.connect(master) + + @property + def connection(self) -> DuckDBPyConnection: + return self._connection + + def stop(self) -> None: + self._connection.close() + + @classmethod + def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": + raise ContributionsAcceptedError + + @classmethod + def setSystemProperty(cls, key: str, value: str) -> None: + raise ContributionsAcceptedError + + @property + def applicationId(self) -> str: + raise ContributionsAcceptedError + + @property + def defaultMinPartitions(self) -> int: + raise ContributionsAcceptedError + + @property + def defaultParallelism(self) -> int: + raise ContributionsAcceptedError + + # @property + # def resources(self) -> Dict[str, ResourceInformation]: + # raise ContributionsAcceptedError + + @property + def startTime(self) -> str: + raise ContributionsAcceptedError + + @property + def uiWebUrl(self) -> str: + raise ContributionsAcceptedError + + @property + def version(self) -> str: + raise ContributionsAcceptedError + + def __repr__(self) -> str: + raise ContributionsAcceptedError + + # def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None) -> 'Accumulator[T]': + # pass + + def addArchive(self, path: str) -> None: + raise ContributionsAcceptedError + + def addFile(self, path: str, recursive: bool = False) -> None: + raise ContributionsAcceptedError + + def addPyFile(self, path: str) -> None: + raise ContributionsAcceptedError + + # def binaryFiles(self, path: str, minPartitions: Optional[int] = None) -> pyduckdb.spark.rdd.RDD[typing.Tuple[str, bytes]]: + # pass + + # def binaryRecords(self, path: str, recordLength: int) -> pyduckdb.spark.rdd.RDD[bytes]: + # pass + + # def broadcast(self, value: ~T) -> 'Broadcast[T]': + # pass + + def cancelAllJobs(self) -> None: + raise ContributionsAcceptedError + + def cancelJobGroup(self, groupId: str) -> None: + raise ContributionsAcceptedError + + def dump_profiles(self, path: str) -> None: + raise ContributionsAcceptedError + + # def emptyRDD(self) -> pyduckdb.spark.rdd.RDD[typing.Any]: + # pass + + def getCheckpointDir(self) -> Optional[str]: + raise ContributionsAcceptedError + + def getConf(self) -> SparkConf: + raise ContributionsAcceptedError + + def getLocalProperty(self, key: str) -> Optional[str]: + raise ContributionsAcceptedError + + # def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # pass + + # def hadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # pass + + # def newAPIHadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # pass + + # def newAPIHadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # pass - @property - def connection(self) -> DuckDBPyConnection: - return self._connection + # def parallelize(self, c: Iterable[~T], numSlices: Optional[int] = None) -> pyspark.rdd.RDD[~T]: + # pass - def stop(self) -> None: - self._connection.close() + # def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> pyspark.rdd.RDD[typing.Any]: + # pass - @classmethod - def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": - raise ContributionsAcceptedError + # def range(self, start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None) -> pyspark.rdd.RDD[int]: + # pass - @classmethod - def setSystemProperty(cls, key: str, value: str) -> None: - raise ContributionsAcceptedError + # def runJob(self, rdd: pyspark.rdd.RDD[~T], partitionFunc: Callable[[Iterable[~T]], Iterable[~U]], partitions: Optional[Sequence[int]] = None, allowLocal: bool = False) -> List[~U]: + # pass - @property - def applicationId(self) -> str: - raise ContributionsAcceptedError + # def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: + # pass - @property - def defaultMinPartitions(self) -> int: - raise ContributionsAcceptedError - - @property - def defaultParallelism(self) -> int: - raise ContributionsAcceptedError - - #@property - #def resources(self) -> Dict[str, ResourceInformation]: - # raise ContributionsAcceptedError - - @property - def startTime(self) -> str: - raise ContributionsAcceptedError - - @property - def uiWebUrl(self) -> str: - raise ContributionsAcceptedError - - @property - def version(self) -> str: - raise ContributionsAcceptedError - - def __repr__(self) -> str: - raise ContributionsAcceptedError - - #def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None) -> 'Accumulator[T]': - # pass - - def addArchive(self, path: str) -> None: - raise ContributionsAcceptedError + def setCheckpointDir(self, dirName: str) -> None: + raise ContributionsAcceptedError - def addFile(self, path: str, recursive: bool = False) -> None: - raise ContributionsAcceptedError - - def addPyFile(self, path: str) -> None: - raise ContributionsAcceptedError - - #def binaryFiles(self, path: str, minPartitions: Optional[int] = None) -> pyduckdb.spark.rdd.RDD[typing.Tuple[str, bytes]]: - # pass - - #def binaryRecords(self, path: str, recordLength: int) -> pyduckdb.spark.rdd.RDD[bytes]: - # pass - - #def broadcast(self, value: ~T) -> 'Broadcast[T]': - # pass - - def cancelAllJobs(self) -> None: - raise ContributionsAcceptedError - - def cancelJobGroup(self, groupId: str) -> None: - raise ContributionsAcceptedError - - def dump_profiles(self, path: str) -> None: - raise ContributionsAcceptedError - - #def emptyRDD(self) -> pyduckdb.spark.rdd.RDD[typing.Any]: - # pass - - def getCheckpointDir(self) -> Optional[str]: - raise ContributionsAcceptedError - - def getConf(self) -> SparkConf: - raise ContributionsAcceptedError - - def getLocalProperty(self, key: str) -> Optional[str]: - raise ContributionsAcceptedError - - #def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: - # pass - - #def hadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: - # pass - - #def newAPIHadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: - # pass - - #def newAPIHadoopRDD(self, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: - # pass + def setJobDescription(self, value: str) -> None: + raise ContributionsAcceptedError - #def parallelize(self, c: Iterable[~T], numSlices: Optional[int] = None) -> pyspark.rdd.RDD[~T]: - # pass + def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: + raise ContributionsAcceptedError - #def pickleFile(self, name: str, minPartitions: Optional[int] = None) -> pyspark.rdd.RDD[typing.Any]: - # pass - - #def range(self, start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None) -> pyspark.rdd.RDD[int]: - # pass - - #def runJob(self, rdd: pyspark.rdd.RDD[~T], partitionFunc: Callable[[Iterable[~T]], Iterable[~U]], partitions: Optional[Sequence[int]] = None, allowLocal: bool = False) -> List[~U]: - # pass - - #def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: - # pass + def setLocalProperty(self, key: str, value: str) -> None: + raise ContributionsAcceptedError - def setCheckpointDir(self, dirName: str) -> None: - raise ContributionsAcceptedError - - def setJobDescription(self, value: str) -> None: - raise ContributionsAcceptedError - - def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: - raise ContributionsAcceptedError - - def setLocalProperty(self, key: str, value: str) -> None: - raise ContributionsAcceptedError - - def setLogLevel(self, logLevel: str) -> None: - raise ContributionsAcceptedError - - def show_profiles(self) -> None: - raise ContributionsAcceptedError + def setLogLevel(self, logLevel: str) -> None: + raise ContributionsAcceptedError - def sparkUser(self) -> str: - raise ContributionsAcceptedError + def show_profiles(self) -> None: + raise ContributionsAcceptedError - #def statusTracker(self) -> pyduckdb.spark.status.StatusTracker: - # raise ContributionsAcceptedError + def sparkUser(self) -> str: + raise ContributionsAcceptedError - #def textFile(self, name: str, minPartitions: Optional[int] = None, use_unicode: bool = True) -> pyspark.rdd.RDD[str]: - # pass + # def statusTracker(self) -> pyduckdb.spark.status.StatusTracker: + # raise ContributionsAcceptedError - #def union(self, rdds: List[pyspark.rdd.RDD[~T]]) -> pyspark.rdd.RDD[~T]: - # pass + # def textFile(self, name: str, minPartitions: Optional[int] = None, use_unicode: bool = True) -> pyspark.rdd.RDD[str]: + # pass - #def wholeTextFiles(self, path: str, minPartitions: Optional[int] = None, use_unicode: bool = True) -> pyspark.rdd.RDD[typing.Tuple[str, str]]: - # pass + # def union(self, rdds: List[pyspark.rdd.RDD[~T]]) -> pyspark.rdd.RDD[~T]: + # pass + # def wholeTextFiles(self, path: str, minPartitions: Optional[int] = None, use_unicode: bool = True) -> pyspark.rdd.RDD[typing.Tuple[str, str]]: + # pass -__all__ = [ - "SparkContext" -] +__all__ = ["SparkContext"] diff --git a/tools/pythonpkg/pyduckdb/spark/exception.py b/tools/pythonpkg/pyduckdb/spark/exception.py index 64d797a95305..33a92ae328ef 100644 --- a/tools/pythonpkg/pyduckdb/spark/exception.py +++ b/tools/pythonpkg/pyduckdb/spark/exception.py @@ -1,10 +1,9 @@ class ContributionsAcceptedError(NotImplementedError): - """ - This method is not planned to be implemented, if you would like to implement this method - or show your interest in this method to other members of the community, - feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb - """ + """ + This method is not planned to be implemented, if you would like to implement this method + or show your interest in this method to other members of the community, + feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb + """ -__all__ = [ - "ContributionsAcceptedError" -] + +__all__ = ["ContributionsAcceptedError"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/__init__.py b/tools/pythonpkg/pyduckdb/spark/sql/__init__.py index 2167d19b4add..2312ee509788 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/__init__.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/__init__.py @@ -4,10 +4,4 @@ from .conf import RuntimeConfig from .catalog import Catalog -__all__ = [ - "SparkSession", - "DataFrame", - "RuntimeConfig", - "DataFrameWriter", - "Catalog" -] +__all__ = ["SparkSession", "DataFrame", "RuntimeConfig", "DataFrameWriter", "Catalog"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/catalog.py b/tools/pythonpkg/pyduckdb/spark/sql/catalog.py index b99c81a438ef..5a80b745a4d1 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/catalog.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/catalog.py @@ -1,11 +1,13 @@ from typing import List, NamedTuple, Optional from pyduckdb.spark.sql.session import SparkSession + class Database(NamedTuple): name: str description: Optional[str] locationUri: str + class Table(NamedTuple): name: str database: Optional[str] @@ -13,6 +15,7 @@ class Table(NamedTuple): tableType: str isTemporary: bool + class Column(NamedTuple): name: str description: Optional[str] @@ -21,65 +24,55 @@ class Column(NamedTuple): isPartition: bool isBucket: bool + class Function(NamedTuple): name: str description: Optional[str] className: str isTemporary: bool + class Catalog: - def __init__(self, session: SparkSession): - self._session = session - - def listDatabases(self) -> List[Database]: - res = self._session.conn.sql('select * from duckdb_databases()').fetchall() - def transform_to_database(x) -> Database: - return Database(name=x[0], description=None, locationUri='') - databases = [transform_to_database(x) for x in res] - return databases - - def listTables(self) -> List[Table]: - res = self._session.conn.sql('select * from duckdb_tables()').fetchall() - def transform_to_table(x) -> Table: - return Table( - name=x[4], - database=x[0], - description=x[13], - tableType='', - isTemporary=x[7] - ) - tables = [transform_to_table(x) for x in res] - return tables - - def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Column]: - query = f""" + def __init__(self, session: SparkSession): + self._session = session + + def listDatabases(self) -> List[Database]: + res = self._session.conn.sql('select * from duckdb_databases()').fetchall() + + def transform_to_database(x) -> Database: + return Database(name=x[0], description=None, locationUri='') + + databases = [transform_to_database(x) for x in res] + return databases + + def listTables(self) -> List[Table]: + res = self._session.conn.sql('select * from duckdb_tables()').fetchall() + + def transform_to_table(x) -> Table: + return Table(name=x[4], database=x[0], description=x[13], tableType='', isTemporary=x[7]) + + tables = [transform_to_table(x) for x in res] + return tables + + def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Column]: + query = f""" select * from duckdb_columns() where table_name = '{tableName}' """ - if dbName: - query += f" and database_name = '{dbName}'" - res = self._session.conn.sql(query).fetchall() - def transform_to_column(x) -> Column: - return Column( - name=x[6], - description=None, - dataType=x[11], - nullable=x[8], - isPartition=False, - isBucket=False - ) - columns = [transform_to_column(x) for x in res] - return columns - - def listFunctions(self, dbName: Optional[str] = None) -> List[Function]: - raise NotImplementedError - - def setCurrentDatabase(self, dbName: str) -> None: - raise NotImplementedError - -__all__ = [ - "Catalog", - "Table", - "Column", - "Function", - "Database" -] + if dbName: + query += f" and database_name = '{dbName}'" + res = self._session.conn.sql(query).fetchall() + + def transform_to_column(x) -> Column: + return Column(name=x[6], description=None, dataType=x[11], nullable=x[8], isPartition=False, isBucket=False) + + columns = [transform_to_column(x) for x in res] + return columns + + def listFunctions(self, dbName: Optional[str] = None) -> List[Function]: + raise NotImplementedError + + def setCurrentDatabase(self, dbName: str) -> None: + raise NotImplementedError + + +__all__ = ["Catalog", "Table", "Column", "Function", "Database"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/conf.py b/tools/pythonpkg/pyduckdb/spark/sql/conf.py index 2301c60a77fa..7a404ff4a09f 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/conf.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/conf.py @@ -2,22 +2,22 @@ from pyduckdb.spark._globals import _NoValueType, _NoValue from duckdb import DuckDBPyConnection + class RuntimeConfig: - def __init__(self, connection: DuckDBPyConnection): - self._connection = connection + def __init__(self, connection: DuckDBPyConnection): + self._connection = connection + + def set(self, key: str, value: str) -> None: + raise NotImplementedError - def set(self, key: str, value: str) -> None: - raise NotImplementedError + def isModifiable(self, key: str) -> bool: + raise NotImplementedError - def isModifiable(self, key: str) -> bool: - raise NotImplementedError + def unset(self, key: str) -> None: + raise NotImplementedError - def unset(self, key: str) -> None: - raise NotImplementedError + def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: + raise NotImplementedError - def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: - raise NotImplementedError -__all__ = [ - "RuntimeConfig" -] +__all__ = ["RuntimeConfig"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/dataframe.py b/tools/pythonpkg/pyduckdb/spark/sql/dataframe.py index edc87c57ecaa..d7e201be285f 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/dataframe.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/dataframe.py @@ -4,26 +4,26 @@ import duckdb if TYPE_CHECKING: - from pyduckdb.spark.sql.session import SparkSession + from pyduckdb.spark.sql.session import SparkSession + class DataFrame: - def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession"): - self.relation = relation - self.session = session + def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession"): + self.relation = relation + self.session = session + + def show(self) -> None: + self.relation.show() - def show(self) -> None: - self.relation.show() + def createOrReplaceTempView(self, name: str) -> None: + raise NotImplementedError - def createOrReplaceTempView(self, name: str) -> None: - raise NotImplementedError + def createGlobalTempView(self, name: str) -> None: + raise NotImplementedError - def createGlobalTempView(self, name: str) -> None: - raise NotImplementedError + @property + def write(self) -> DataFrameWriter: + return DataFrameWriter(self) - @property - def write(self) -> DataFrameWriter: - return DataFrameWriter(self) -__all__ = [ - "DataFrame" -] +__all__ = ["DataFrame"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/readwriter.py b/tools/pythonpkg/pyduckdb/spark/sql/readwriter.py index 101f4941a050..142ca0ce51e3 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/readwriter.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/readwriter.py @@ -5,27 +5,34 @@ OptionalPrimitiveType = Optional[PrimitiveType] if TYPE_CHECKING: - from pyduckdb.spark.sql.dataframe import DataFrame - from pyduckdb.spark.sql.session import SparkSession + from pyduckdb.spark.sql.dataframe import DataFrame + from pyduckdb.spark.sql.session import SparkSession + class DataFrameWriter: - def __init__(self, dataframe: "DataFrame"): - self.dataframe = dataframe + def __init__(self, dataframe: "DataFrame"): + self.dataframe = dataframe + + def saveAsTable(self, table_name: str) -> None: + relation = self.dataframe.relation + relation.create(table_name) - def saveAsTable(self, table_name: str) -> None: - relation = self.dataframe.relation - relation.create(table_name) class DataFrameReader: - def __init__(self, session: "SparkSession"): - raise NotImplementedError - self.session = session - - def load(self, path: Union[str, List[str], None] = None, format: Optional[str] = None, schema: Union[StructType, str, None] = None, **options: OptionalPrimitiveType) -> "DataFrame": - from pyduckdb.spark.sql.dataframe import DataFrame - raise NotImplementedError - -__all__ = [ - "DataFrameWriter", - "DataFrameReader" -] + def __init__(self, session: "SparkSession"): + raise NotImplementedError + self.session = session + + def load( + self, + path: Union[str, List[str], None] = None, + format: Optional[str] = None, + schema: Union[StructType, str, None] = None, + **options: OptionalPrimitiveType + ) -> "DataFrame": + from pyduckdb.spark.sql.dataframe import DataFrame + + raise NotImplementedError + + +__all__ = ["DataFrameWriter", "DataFrameReader"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/session.py b/tools/pythonpkg/pyduckdb/spark/sql/session.py index b047a54008e5..69d9217d5d55 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/session.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/session.py @@ -1,7 +1,7 @@ from typing import Optional, List, Tuple, Any, TYPE_CHECKING if TYPE_CHECKING: - from pyduckdb.spark.sql.catalog import Catalog + from pyduckdb.spark.sql.catalog import Catalog from pyduckdb.spark.exception import ContributionsAcceptedError @@ -22,116 +22,121 @@ # For us this is done inside of `duckdb.connect`, based on the passed in path + configuration # SparkContext can be compared to our Connection class, and SparkConf to our ClientContext class + class SparkSession: - def __init__(self, context : SparkContext): - self.conn = context.connection - self._context = context - self._conf = RuntimeConfig(self.conn) - - def createDataFrame(self, tuples: List[Tuple[Any, ...]]) -> DataFrame: - parameter_count = len(tuples) - parameters = [f'${x+1}' for x in range(parameter_count)] - parameters = ', '.join(parameters) - query = f""" + def __init__(self, context: SparkContext): + self.conn = context.connection + self._context = context + self._conf = RuntimeConfig(self.conn) + + def createDataFrame(self, tuples: List[Tuple[Any, ...]]) -> DataFrame: + parameter_count = len(tuples) + parameters = [f'${x+1}' for x in range(parameter_count)] + parameters = ', '.join(parameters) + query = f""" select {parameters} """ - # FIXME: we can't add prepared parameters to a relation - # or extract the relation from a connection after 'execute' - raise NotImplementedError() - - def newSession(self) -> "SparkSession": - return SparkSession(self._context) - - def range(self, start: int, end: Optional[int] = None, step: int = 1, numPartitions: Optional[int] = None) -> "DataFrame": - raise ContributionsAcceptedError - - def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: - if kwargs: - raise NotImplementedError - relation = self.conn.sql(sqlQuery) - return DataFrame(relation, self) - - def stop(self) -> None: - self._context.stop() - - def table(self, tableName: str) -> DataFrame: - relation = self.conn.table(tableName) - return DataFrame(relation, self) - - def getActiveSession(self) -> "SparkSession": - return self - - @property - def catalog(self) -> "Catalog": - if not hasattr(self, "_catalog"): - from pyduckdb.spark.sql.catalog import Catalog - self._catalog = Catalog(self) - return self._catalog - - @property - def conf(self) -> RuntimeConfig: - return self._conf - - @property - def read(self) -> DataFrameReader: - return DataFrameReader(self) - - @property - def readStream(self) -> DataStreamReader: - return DataStreamReader(self) - - @property - def sparkContext(self) -> SparkContext: - return self._context - - @property - def streams(self) -> Any: - raise ContributionsAcceptedError - - @property - def udf(self) -> UDFRegistration: - return UDFRegistration() - - @property - def version(self) -> str: - return '1.0.0' - - class Builder: - def __init__(self): - self.name = "builder" - self._master = ':memory:' - self._config = {} - - def master(self, name: str) -> "SparkSession.Builder": - self._master = name - return self - - def appName(self, name: str) -> "SparkSession.Builder": - # no-op - return self - - def remote(self, url: str) -> "SparkSession.Builder": - # no-op - return self - - def getOrCreate(self) -> "SparkSession": - # TODO: use the config to pass in methods to 'connect' - context = SparkContext(self._master) - return SparkSession(context) - - def config(self, key: Optional[str] = None, value: Optional[Any] = None, conf: Optional[SparkConf] = None) -> "SparkSession.Builder": - if conf: - raise NotImplementedError - if (key and value): - self._config[key] = value - return self - - def enableHiveSupport(self) -> "SparkSession.Builder": - # no-op - return self - - builder = Builder() - -__all__ = [ - "SparkSession" -] + # FIXME: we can't add prepared parameters to a relation + # or extract the relation from a connection after 'execute' + raise NotImplementedError() + + def newSession(self) -> "SparkSession": + return SparkSession(self._context) + + def range( + self, start: int, end: Optional[int] = None, step: int = 1, numPartitions: Optional[int] = None + ) -> "DataFrame": + raise ContributionsAcceptedError + + def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: + if kwargs: + raise NotImplementedError + relation = self.conn.sql(sqlQuery) + return DataFrame(relation, self) + + def stop(self) -> None: + self._context.stop() + + def table(self, tableName: str) -> DataFrame: + relation = self.conn.table(tableName) + return DataFrame(relation, self) + + def getActiveSession(self) -> "SparkSession": + return self + + @property + def catalog(self) -> "Catalog": + if not hasattr(self, "_catalog"): + from pyduckdb.spark.sql.catalog import Catalog + + self._catalog = Catalog(self) + return self._catalog + + @property + def conf(self) -> RuntimeConfig: + return self._conf + + @property + def read(self) -> DataFrameReader: + return DataFrameReader(self) + + @property + def readStream(self) -> DataStreamReader: + return DataStreamReader(self) + + @property + def sparkContext(self) -> SparkContext: + return self._context + + @property + def streams(self) -> Any: + raise ContributionsAcceptedError + + @property + def udf(self) -> UDFRegistration: + return UDFRegistration() + + @property + def version(self) -> str: + return '1.0.0' + + class Builder: + def __init__(self): + self.name = "builder" + self._master = ':memory:' + self._config = {} + + def master(self, name: str) -> "SparkSession.Builder": + self._master = name + return self + + def appName(self, name: str) -> "SparkSession.Builder": + # no-op + return self + + def remote(self, url: str) -> "SparkSession.Builder": + # no-op + return self + + def getOrCreate(self) -> "SparkSession": + # TODO: use the config to pass in methods to 'connect' + context = SparkContext(self._master) + return SparkSession(context) + + def config( + self, key: Optional[str] = None, value: Optional[Any] = None, conf: Optional[SparkConf] = None + ) -> "SparkSession.Builder": + if conf: + raise NotImplementedError + if key and value: + self._config[key] = value + return self + + def enableHiveSupport(self) -> "SparkSession.Builder": + # no-op + return self + + builder = Builder() + + +__all__ = ["SparkSession"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/streaming.py b/tools/pythonpkg/pyduckdb/spark/sql/streaming.py index 13e27539db97..f136b47fe780 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/streaming.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/streaming.py @@ -2,29 +2,36 @@ from pyduckdb.spark.sql.types import StructType if TYPE_CHECKING: - from pyduckdb.spark.sql.dataframe import DataFrame - from pyduckdb.spark.sql.session import SparkSession + from pyduckdb.spark.sql.dataframe import DataFrame + from pyduckdb.spark.sql.session import SparkSession PrimitiveType = Union[bool, float, int, str] OptionalPrimitiveType = Optional[PrimitiveType] + class DataStreamWriter: - def __init__(self, dataframe: "DataFrame"): - self.dataframe = dataframe + def __init__(self, dataframe: "DataFrame"): + self.dataframe = dataframe + + def toTable(self, table_name: str) -> None: + # Should we register the dataframe or create a table from the contents? + raise NotImplementedError - def toTable(self, table_name: str) -> None: - # Should we register the dataframe or create a table from the contents? - raise NotImplementedError class DataStreamReader: - def __init__(self, session: "SparkSession"): - self.session = session + def __init__(self, session: "SparkSession"): + self.session = session + + def load( + self, + path: Optional[str] = None, + format: Optional[str] = None, + schema: Union[StructType, str, None] = None, + **options: OptionalPrimitiveType + ) -> "DataFrame": + from pyduckdb.spark.sql.dataframe import DataFrame + + raise NotImplementedError - def load(self, path: Optional[str] = None, format: Optional[str] = None, schema: Union[StructType, str, None] = None, **options: OptionalPrimitiveType) -> "DataFrame": - from pyduckdb.spark.sql.dataframe import DataFrame - raise NotImplementedError -__all__ = [ - "DataStreamReader", - "DataStreamWriter" -] +__all__ = ["DataStreamReader", "DataStreamWriter"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/types.py b/tools/pythonpkg/pyduckdb/spark/sql/types.py index 7dbd0abf8d86..d0e9d92d0f82 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/types.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/types.py @@ -1,9 +1,9 @@ from typing import Optional, List, Tuple, Any, TYPE_CHECKING + class StructType: - def __init__(self): - raise NotImplementedError + def __init__(self): + raise NotImplementedError + -__all__ = [ - "StructType" -] +__all__ = ["StructType"] diff --git a/tools/pythonpkg/pyduckdb/spark/sql/udf.py b/tools/pythonpkg/pyduckdb/spark/sql/udf.py index 534b4d419763..f2e5335a9510 100644 --- a/tools/pythonpkg/pyduckdb/spark/sql/udf.py +++ b/tools/pythonpkg/pyduckdb/spark/sql/udf.py @@ -1,9 +1,9 @@ # https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ + class UDFRegistration: - def __init__(self): - raise NotImplementedError + def __init__(self): + raise NotImplementedError + -__all__ = [ - "UDFRegistration" -] +__all__ = ["UDFRegistration"] diff --git a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_catalog.py b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_catalog.py index a44bceaed67e..4e8745162e82 100644 --- a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_catalog.py +++ b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_catalog.py @@ -1,28 +1,33 @@ - - import pytest from pyduckdb.spark.sql.catalog import Table, Database, Column + class TestSparkCatalog(object): - def test_list_databases(self, spark): - dbs = spark.catalog.listDatabases() - assert dbs == [ - Database(name='memory', description=None, locationUri=''), - Database(name='system', description=None, locationUri=''), - Database(name='temp', description=None, locationUri='') - ] + def test_list_databases(self, spark): + dbs = spark.catalog.listDatabases() + assert dbs == [ + Database(name='memory', description=None, locationUri=''), + Database(name='system', description=None, locationUri=''), + Database(name='temp', description=None, locationUri=''), + ] - def test_list_tables(self, spark): - spark.sql('create table tbl(a varchar)') - tbls = spark.catalog.listTables() - assert tbls == [ - Table(name='tbl', database='memory', description='CREATE TABLE tbl(a VARCHAR);', tableType='', isTemporary=False) - ] + def test_list_tables(self, spark): + spark.sql('create table tbl(a varchar)') + tbls = spark.catalog.listTables() + assert tbls == [ + Table( + name='tbl', + database='memory', + description='CREATE TABLE tbl(a VARCHAR);', + tableType='', + isTemporary=False, + ) + ] - def test_list_columns(self, spark): - spark.sql('create table tbl(a varchar, b bool)') - columns = spark.catalog.listColumns('tbl') - assert columns == [ - Column(name='a', description=None, dataType='VARCHAR', nullable=False, isPartition=False, isBucket=False), - Column(name='b', description=None, dataType='BOOLEAN', nullable=False, isPartition=False, isBucket=False) - ] + def test_list_columns(self, spark): + spark.sql('create table tbl(a varchar, b bool)') + columns = spark.catalog.listColumns('tbl') + assert columns == [ + Column(name='a', description=None, dataType='VARCHAR', nullable=False, isPartition=False, isBucket=False), + Column(name='b', description=None, dataType='BOOLEAN', nullable=False, isPartition=False, isBucket=False), + ] diff --git a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_dataframe.py b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_dataframe.py index 20c4c9c29d29..64f9ece8dfad 100644 --- a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_dataframe.py +++ b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_dataframe.py @@ -1,27 +1,29 @@ import pytest + class TestDataFrame(object): - @pytest.mark.skip("can't create a dataframe from a list of tuples yet") - def test_dataframe(self, spark): - # Create DataFrame - df = spark.createDataFrame( - [("Scala", 25000), ("Spark", 35000), ("PHP", 21000)]) - df.show() + @pytest.mark.skip("can't create a dataframe from a list of tuples yet") + def test_dataframe(self, spark): + # Create DataFrame + df = spark.createDataFrame([("Scala", 25000), ("Spark", 35000), ("PHP", 21000)]) + df.show() - # Output - #+-----+-----+ - #| _1| _2| - #+-----+-----+ - #|Scala|25000| - #|Spark|35000| - #| PHP|21000| - #+-----+-----+ + # Output + # +-----+-----+ + # | _1| _2| + # +-----+-----+ + # |Scala|25000| + # |Spark|35000| + # | PHP|21000| + # +-----+-----+ - def test_writing_to_table(self, spark): - # Create Hive table & query it. - spark.sql(""" + def test_writing_to_table(self, spark): + # Create Hive table & query it. + spark.sql( + """ create table sample_table("_1" bool, "_2" integer) - """) - spark.table("sample_table").write.saveAsTable("sample_hive_table") - df3 = spark.sql("SELECT _1,_2 FROM sample_hive_table") - df3.show() + """ + ) + spark.table("sample_table").write.saveAsTable("sample_hive_table") + df3 = spark.sql("SELECT _1,_2 FROM sample_hive_table") + df3.show() diff --git a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_runtime_config.py b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_runtime_config.py index 8d982dc944f9..60d4bbe56574 100644 --- a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_runtime_config.py +++ b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_runtime_config.py @@ -1,19 +1,18 @@ - - import pytest + class TestSparkRuntimeConfig(object): - def test_spark_runtime_config(self, spark): - # This fetches the internal runtime config from the session - spark.conf - - def test_spark_runtime_config_set(self, spark): - # Set Config - with pytest.raises(NotImplementedError): - spark.conf.set("spark.executor.memory", "5g") + def test_spark_runtime_config(self, spark): + # This fetches the internal runtime config from the session + spark.conf + + def test_spark_runtime_config_set(self, spark): + # Set Config + with pytest.raises(NotImplementedError): + spark.conf.set("spark.executor.memory", "5g") - @pytest.mark.skip(reason="RuntimeConfig is not implemented yet") - def test_spark_runtime_config_get(self, spark): - # Get a Spark Config - with pytest.raises(KeyError): - partitions = spark.conf.get("spark.sql.shuffle.partitions") + @pytest.mark.skip(reason="RuntimeConfig is not implemented yet") + def test_spark_runtime_config_get(self, spark): + # Get a Spark Config + with pytest.raises(KeyError): + partitions = spark.conf.get("spark.sql.shuffle.partitions") diff --git a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_session.py b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_session.py index 5d475e8edb39..dcd94fee73c0 100644 --- a/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_session.py +++ b/tools/pythonpkg/pyduckdb/spark/tests/basic/test_spark_session.py @@ -1,76 +1,77 @@ import pytest from pyduckdb.spark.sql import SparkSession + class TestSparkSession(object): - def test_spark_session_default(self): - session = SparkSession.builder.getOrCreate() - - def test_spark_session(self): - session = SparkSession.builder.master("local[1]") \ - .appName('SparkByExamples.com') \ - .getOrCreate() - - def test_new_session(self, spark: SparkSession): - session = spark.newSession() - print(session) - - @pytest.mark.skip(reason='not tested yet') - def test_retrieve_same_session(self): - spark = SparkSession.builder.master('test').appName('test2').getOrCreate() - spark2 = SparkSession.builder.getOrCreate() - # Same connection should be returned - assert spark == spark2 - - def test_config(self): - # Usage of config() - spark = SparkSession.builder \ - .master("local[1]") \ - .appName("SparkByExamples.com") \ - .config("spark.some.config.option", "config-value") \ - .getOrCreate() - - @pytest.mark.skip(reason="enableHiveSupport is not implemented yet") - def test_hive_support(self): - # Enabling Hive to use in Spark - spark = SparkSession.builder \ - .master("local[1]") \ - .appName("SparkByExamples.com") \ - .config("spark.sql.warehouse.dir", "/spark-warehouse") \ - .enableHiveSupport() \ - .getOrCreate() - - def test_version(self, spark): - version = spark.version - assert version == '1.0.0' - - def test_get_active_session(self, spark): - active_session = spark.getActiveSession() - - def test_read(self, spark): - with pytest.raises(NotImplementedError): - reader = spark.read - - def test_write(self, spark): - df = spark.sql('select 42') - writer = df.write - - def test_read_stream(self, spark): - reader = spark.readStream - - def test_spark_context(self, spark): - context = spark.sparkContext - - def test_sql(self, spark): - df = spark.sql('select 42') - - def test_stop_context(self, spark): - context = spark.sparkContext - spark.stop() - - def test_table(self, spark): - spark.sql('create table tbl(a varchar(10))') - df = spark.table('tbl') - - def test_udf(self, spark): - with pytest.raises(NotImplementedError): - udf_registration = spark.udf + def test_spark_session_default(self): + session = SparkSession.builder.getOrCreate() + + def test_spark_session(self): + session = SparkSession.builder.master("local[1]").appName('SparkByExamples.com').getOrCreate() + + def test_new_session(self, spark: SparkSession): + session = spark.newSession() + print(session) + + @pytest.mark.skip(reason='not tested yet') + def test_retrieve_same_session(self): + spark = SparkSession.builder.master('test').appName('test2').getOrCreate() + spark2 = SparkSession.builder.getOrCreate() + # Same connection should be returned + assert spark == spark2 + + def test_config(self): + # Usage of config() + spark = ( + SparkSession.builder.master("local[1]") + .appName("SparkByExamples.com") + .config("spark.some.config.option", "config-value") + .getOrCreate() + ) + + @pytest.mark.skip(reason="enableHiveSupport is not implemented yet") + def test_hive_support(self): + # Enabling Hive to use in Spark + spark = ( + SparkSession.builder.master("local[1]") + .appName("SparkByExamples.com") + .config("spark.sql.warehouse.dir", "/spark-warehouse") + .enableHiveSupport() + .getOrCreate() + ) + + def test_version(self, spark): + version = spark.version + assert version == '1.0.0' + + def test_get_active_session(self, spark): + active_session = spark.getActiveSession() + + def test_read(self, spark): + with pytest.raises(NotImplementedError): + reader = spark.read + + def test_write(self, spark): + df = spark.sql('select 42') + writer = df.write + + def test_read_stream(self, spark): + reader = spark.readStream + + def test_spark_context(self, spark): + context = spark.sparkContext + + def test_sql(self, spark): + df = spark.sql('select 42') + + def test_stop_context(self, spark): + context = spark.sparkContext + spark.stop() + + def test_table(self, spark): + spark.sql('create table tbl(a varchar(10))') + df = spark.table('tbl') + + def test_udf(self, spark): + with pytest.raises(NotImplementedError): + udf_registration = spark.udf diff --git a/tools/pythonpkg/pyduckdb/spark/tests/conftest.py b/tools/pythonpkg/pyduckdb/spark/tests/conftest.py index cac66bb8175e..4359879a96a1 100644 --- a/tools/pythonpkg/pyduckdb/spark/tests/conftest.py +++ b/tools/pythonpkg/pyduckdb/spark/tests/conftest.py @@ -1,7 +1,8 @@ import pytest from pyduckdb.spark.sql import SparkSession + # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture @pytest.fixture(scope='function', autouse=True) def spark(): - return SparkSession.builder.master(':memory:').appName('pyspark').getOrCreate() + return SparkSession.builder.master(':memory:').appName('pyspark').getOrCreate() diff --git a/tools/pythonpkg/pyduckdb/udf.py b/tools/pythonpkg/pyduckdb/udf.py index 073db1576ee6..bbf05c7da74b 100644 --- a/tools/pythonpkg/pyduckdb/udf.py +++ b/tools/pythonpkg/pyduckdb/udf.py @@ -1,24 +1,19 @@ def vectorized(func): - """ - Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output - """ - from inspect import signature - import types - new_func = types.FunctionType( - func.__code__, - func.__globals__, - func.__name__, - func.__defaults__, - func.__closure__ - ) - # Construct the annotations: - import pyarrow as pa + """ + Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output + """ + from inspect import signature + import types - new_annotations = {} - sig = signature(func) - sig.parameters - for param in sig.parameters: - new_annotations[param] = pa.lib.ChunkedArray + new_func = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) + # Construct the annotations: + import pyarrow as pa - new_func.__annotations__ = new_annotations - return new_func + new_annotations = {} + sig = signature(func) + sig.parameters + for param in sig.parameters: + new_annotations[param] = pa.lib.ChunkedArray + + new_func.__annotations__ = new_annotations + return new_func diff --git a/tools/pythonpkg/pyduckdb/value/constant.py b/tools/pythonpkg/pyduckdb/value/constant.py index 2a58a8c6ea3d..6b2c8a408cf0 100644 --- a/tools/pythonpkg/pyduckdb/value/constant.py +++ b/tools/pythonpkg/pyduckdb/value/constant.py @@ -1,7 +1,4 @@ -from typing import ( - Any, - Dict -) +from typing import Any, Dict from duckdb.typing import DuckDBPyType from duckdb.typing import ( BIGINT, @@ -29,163 +26,208 @@ USMALLINT, UTINYINT, UUID, - VARCHAR + VARCHAR, ) + class Value: def __init__(self, object: Any, type: DuckDBPyType): self.object = object self.type = type + def __repr__(self) -> str: - return str(self.object) + return str(self.object) + # Miscellaneous + class NullValue(Value): def __init__(self): super().__init__(None, SQLNULL) + class BooleanValue(Value): def __init__(self, object: Any): super().__init__(object, BOOLEAN) + # Unsigned numerics + class UnsignedBinaryValue(Value): def __init__(self, object: Any): super().__init__(object, UTINYINT) + class UnsignedShortValue(Value): def __init__(self, object: Any): super().__init__(object, USMALLINT) + class UnsignedIntegerValue(Value): def __init__(self, object: Any): super().__init__(object, UINTEGER) + class UnsignedLongValue(Value): def __init__(self, object: Any): super().__init__(object, UBIGINT) + # Signed numerics + class BinaryValue(Value): def __init__(self, object: Any): super().__init__(object, TINYINT) + class ShortValue(Value): def __init__(self, object: Any): super().__init__(object, SMALLINT) + class IntegerValue(Value): def __init__(self, object: Any): super().__init__(object, INTEGER) + class LongValue(Value): def __init__(self, object: Any): super().__init__(object, BIGINT) + class HugeIntegerValue(Value): def __init__(self, object: Any): super().__init__(object, HUGEINT) + # Fractional + class FloatValue(Value): def __init__(self, object: Any): super().__init__(object, FLOAT) + class DoubleValue(Value): def __init__(self, object: Any): super().__init__(object, DOUBLE) + class DecimalValue(Value): def __init__(self, object: Any, width: int, scale: int): import duckdb + decimal_type = duckdb.decimal_type(width, scale) super().__init__(object, decimal_type) + # String + class StringValue(Value): def __init__(self, object: Any): super().__init__(object, VARCHAR) + class UUIDValue(Value): def __init__(self, object: Any): super().__init__(object, UUID) + class BitValue(Value): def __init__(self, object: Any): super().__init__(object, BIT) + class BlobValue(Value): def __init__(self, object: Any): super().__init__(object, BLOB) + # Temporal + class DateValue(Value): def __init__(self, object: Any): super().__init__(object, DATE) + class IntervalValue(Value): def __init__(self, object: Any): super().__init__(object, INTERVAL) + class TimestampValue(Value): def __init__(self, object: Any): super().__init__(object, TIMESTAMP) + class TimestampSecondValue(Value): def __init__(self, object: Any): super().__init__(object, TIMESTAMP_S) + class TimestampMilisecondValue(Value): def __init__(self, object: Any): super().__init__(object, TIMESTAMP_MS) + class TimestampNanosecondValue(Value): def __init__(self, object: Any): super().__init__(object, TIMESTAMP_NS) + class TimestampTimeZoneValue(Value): def __init__(self, object: Any): super().__init__(object, TIMESTAMP_TZ) + class TimeValue(Value): def __init__(self, object: Any): super().__init__(object, TIME) + class TimeTimeZoneValue(Value): def __init__(self, object: Any): super().__init__(object, TIME_TZ) + class ListValue(Value): def __init__(self, object: Any, child_type: DuckDBPyType): import duckdb + list_type = duckdb.list_type(child_type) super().__init__(object, list_type) + class StructValue(Value): def __init__(self, object: Any, children: Dict[str, DuckDBPyType]): import duckdb + struct_type = duckdb.struct_type(children) super().__init__(object, struct_type) + class MapValue(Value): def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType): import duckdb + map_type = duckdb.map_type(key_type, value_type) super().__init__(object, map_type) + class UnionType(Value): def __init__(self, object: Any, members: Dict[str, DuckDBPyType]): import duckdb + union_type = duckdb.union_type(members) super().__init__(object, union_type) -#TODO: add EnumValue once `duckdb.enum_type` is added + +# TODO: add EnumValue once `duckdb.enum_type` is added __all__ = [ "Value", @@ -215,5 +257,5 @@ def __init__(self, object: Any, members: Dict[str, DuckDBPyType]): "TimestampNanosecondValue", "TimestampTimeZoneValue", "TimeValue", - "TimeTimeZoneValue" + "TimeTimeZoneValue", ] diff --git a/tools/pythonpkg/setup.py b/tools/pythonpkg/setup.py old mode 100755 new mode 100644 index aa563924f0a6..4cc3aea73340 --- a/tools/pythonpkg/setup.py +++ b/tools/pythonpkg/setup.py @@ -70,11 +70,22 @@ class build_ext(CompilerLauncherMixin, _build_ext): if 'DUCKDB_BUILD_UNITY' in os.environ: unity_build = 16 -def parallel_cpp_compile(self, sources, output_dir=None, macros=None, include_dirs=None, debug=0, - extra_preargs=None, extra_postargs=None, depends=None): + +def parallel_cpp_compile( + self, + sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None, +): # Copied from distutils.ccompiler.CCompiler macros, objects, extra_postargs, pp_opts, build = self._setup_compile( - output_dir, macros, include_dirs, sources, depends, extra_postargs) + output_dir, macros, include_dirs, sources, depends, extra_postargs + ) cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) @@ -92,15 +103,19 @@ def _single_compile(obj): # speed up compilation with: -j = cpu_number() on non Windows machines if os.name != 'nt' and os.environ.get('DUCKDB_DISABLE_PARALLEL_COMPILE', '') != '1': import distutils.ccompiler + distutils.ccompiler.CCompiler.compile = parallel_cpp_compile + def open_utf8(fpath, flags): import sys + if sys.version_info[0] < 3: return open(fpath, flags) else: return open(fpath, flags, encoding="utf8") + # make sure we are in the right directory os.chdir(os.path.dirname(os.path.realpath(__file__))) @@ -121,22 +136,24 @@ def open_utf8(fpath, flags): for i in range(len(sys.argv)): if sys.argv[i].startswith("--binary-dir="): existing_duckdb_dir = sys.argv[i].split('=', 1)[1] - elif sys.argv[i].startswith('--package_name=') : + elif sys.argv[i].startswith('--package_name='): lib_name = sys.argv[i].split('=', 1)[1] elif sys.argv[i].startswith("--compile-flags="): - toolchain_args = ['-std=c++11'] + [x.strip() for x in sys.argv[i].split('=', 1)[1].split(' ') if len(x.strip()) > 0] + toolchain_args = ['-std=c++11'] + [ + x.strip() for x in sys.argv[i].split('=', 1)[1].split(' ') if len(x.strip()) > 0 + ] elif sys.argv[i].startswith("--libs="): libraries = [x.strip() for x in sys.argv[i].split('=', 1)[1].split(' ') if len(x.strip()) > 0] else: new_sys_args.append(sys.argv[i]) sys.argv = new_sys_args -toolchain_args.append('-DDUCKDB_PYTHON_LIB_NAME='+lib_name) +toolchain_args.append('-DDUCKDB_PYTHON_LIB_NAME=' + lib_name) if platform.system() == 'Darwin': toolchain_args.extend(['-stdlib=libc++', '-mmacosx-version-min=10.7']) if platform.system() == 'Windows': - toolchain_args.extend(['-DDUCKDB_BUILD_LIBRARY','-DWIN32']) + toolchain_args.extend(['-DDUCKDB_BUILD_LIBRARY', '-DWIN32']) if 'BUILD_HTTPFS' in os.environ: libraries += ['crypto', 'ssl'] @@ -145,21 +162,26 @@ def open_utf8(fpath, flags): for ext in extensions: toolchain_args.extend(['-DDUCKDB_EXTENSION_{}_LINKED'.format(ext.upper())]) + class get_pybind_include(object): def __init__(self, user=False): self.user = user def __str__(self): import pybind11 + return pybind11.get_include(self.user) + extra_files = [] header_files = [] + def list_source_files(directory): sources = glob('src/**/*.cpp', recursive=True) return sources + script_path = os.path.dirname(os.path.abspath(__file__)) main_include_path = os.path.join(script_path, 'src', 'include') main_source_path = os.path.join(script_path, 'src') @@ -176,9 +198,14 @@ def list_source_files(directory): # copy all source files to the current directory sys.path.append(os.path.join(script_path, '..', '..', 'scripts')) import package_build - (source_list, include_list, original_sources) = package_build.build_package(os.path.join(script_path, lib_name), extensions, False, unity_build) - duckdb_sources = [os.path.sep.join(package_build.get_relative_path(script_path, x).split('/')) for x in source_list] + (source_list, include_list, original_sources) = package_build.build_package( + os.path.join(script_path, lib_name), extensions, False, unity_build + ) + + duckdb_sources = [ + os.path.sep.join(package_build.get_relative_path(script_path, x).split('/')) for x in source_list + ] duckdb_sources.sort() original_sources = [os.path.join(lib_name, x) for x in original_sources] @@ -188,6 +215,7 @@ def list_source_files(directory): # gather the include files import amalgamation + header_files = amalgamation.list_includes_files(duckdb_includes) # write the source list, include list and git hash to separate files @@ -216,13 +244,15 @@ def list_source_files(directory): source_files += duckdb_sources include_directories = duckdb_includes + include_directories - libduckdb = Extension(lib_name + '.duckdb', + libduckdb = Extension( + lib_name + '.duckdb', include_dirs=include_directories, sources=source_files, extra_compile_args=toolchain_args, extra_link_args=toolchain_args, libraries=libraries, - language='c++') + language='c++', + ) else: sys.path.append(os.path.join(script_path, '..', '..', 'scripts')) import package_build @@ -234,14 +264,16 @@ def list_source_files(directory): library_dirs = [x[0] for x in result_libraries if x[0] is not None] libnames = [x[1] for x in result_libraries if x[1] is not None] - libduckdb = Extension(lib_name + '.duckdb', + libduckdb = Extension( + lib_name + '.duckdb', include_dirs=include_directories, sources=main_source_files, extra_compile_args=toolchain_args, extra_link_args=toolchain_args, libraries=libnames, library_dirs=library_dirs, - language='c++') + language='c++', + ) # Only include pytest-runner in setup_requires if we're invoking tests if {'pytest', 'test', 'ptr'}.intersection(sys.argv): @@ -253,6 +285,7 @@ def list_source_files(directory): if os.getenv('SETUPTOOLS_SCM_NO_LOCAL', 'no') != 'no': setuptools_scm_conf['local_scheme'] = 'no-local-version' + # data files need to be formatted as [(directory, [files...]), (directory2, [files...])] # no clue why the setup script can't do this automatically, but hey def setup_data_files(data_files): @@ -275,6 +308,7 @@ def setup_data_files(data_files): new_data_files.append((kv, directory_map[kv])) return new_data_files + data_files = setup_data_files(extra_files + header_files) packages = [ @@ -285,37 +319,34 @@ def setup_data_files(data_files): 'pyduckdb.value', 'duckdb-stubs', 'duckdb-stubs.functional', - 'duckdb-stubs.typing' + 'duckdb-stubs.typing', ] -spark_packages = [ - 'pyduckdb.spark', - 'pyduckdb.spark.sql' -] +spark_packages = ['pyduckdb.spark', 'pyduckdb.spark.sql'] packages.extend(spark_packages) setup( - name = lib_name, - description = 'DuckDB embedded database', - keywords = 'DuckDB Database SQL OLAP', + name=lib_name, + description='DuckDB embedded database', + keywords='DuckDB Database SQL OLAP', url="https://www.duckdb.org", - long_description = 'See here for an introduction: https://duckdb.org/docs/api/python/overview', + long_description='See here for an introduction: https://duckdb.org/docs/api/python/overview', license='MIT', - data_files = data_files, + data_files=data_files, packages=packages, include_package_data=True, setup_requires=setup_requires + ["setuptools_scm<7.0.0", 'pybind11>=2.6.0'], - use_scm_version = setuptools_scm_conf, + use_scm_version=setuptools_scm_conf, tests_require=['google-cloud-storage', 'mypy', 'pytest'], - classifiers = [ + classifiers=[ 'Topic :: Database :: Database Engines/Servers', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', ], - ext_modules = [libduckdb], - maintainer = "Hannes Muehleisen", - maintainer_email = "hannes@cwi.nl", + ext_modules=[libduckdb], + maintainer="Hannes Muehleisen", + maintainer_email="hannes@cwi.nl", cmdclass={"build_ext": build_ext}, project_urls={ "Documentation": "https://duckdb.org/docs/api/python/overview", diff --git a/tools/pythonpkg/tests/conftest.py b/tools/pythonpkg/tests/conftest.py index 6d1d4acd9bd4..0eaac8f5df23 100644 --- a/tools/pythonpkg/tests/conftest.py +++ b/tools/pythonpkg/tests/conftest.py @@ -8,6 +8,7 @@ try: import pandas + pyarrow_dtype = pandas.core.arrays.arrow.dtype.ArrowDtype except: pyarrow_dtype = None @@ -15,15 +16,18 @@ # Check if pandas has arrow dtypes enabled try: from pandas.compat import pa_version_under7p0 + pyarrow_dtypes_enabled = not pa_version_under7p0 except: pyarrow_dtypes_enabled = False + # https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # https://stackoverflow.com/a/47700320 def pytest_addoption(parser): parser.addoption("--skiplist", action="append", nargs="+", type=str, help="skip listed tests") + def pytest_collection_modifyitems(config, items): tests_to_skip = config.getoption("--skiplist") if not tests_to_skip: @@ -44,39 +48,48 @@ def pytest_collection_modifyitems(config, items): # the class is named specifically item.add_marker(skip_listed) + @pytest.fixture(scope="function") def duckdb_empty_cursor(request): connection = duckdb.connect('') cursor = connection.cursor() return cursor + def pandas_supports_arrow_backend(): try: from pandas.compat import pa_version_under7p0 + if pa_version_under7p0 == True: return False except: return False import pandas as pd + return Version(pd.__version__) >= Version('2.0.0') + def numpy_pandas_df(*args, **kwargs): pandas = pytest.importorskip("pandas") return pandas.DataFrame(*args, **kwargs) + def arrow_pandas_df(*args, **kwargs): - df = numpy_pandas_df(*args, **kwargs); + df = numpy_pandas_df(*args, **kwargs) return df.convert_dtypes(dtype_backend="pyarrow") + class NumpyPandas: def __init__(self): self.backend = 'numpy_nullable' self.DataFrame = numpy_pandas_df self.pandas = pytest.importorskip("pandas") + def __getattr__(self, __name: str): item = eval(f'self.pandas.{__name}') return item + def convert_arrow_to_numpy_backend(df): pandas = pytest.importorskip("pandas") names = df.columns @@ -86,24 +99,33 @@ def convert_arrow_to_numpy_backend(df): # This should convert the pyarrow chunked arrays into numpy arrays return pandas.DataFrame(df_content) + def convert_to_numpy(df): - if pyarrow_dtypes_enabled and pyarrow_dtype != None and any([True for x in df.dtypes if isinstance(x, pyarrow_dtype)]): + if ( + pyarrow_dtypes_enabled + and pyarrow_dtype != None + and any([True for x in df.dtypes if isinstance(x, pyarrow_dtype)]) + ): return convert_arrow_to_numpy_backend(df) return df + def convert_and_equal(df1, df2, **kwargs): df1 = convert_to_numpy(df1) df2 = convert_to_numpy(df2) pytest.importorskip("pandas").testing.assert_frame_equal(df1, df2, **kwargs) + class ArrowMockTesting: def __init__(self): self.testing = pytest.importorskip("pandas").testing self.assert_frame_equal = convert_and_equal + def __getattr__(self, __name: str): item = eval(f'self.testing.{__name}') return item + # This converts dataframes constructed with 'DataFrame(...)' to pyarrow backed dataframes # Assert equal does the opposite, turning all pyarrow backed dataframes into numpy backed ones # this is done because we don't produce pyarrow backed dataframes yet @@ -118,10 +140,12 @@ def __init__(self): self.backend = 'numpy_nullable' self.DataFrame = self.pandas.DataFrame self.testing = ArrowMockTesting() + def __getattr__(self, __name: str): item = eval(f'self.pandas.{__name}') return item + @pytest.fixture(scope="function") def require(): def _require(extension_name, db_name=''): @@ -151,7 +175,7 @@ def _require(extension_name, db_name=''): for path in extension_paths_found: print(path) - if (path.endswith(extension_name + ".duckdb_extension")): + if path.endswith(extension_name + ".duckdb_extension"): conn = duckdb.connect(db_name, config={'allow_unsigned_extensions': 'true'}) conn.execute(f"LOAD '{path}'") return conn diff --git a/tools/pythonpkg/tests/coverage/test_pandas_categorical_coverage.py b/tools/pythonpkg/tests/coverage/test_pandas_categorical_coverage.py index 2bc71051f44b..e20afa726c03 100644 --- a/tools/pythonpkg/tests/coverage/test_pandas_categorical_coverage.py +++ b/tools/pythonpkg/tests/coverage/test_pandas_categorical_coverage.py @@ -3,25 +3,31 @@ import pytest from conftest import NumpyPandas, ArrowPandas + def check_result_list(res): for res_item in res: assert res_item[0] == res_item[1] + def check_create_table(category, pandas): conn = duckdb.connect() - conn.execute ("PRAGMA enable_verification") - df_in = pandas.DataFrame({ - 'x': pandas.Categorical(category, ordered=True), - 'y': pandas.Categorical(category, ordered=True), - 'z': category - }) + conn.execute("PRAGMA enable_verification") + df_in = pandas.DataFrame( + { + 'x': pandas.Categorical(category, ordered=True), + 'y': pandas.Categorical(category, ordered=True), + 'z': category, + } + ) category.append('bla') - df_in_diff = pandas.DataFrame({ - 'k': pandas.Categorical(category, ordered=True), - }) + df_in_diff = pandas.DataFrame( + { + 'k': pandas.Categorical(category, ordered=True), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data") df_out = df_out.df() @@ -31,18 +37,18 @@ def check_create_table(category, pandas): conn.execute("CREATE TABLE t2 AS SELECT * FROM df_in") # Check fetchall - res = conn.execute("SELECT x,z FROM t1").fetchall() + res = conn.execute("SELECT x,z FROM t1").fetchall() check_result_list(res) - # Do a insert to trigger string -> cat + # Do a insert to trigger string -> cat conn.execute("INSERT INTO t1 VALUES ('2','2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() assert res == [('1',)] - res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x) order by t1.x").fetchall() + res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x) order by t1.x").fetchall() assert res == conn.execute("SELECT x FROM t1 order by t1.x").fetchall() - + res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.y) order by t1.x").fetchall() correct_res = conn.execute("SELECT x FROM t1 order by x").fetchall() assert res == correct_res @@ -61,19 +67,19 @@ def check_create_table(category, pandas): # We should be able to drop the table without any dependencies conn.execute("DROP TABLE t1") + # TODO: extend tests with ArrowPandas class TestCategory(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_category_string_uint16(self, duckdb_cursor, pandas): category = [] - for i in range (300): + for i in range(300): category.append(str(i)) check_create_table(category, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_category_string_uint32(self, duckdb_cursor, pandas): category = [] - for i in range (70000): + for i in range(70000): category.append(str(i)) - check_create_table(category, pandas) \ No newline at end of file + check_create_table(category, pandas) diff --git a/tools/pythonpkg/tests/extensions/json/test_read_json.py b/tools/pythonpkg/tests/extensions/json/test_read_json.py index 4d67e0cd1c40..ca46e07ab717 100644 --- a/tools/pythonpkg/tests/extensions/json/test_read_json.py +++ b/tools/pythonpkg/tests/extensions/json/test_read_json.py @@ -5,14 +5,17 @@ import duckdb import re + def TestFile(name): import os - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data',name) + + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', name) return filename + class TestReadJSON(object): def test_read_json_columns(self): - rel = duckdb.read_json(TestFile('example.json'), columns={'id':'integer', 'name':'varchar'}) + rel = duckdb.read_json(TestFile('example.json'), columns={'id': 'integer', 'name': 'varchar'}) res = rel.fetchone() print(res) assert res == (1, 'O Brother, Where Art Thou?') @@ -34,7 +37,7 @@ def test_read_json_sample_size(self): res = rel.fetchone() print(res) assert res == (1, 'O Brother, Where Art Thou?') - + def test_read_json_format(self): # Wrong option with pytest.raises(duckdb.BinderException, match="format must be one of .* not 'test'"): @@ -43,8 +46,15 @@ def test_read_json_format(self): rel = duckdb.read_json(TestFile('example.json'), format='unstructured') res = rel.fetchone() print(res) - assert res == ([{'id': 1, 'name': 'O Brother, Where Art Thou?'}, {'id': 2, 'name': 'Home for the Holidays'}, {'id': 3, 'name': 'The Firm'}, {'id': 4, 'name': 'Broadcast News'}, {'id': 5, 'name': 'Raising Arizona'}],) - + assert res == ( + [ + {'id': 1, 'name': 'O Brother, Where Art Thou?'}, + {'id': 2, 'name': 'Home for the Holidays'}, + {'id': 3, 'name': 'The Firm'}, + {'id': 4, 'name': 'Broadcast News'}, + {'id': 5, 'name': 'Raising Arizona'}, + ], + ) def test_read_json_records(self): # Wrong option @@ -55,5 +65,3 @@ def test_read_json_records(self): res = rel.fetchone() print(res) assert res == (1, 'O Brother, Where Art Thou?') - - diff --git a/tools/pythonpkg/tests/extensions/test_httpfs.py b/tools/pythonpkg/tests/extensions/test_httpfs.py index 1d4f78919d8b..2cc406a51f71 100644 --- a/tools/pythonpkg/tests/extensions/test_httpfs.py +++ b/tools/pythonpkg/tests/extensions/test_httpfs.py @@ -6,47 +6,51 @@ # We only run this test if this env var is set pytestmark = mark.skipif( - not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False), - reason='httpfs extension not available' + not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False), reason='httpfs extension not available' ) + class TestHTTPFS(object): - def test_read_json_httpfs(self, require): - connection = require('httpfs') - # FIXME: add test back - # res = connection.read_json('https://jsonplaceholder.typicode.com/todos') - # assert len(res.types) == 4 - - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_httpfs(self, require, pandas): - connection = require('httpfs') - try: - connection.execute("SELECT id, first_name, last_name FROM PARQUET_SCAN('https://raw.githubusercontent.com/cwida/duckdb/master/data/parquet-testing/userdata1.parquet') LIMIT 3;") - except RuntimeError as e: - # Test will ignore result if it fails due to networking issues while running the test. - if (str(e).startswith("HTTP HEAD error")): - return - elif (str(e).startswith("Unable to connect")): - return - else: - raise e - - result_df = connection.fetchdf() - exp_result = pandas.DataFrame({ - 'id': pandas.Series([1, 2, 3], dtype="int32"), - 'first_name': ['Amanda', 'Albert', 'Evelyn'], - 'last_name': ['Jordan', 'Freeman', 'Morgan'] - }) - pandas.testing.assert_frame_equal(result_df, exp_result) - - def test_http_exception(self, require): - connection = require('httpfs') - - with raises(duckdb.HTTPException) as exc: - connection.execute("SELECT * FROM PARQUET_SCAN('https://example.com/userdata1.parquet')") - - value = exc.value - assert value.status_code == 404 - assert value.reason == 'Not Found' - assert value.body == '' - assert 'Content-Length' in value.headers + def test_read_json_httpfs(self, require): + connection = require('httpfs') + # FIXME: add test back + # res = connection.read_json('https://jsonplaceholder.typicode.com/todos') + # assert len(res.types) == 4 + + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_httpfs(self, require, pandas): + connection = require('httpfs') + try: + connection.execute( + "SELECT id, first_name, last_name FROM PARQUET_SCAN('https://raw.githubusercontent.com/cwida/duckdb/master/data/parquet-testing/userdata1.parquet') LIMIT 3;" + ) + except RuntimeError as e: + # Test will ignore result if it fails due to networking issues while running the test. + if str(e).startswith("HTTP HEAD error"): + return + elif str(e).startswith("Unable to connect"): + return + else: + raise e + + result_df = connection.fetchdf() + exp_result = pandas.DataFrame( + { + 'id': pandas.Series([1, 2, 3], dtype="int32"), + 'first_name': ['Amanda', 'Albert', 'Evelyn'], + 'last_name': ['Jordan', 'Freeman', 'Morgan'], + } + ) + pandas.testing.assert_frame_equal(result_df, exp_result) + + def test_http_exception(self, require): + connection = require('httpfs') + + with raises(duckdb.HTTPException) as exc: + connection.execute("SELECT * FROM PARQUET_SCAN('https://example.com/userdata1.parquet')") + + value = exc.value + assert value.status_code == 404 + assert value.reason == 'Not Found' + assert value.body == '' + assert 'Content-Length' in value.headers diff --git a/tools/pythonpkg/tests/fast/api/test_3324.py b/tools/pythonpkg/tests/fast/api/test_3324.py index 1832017caf37..52b5ea28a407 100644 --- a/tools/pythonpkg/tests/fast/api/test_3324.py +++ b/tools/pythonpkg/tests/fast/api/test_3324.py @@ -1,18 +1,21 @@ import pytest import duckdb -class Test3324(object): +class Test3324(object): def test_3324(self, duckdb_cursor): - create_output = duckdb_cursor.execute(""" + create_output = duckdb_cursor.execute( + """ create or replace table my_table as select 'test1' as column1, 1 as column2, 'quack' as column3 union all select 'test2' as column1, 2 as column2, 'quacks' as column3 union all select 'test3' as column1, 3 as column2, 'quacking' as column3 - """).fetch_df() - prepare_output = duckdb_cursor.execute(""" + """ + ).fetch_df() + prepare_output = duckdb_cursor.execute( + """ prepare v1 as select column1 @@ -20,9 +23,10 @@ def test_3324(self, duckdb_cursor): , column3 from my_table where - column1 = $1""").fetch_df() + column1 = $1""" + ).fetch_df() with pytest.raises(duckdb.BinderException, match="Unexpected prepared parameter"): - duckdb_cursor.execute("""execute v1(?)""",'test1').fetch_df() + duckdb_cursor.execute("""execute v1(?)""", 'test1').fetch_df() with pytest.raises(duckdb.BinderException, match="Unexpected prepared parameter"): - duckdb_cursor.execute("""execute v1(?)""",('test1',)).fetch_df() \ No newline at end of file + duckdb_cursor.execute("""execute v1(?)""", ('test1',)).fetch_df() diff --git a/tools/pythonpkg/tests/fast/api/test_3654.py b/tools/pythonpkg/tests/fast/api/test_3654.py index e3d817271b4e..e63f0cd12dab 100644 --- a/tools/pythonpkg/tests/fast/api/test_3654.py +++ b/tools/pythonpkg/tests/fast/api/test_3654.py @@ -1,21 +1,25 @@ import duckdb import pytest + try: import pyarrow as pa + can_run = True except: can_run = False from conftest import NumpyPandas, ArrowPandas + class Test3654(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_3654_pandas(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame({ - 'id': [1, 1, 2], - }) + df1 = pandas.DataFrame( + { + 'id': [1, 1, 2], + } + ) con = duckdb.connect() - con.register("df1",df1) + con.register("df1", df1) rel = con.view("df1") print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(1,), (1,), (2,)] @@ -25,12 +29,14 @@ def test_3654_arrow(self, duckdb_cursor, pandas): if not can_run: return - df1 = pandas.DataFrame({ - 'id': [1, 1, 2], - }) + df1 = pandas.DataFrame( + { + 'id': [1, 1, 2], + } + ) table = pa.Table.from_pandas(df1) con = duckdb.connect() - con.register("df1",table) + con.register("df1", table) rel = con.view("df1") print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(1,), (1,), (2,)] diff --git a/tools/pythonpkg/tests/fast/api/test_3728.py b/tools/pythonpkg/tests/fast/api/test_3728.py index 78e9ed26f99d..aa5864c23f1c 100644 --- a/tools/pythonpkg/tests/fast/api/test_3728.py +++ b/tools/pythonpkg/tests/fast/api/test_3728.py @@ -1,5 +1,6 @@ import duckdb + class Test3728(object): def test_3728_describe_enum(self, duckdb_cursor): # Create an in-memory database, but the problem is also present in file-backed DBs @@ -12,4 +13,7 @@ def test_3728_describe_enum(self, duckdb_cursor): cursor.execute("CREATE TABLE person (name text, current_mood mood);") # This fails with "RuntimeError: Not implemented Error: unsupported type: mood" - assert cursor.table("person").execute().description == [('name', 'STRING', None, None, None, None, None), ('current_mood', 'mood', None, None, None, None, None)] + assert cursor.table("person").execute().description == [ + ('name', 'STRING', None, None, None, None, None), + ('current_mood', 'mood', None, None, None, None, None), + ] diff --git a/tools/pythonpkg/tests/fast/api/test_6315.py b/tools/pythonpkg/tests/fast/api/test_6315.py index ed5487d6057f..e8eaff591ffd 100644 --- a/tools/pythonpkg/tests/fast/api/test_6315.py +++ b/tools/pythonpkg/tests/fast/api/test_6315.py @@ -1,5 +1,6 @@ import duckdb + class Test6315(object): def test_6315(self, duckdb_cursor): # segfault when accessing description after fetching rows diff --git a/tools/pythonpkg/tests/fast/api/test_attribute_getter.py b/tools/pythonpkg/tests/fast/api/test_attribute_getter.py index 5fb40df6a23e..dd2955db1d09 100644 --- a/tools/pythonpkg/tests/fast/api/test_attribute_getter.py +++ b/tools/pythonpkg/tests/fast/api/test_attribute_getter.py @@ -8,19 +8,20 @@ import csv import pytest + class TestGetAttribute(object): def test_basic_getattr(self): rel = duckdb.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') assert rel.a.fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] assert rel.b.fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] assert rel.c.fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] - + def test_basic_getitem(self): rel = duckdb.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') assert rel['a'].fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] assert rel['b'].fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] assert rel['c'].fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] - + def test_getitem_nonexistant(self): rel = duckdb.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') with pytest.raises(AttributeError): @@ -33,13 +34,13 @@ def test_getattr_nonexistant(self): def test_getattr_collision(self): rel = duckdb.sql('select i as df from range(100) tbl(i)') - + # 'df' also exists as a method on DuckDBPyRelation assert rel.df.__class__ != duckdb.DuckDBPyRelation def test_getitem_collision(self): rel = duckdb.sql('select i as df from range(100) tbl(i)') - + # this case is not an issue on __getitem__ assert rel['df'].__class__ == duckdb.DuckDBPyRelation @@ -51,4 +52,4 @@ def test_getitem_struct(self): def test_getattr_struct(self): rel = duckdb.sql("select {'a':5, 'b':6} as a, 5 as b") assert rel.a.a.fetchall()[0][0] == 5 - assert rel.a.b.fetchall()[0][0] == 6 \ No newline at end of file + assert rel.a.b.fetchall()[0][0] == 6 diff --git a/tools/pythonpkg/tests/fast/api/test_config.py b/tools/pythonpkg/tests/fast/api/test_config.py index a37bb796ec90..bad8d0d51543 100644 --- a/tools/pythonpkg/tests/fast/api/test_config.py +++ b/tools/pythonpkg/tests/fast/api/test_config.py @@ -5,31 +5,32 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestDBConfig(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_default_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1,2,3]}) + df = pandas.DataFrame({'a': [1, 2, 3]}) con = duckdb.connect(':memory:', config={'default_order': 'desc'}) result = con.execute('select * from df order by a').fetchall() assert result == [(3,), (2,), (1,)] @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_null_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1,2,3,None]}) + df = pandas.DataFrame({'a': [1, 2, 3, None]}) con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last'}) result = con.execute('select * from df order by a').fetchall() assert result == [(1,), (2,), (3,), (None,)] @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_multiple_options(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1,2,3,None]}) + df = pandas.DataFrame({'a': [1, 2, 3, None]}) con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last', 'default_order': 'desc'}) result = con.execute('select * from df order by a').fetchall() assert result == [(3,), (2,), (1,), (None,)] @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_external_access(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1,2,3]}) + df = pandas.DataFrame({'a': [1, 2, 3]}) # this works (replacement scan) con_regular = duckdb.connect(':memory:', config={}) con_regular.execute('select * from df') @@ -49,7 +50,7 @@ def test_unrecognized_option(self, duckdb_cursor): con_regular = duckdb.connect(':memory:', config={'thisoptionisprobablynotthere': '42'}) except: success = False - assert success==False + assert success == False def test_incorrect_parameter(self, duckdb_cursor): success = True @@ -57,4 +58,4 @@ def test_incorrect_parameter(self, duckdb_cursor): con_regular = duckdb.connect(':memory:', config={'default_null_order': '42'}) except: success = False - assert success==False + assert success == False diff --git a/tools/pythonpkg/tests/fast/api/test_connection_close.py b/tools/pythonpkg/tests/fast/api/test_connection_close.py index adc35df2a6e9..8d0a7c7de5d4 100644 --- a/tools/pythonpkg/tests/fast/api/test_connection_close.py +++ b/tools/pythonpkg/tests/fast/api/test_connection_close.py @@ -5,13 +5,15 @@ import os import pytest + def check_exception(f): had_exception = False try: f() except BaseException: had_exception = True - assert(had_exception) + assert had_exception + class TestConnectionClose(object): def test_connection_close(self, duckdb_cursor): @@ -23,14 +25,14 @@ def test_connection_close(self, duckdb_cursor): cursor.execute("create table a (i integer)") cursor.execute("insert into a values (42)") con.close() - check_exception(lambda :cursor.execute("select * from a")) + check_exception(lambda: cursor.execute("select * from a")) def test_open_and_exit(self): with pytest.raises(TypeError): with duckdb.connect() as connection: connection.execute("select 42") # This exception does not get swallowed by __exit__ - raise TypeError(); + raise TypeError() def test_reopen_connection(self, duckdb_cursor): fd, db = tempfile.mkstemp() diff --git a/tools/pythonpkg/tests/fast/api/test_cursor.py b/tools/pythonpkg/tests/fast/api/test_cursor.py index aa95de79e2ce..549c8f071e97 100644 --- a/tools/pythonpkg/tests/fast/api/test_cursor.py +++ b/tools/pythonpkg/tests/fast/api/test_cursor.py @@ -12,13 +12,13 @@ def test_cursor_basic(self): cursor = con.cursor() # Use the cursor for queries res = cursor.execute("select [1,2,3,NULL,4]").fetchall() - assert res == [([1,2,3,None,4],)] + assert res == [([1, 2, 3, None, 4],)] def test_cursor_preexisting(self): con = duckdb.connect(':memory:') con.execute("create table tbl as select i a, i+1 b, i+2 c from range(5) tbl(i)") cursor = con.cursor() - res = cursor.execute("select * from tbl").fetchall(); + res = cursor.execute("select * from tbl").fetchall() assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_after_creation(self): @@ -27,7 +27,7 @@ def test_cursor_after_creation(self): cursor = con.cursor() # Then create table on the source connection con.execute("create table tbl as select i a, i+1 b, i+2 c from range(5) tbl(i)") - res = cursor.execute("select * from tbl").fetchall(); + res = cursor.execute("select * from tbl").fetchall() assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_mixed(self): @@ -39,7 +39,7 @@ def test_cursor_mixed(self): # Close the cursor and create a new one cursor.close() cursor = con.cursor() - res = cursor.execute("select * from tbl").fetchall(); + res = cursor.execute("select * from tbl").fetchall() assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_temp_schema_closed(self): @@ -59,7 +59,7 @@ def test_cursor_temp_schema_open(self): cursor.execute("create temp table tbl as select * from range(100)") other_cursor = con.cursor() # Connection that created the table is still open - #cursor.close() + # cursor.close() with pytest.raises(duckdb.CatalogException): # This table does not exist in this cursor res = other_cursor.execute("select * from tbl").fetchall() diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi00.py b/tools/pythonpkg/tests/fast/api/test_dbapi00.py index 291a2e28d2e3..f13f5079d711 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi00.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi00.py @@ -5,6 +5,7 @@ import duckdb from conftest import NumpyPandas, ArrowPandas + def assert_result_equal(result): assert result == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (None,)], "Incorrect result returned" @@ -23,18 +24,18 @@ def test_fetchmany_default(self, duckdb_cursor): # by default 'size' is 1 arraysize = 1 list_of_results = [] - while (True): + while True: res = duckdb_cursor.fetchmany() - assert(isinstance(res, list)) + assert isinstance(res, list) list_of_results.extend(res) - if (len(res) == 0): + if len(res) == 0: break - assert(len(list_of_results) == truth_value) + assert len(list_of_results) == truth_value assert_result_equal(list_of_results) res = duckdb_cursor.fetchmany(2) - assert(len(res) == 0) + assert len(res) == 0 res = duckdb_cursor.fetchmany(3) - assert(len(res) == 0) + assert len(res) == 0 def test_fetchmany(self, duckdb_cursor): # Get truth value @@ -44,33 +45,33 @@ def test_fetchmany(self, duckdb_cursor): arraysize = 3 expected_iteration_count = 1 + (int)(truth_value / arraysize) + (1 if truth_value % arraysize else 0) iteration_count = 0 - print("truth_value:",truth_value) - print("expected_iteration_count:",expected_iteration_count) - while (True): + print("truth_value:", truth_value) + print("expected_iteration_count:", expected_iteration_count) + while True: print(iteration_count) res = duckdb_cursor.fetchmany(3) print(res) iteration_count += 1 - assert(isinstance(res, list)) + assert isinstance(res, list) list_of_results.extend(res) - if (len(res) == 0): + if len(res) == 0: break - assert(iteration_count == expected_iteration_count) - assert(len(list_of_results) == truth_value) + assert iteration_count == expected_iteration_count + assert len(list_of_results) == truth_value assert_result_equal(list_of_results) res = duckdb_cursor.fetchmany(3) - assert(len(res) == 0) + assert len(res) == 0 def test_fetchmany_too_many(self, duckdb_cursor): truth_value = len(duckdb_cursor.execute('select * from integers').fetchall()) duckdb_cursor.execute('select * from integers') res = duckdb_cursor.fetchmany(truth_value * 5) - assert(len(res) == truth_value) + assert len(res) == truth_value assert_result_equal(res) res = duckdb_cursor.fetchmany(2) - assert(len(res) == 0) + assert len(res) == 0 res = duckdb_cursor.fetchmany(3) - assert(len(res) == 0) + assert len(res) == 0 def test_numpy_selection(self, duckdb_cursor): duckdb_cursor.execute('SELECT * FROM integers') @@ -97,9 +98,7 @@ def test_pandas_selection(self, duckdb_cursor, pandas): duckdb_cursor.execute('SELECT * FROM timestamps') result = duckdb_cursor.fetchdf() - df = pandas.DataFrame({ - 't': pandas.to_datetime(['1992-10-03 18:34:45', '2010-01-01 00:00:01', None]) - }) + df = pandas.DataFrame({'t': pandas.to_datetime(['1992-10-03 18:34:45', '2010-01-01 00:00:01', None])}) pandas.testing.assert_frame_equal(result, df) # def test_numpy_creation(self, duckdb_cursor): diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi01.py b/tools/pythonpkg/tests/fast/api/test_dbapi01.py index 8a6a877ad10d..0c34b3373767 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi01.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi01.py @@ -3,18 +3,31 @@ import numpy import duckdb + class TestMultipleResultSets(object): def test_regular_selection(self, duckdb_cursor): duckdb_cursor.execute('SELECT * FROM integers') duckdb_cursor.execute('SELECT * FROM integers') result = duckdb_cursor.fetchall() - assert result == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (None,)], "Incorrect result returned" + assert result == [ + (0,), + (1,), + (2,), + (3,), + (4,), + (5,), + (6,), + (7,), + (8,), + (9,), + (None,), + ], "Incorrect result returned" def test_numpy_selection(self, duckdb_cursor): duckdb_cursor.execute('SELECT * FROM integers') duckdb_cursor.execute('SELECT * FROM integers') result = duckdb_cursor.fetchnumpy() - expected = numpy.ma.masked_array(numpy.arange(11), mask=([False]*10 + [True])) + expected = numpy.ma.masked_array(numpy.arange(11), mask=([False] * 10 + [True])) numpy.testing.assert_array_equal(result['i'], expected) @@ -25,4 +38,4 @@ def test_numpy_materialized(self, duckdb_cursor): cursor.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') rel = connection.table("integers") res = rel.aggregate("sum(i)").execute().fetchnumpy() - assert res['sum(i)'][0] == 45 \ No newline at end of file + assert res['sum(i)'][0] == 45 diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi04.py b/tools/pythonpkg/tests/fast/api/test_dbapi04.py index 855ba6e229e4..4e91b9274ecb 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi04.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi04.py @@ -1,8 +1,20 @@ -#simple DB API testcase +# simple DB API testcase class TestSimpleDBAPI(object): def test_regular_selection(self, duckdb_cursor): duckdb_cursor.execute('SELECT * FROM integers') result = duckdb_cursor.fetchall() - assert result == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (None,)], "Incorrect result returned" + assert result == [ + (0,), + (1,), + (2,), + (3,), + (4,), + (5,), + (6,), + (7,), + (8,), + (9,), + (None,), + ], "Incorrect result returned" diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi05.py b/tools/pythonpkg/tests/fast/api/test_dbapi05.py index 8efcb45e059b..0de217f22e3f 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi05.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi05.py @@ -1,32 +1,40 @@ -#simple DB API testcase +# simple DB API testcase + class TestSimpleDBAPI(object): def test_prepare(self, duckdb_cursor): result = duckdb_cursor.execute('SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)', ['42', '84']).fetchall() - assert result == [(42, 84, )], "Incorrect result returned" + assert result == [ + ( + 42, + 84, + ) + ], "Incorrect result returned" c = duckdb_cursor # from python docs - c.execute('''CREATE TABLE stocks - (date text, trans text, symbol text, qty real, price real)''') + c.execute( + '''CREATE TABLE stocks + (date text, trans text, symbol text, qty real, price real)''' + ) c.execute("INSERT INTO stocks VALUES ('2006-01-05','BUY','RHAT',100,35.14)") t = ('RHAT',) result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() assert result == (1,) - t = ['RHAT'] result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() assert result == (1,) # Larger example that inserts many records at a time - purchases = [('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), - ] + purchases = [ + ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), + ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), + ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ] c.executemany('INSERT INTO stocks VALUES (?,?,?,?,?)', purchases) result = c.execute('SELECT count(*) FROM stocks').fetchone() - assert result == (4, ) + assert result == (4,) diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi07.py b/tools/pythonpkg/tests/fast/api/test_dbapi07.py index 46ca13a759f9..7792b8de7656 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi07.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi07.py @@ -1,8 +1,9 @@ # timestamp ms precision -import numpy +import numpy from datetime import datetime + class TestNumpyTimestampMilliseconds(object): def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchnumpy() diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi08.py b/tools/pythonpkg/tests/fast/api/test_dbapi08.py index 27fd03ad9dd9..0f4ef7aa06de 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi08.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi08.py @@ -4,6 +4,7 @@ import duckdb from conftest import NumpyPandas, ArrowPandas + class TestType(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetchdf(self, pandas): diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi09.py b/tools/pythonpkg/tests/fast/api/test_dbapi09.py index a58997e7607a..02899a3868cc 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi09.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi09.py @@ -4,6 +4,7 @@ import datetime import pandas + class TestNumpyDate(object): def test_fetchall_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchall() diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi10.py b/tools/pythonpkg/tests/fast/api/test_dbapi10.py index 8c1d946ec925..433b663cae42 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi10.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi10.py @@ -18,13 +18,9 @@ class TestCursorDescription(object): ["SELECT union_value(tag := 1) AS union_col", "union_col", "UNION(tag INTEGER)", int], ], ) - def test_description( - self, query, column_name, string_type, real_type, duckdb_cursor - ): + def test_description(self, query, column_name, string_type, real_type, duckdb_cursor): duckdb_cursor.execute(query) - assert duckdb_cursor.description == [ - (column_name, string_type, None, None, None, None, None) - ] + assert duckdb_cursor.description == [(column_name, string_type, None, None, None, None, None)] assert isinstance(duckdb_cursor.fetchone()[0], real_type) def test_none_description(self, duckdb_empty_cursor): diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi11.py b/tools/pythonpkg/tests/fast/api/test_dbapi11.py index 243b2c769876..91237b9e0323 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi11.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi11.py @@ -4,13 +4,15 @@ import tempfile import os + def check_exception(f): had_exception = False try: f() except: had_exception = True - assert(had_exception) + assert had_exception + class TestReadOnly(object): def test_readonly(self, duckdb_cursor): @@ -19,7 +21,7 @@ def test_readonly(self, duckdb_cursor): os.remove(db) # this is forbidden - check_exception(lambda :duckdb.connect(":memory:", True)) + check_exception(lambda: duckdb.connect(":memory:", True)) con_rw = duckdb.connect(db, False) con_rw.cursor().execute("create table a (i integer)") @@ -28,7 +30,7 @@ def test_readonly(self, duckdb_cursor): con_ro = duckdb.connect(db, True) con_ro.cursor().execute("select * from a").fetchall() - check_exception(lambda : con_ro.execute("delete from a")) + check_exception(lambda: con_ro.execute("delete from a")) con_ro.close() con_rw = duckdb.connect(db, False) diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi12.py b/tools/pythonpkg/tests/fast/api/test_dbapi12.py index f2f2a2b17af5..62bfa73e56bf 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi12.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi12.py @@ -3,12 +3,21 @@ import os import pandas as pd + class TestRelationApi(object): def test_readonly(self, duckdb_cursor): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3], "j":["one", "two", "three"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["one", "two", "three"]}) def test_rel(rel, duckdb_cursor): - res = rel.filter('i < 3').order('j').project('i').union(rel.filter('i > 2').project('i')).join(rel.set_alias('a1'), 'i').project('CAST(i as BIGINT) i, j').order('i') + res = ( + rel.filter('i < 3') + .order('j') + .project('i') + .union(rel.filter('i > 2').project('i')) + .join(rel.set_alias('a1'), 'i') + .project('CAST(i as BIGINT) i, j') + .order('i') + ) pd.testing.assert_frame_equal(res.to_df(), test_df) res3 = duckdb_cursor.from_df(res.to_df()).to_df() pd.testing.assert_frame_equal(res3, test_df) @@ -17,7 +26,7 @@ def test_rel(rel, duckdb_cursor): pd.testing.assert_frame_equal(df_sql.df(), test_df) res2 = res.aggregate('i, count(j) as cj', 'i').order('i') - cmp_df = pd.DataFrame.from_dict({"i":[1, 2, 3], "cj":[1, 1, 1]}) + cmp_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "cj": [1, 1, 1]}) pd.testing.assert_frame_equal(res2.to_df(), cmp_df) duckdb_cursor.execute('DROP TABLE IF EXISTS a2') @@ -40,8 +49,8 @@ def test_rel(rel, duckdb_cursor): rel_a = duckdb_cursor.table('a') rel_v = duckdb_cursor.view('v') - #rel_at = duckdb_cursor.table('at') - #rel_vt = duckdb_cursor.view('vt') + # rel_at = duckdb_cursor.table('at') + # rel_vt = duckdb_cursor.view('vt') rel_df = duckdb_cursor.from_df(test_df) @@ -58,6 +67,5 @@ def test_fromquery(self, duckdb_cursor): # assert duckdb_cursor.from_query('select 45').execute().fetchone()[0] == 45 - # cursor = duckdb.connect().cursor() -# TestRelationApi().test_readonly(cursor) \ No newline at end of file +# TestRelationApi().test_readonly(cursor) diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi13.py b/tools/pythonpkg/tests/fast/api/test_dbapi13.py index 39733e2c4f23..3271c59b5f6f 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi13.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi13.py @@ -4,6 +4,7 @@ import datetime import pandas + class TestNumpyTime(object): def test_fetchall_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchall() diff --git a/tools/pythonpkg/tests/fast/api/test_dbapi_fetch.py b/tools/pythonpkg/tests/fast/api/test_dbapi_fetch.py index 3f89b9dea9cb..e43cd0558e7d 100644 --- a/tools/pythonpkg/tests/fast/api/test_dbapi_fetch.py +++ b/tools/pythonpkg/tests/fast/api/test_dbapi_fetch.py @@ -1,6 +1,7 @@ import duckdb import pytest + class TestDBApiFetch(object): def test_multiple_fetch_one(self, duckdb_cursor): con = duckdb.connect() diff --git a/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py b/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py index 910627c619d2..20a628e2bb7a 100644 --- a/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py +++ b/tools/pythonpkg/tests/fast/api/test_duckdb_connection.py @@ -2,19 +2,25 @@ import pytest from conftest import NumpyPandas, ArrowPandas + def is_dunder_method(method_name: str) -> bool: - if (len(method_name) < 4): + if len(method_name) < 4: return False return method_name[:2] == '__' and method_name[:-3:-1] == '__' + # This file contains tests for DuckDBPyConnection methods, # wrapped by the 'duckdb' module, to execute with the 'default_connection' class TestDuckDBConnection(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_append(self, pandas): duckdb.execute("Create table integers (i integer)") - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) - duckdb.append('integers',df_in) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) + duckdb.append('integers', df_in) assert duckdb.execute('select count(*) from integers').fetchone()[0] == 5 # cleanup duckdb.execute("drop table integers") @@ -28,7 +34,9 @@ def test_default_connection_from_connect(self): con.sql('select i from connect_default_connect') # not allowed with additional options - with pytest.raises(duckdb.InvalidInputException, match='Default connection fetching is only allowed without additional options'): + with pytest.raises( + duckdb.InvalidInputException, match='Default connection fetching is only allowed without additional options' + ): con = duckdb.connect(':default:', read_only=True) def test_arrow(self): @@ -62,7 +70,7 @@ def test_cursor(self): duckdb.table("tbl") def test_df(self): - ref = [([1,2,3],)] + ref = [([1, 2, 3],)] duckdb.execute("select [1,2,3]") res_df = duckdb.fetch_df() res = duckdb.query("select * from res_df").fetchall() @@ -77,7 +85,7 @@ def test_duplicate(self): dup_conn.table("tbl").fetchall() def test_execute(self): - assert [([4,2],)] == duckdb.execute("select [4,2]").fetchall() + assert [([4, 2],)] == duckdb.execute("select [4,2]").fetchall() def test_executemany(self): # executemany does not keep an open result set @@ -94,9 +102,9 @@ def test_fetch_arrow_table(self): duckdb.execute("Create Table test (a integer)") - for i in range (1024): + for i in range(1024): for j in range(2): - duckdb.execute("Insert Into test values ('"+str(i)+"')") + duckdb.execute("Insert Into test values ('" + str(i) + "')") duckdb.execute("Insert Into test values ('5000')") duckdb.execute("Insert Into test values ('6000')") sql = ''' @@ -114,7 +122,7 @@ def test_fetch_arrow_table(self): duckdb.execute("drop table test") def test_fetch_df(self): - ref = [([1,2,3],)] + ref = [([1, 2, 3],)] duckdb.execute("select [1,2,3]") res_df = duckdb.fetch_df() res = duckdb.query("select * from res_df").fetchall() @@ -124,11 +132,11 @@ def test_fetch_df_chunk(self): duckdb.execute("CREATE table t as select range a from range(3000);") query = duckdb.execute("SELECT a FROM t") cur_chunk = query.fetch_df_chunk() - assert(cur_chunk['a'][0] == 0) - assert(len(cur_chunk) == 2048) + assert cur_chunk['a'][0] == 0 + assert len(cur_chunk) == 2048 cur_chunk = query.fetch_df_chunk() - assert(cur_chunk['a'][0] == 2048) - assert(len(cur_chunk) == 952) + assert cur_chunk['a'][0] == 2048 + assert len(cur_chunk) == 952 duckdb.execute("DROP TABLE t") def test_fetch_record_batch(self): @@ -139,13 +147,13 @@ def test_fetch_record_batch(self): duckdb.execute("SELECT a FROM t") record_batch_reader = duckdb.fetch_record_batch(1024) chunk = record_batch_reader.read_all() - assert(len(chunk) == 3000) + assert len(chunk) == 3000 def test_fetchall(self): assert [([1, 2, 3],)] == duckdb.execute("select [1,2,3]").fetchall() def test_fetchdf(self): - ref = [([1,2,3],)] + ref = [([1, 2, 3],)] duckdb.execute("select [1,2,3]") res_df = duckdb.fetchdf() res = duckdb.query("select * from res_df").fetchall() @@ -220,14 +228,14 @@ def temporary_scope(): # Create a connection, we will return this con = duckdb.connect() # Create a dataframe - df = pandas.DataFrame({'a': [1,2,3]}) + df = pandas.DataFrame({'a': [1, 2, 3]}) # The dataframe has to be registered as well # making sure it does not go out of scope con.register("df", df) rel = con.sql('select * from df') con.register("relation", rel) return con - + con = temporary_scope() res = con.sql('select * from relation').fetchall() print(res) diff --git a/tools/pythonpkg/tests/fast/api/test_duckdb_query.py b/tools/pythonpkg/tests/fast/api/test_duckdb_query.py index c8a353c4c7a0..b841df676dc5 100644 --- a/tools/pythonpkg/tests/fast/api/test_duckdb_query.py +++ b/tools/pythonpkg/tests/fast/api/test_duckdb_query.py @@ -1,20 +1,20 @@ - import duckdb import pytest from conftest import NumpyPandas, ArrowPandas from pyduckdb import Value + class TestDuckDBQuery(object): def test_duckdb_query(self, duckdb_cursor): # we can use duckdb.query to run both DDL statements and select statements duckdb.query('create view v1 as select 42 i') rel = duckdb.query('select * from v1') - assert rel.fetchall()[0][0] == 42; + assert rel.fetchall()[0][0] == 42 # also multiple statements duckdb.query('create view v2 as select i*2 j from v1; create view v3 as select j * 2 from v2;') rel = duckdb.query('select * from v3') - assert rel.fetchall()[0][0] == 168; + assert rel.fetchall()[0][0] == 168 # we can run multiple select statements - we get only the last result res = duckdb.query('select 42; select 84;').fetchall() @@ -22,15 +22,17 @@ def test_duckdb_query(self, duckdb_cursor): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_duckdb_from_query_multiple_statements(self, pandas): - tst_df = pandas.DataFrame({'a':[1,23,3,5]}) + tst_df = pandas.DataFrame({'a': [1, 23, 3, 5]}) - res = duckdb.sql(''' + res = duckdb.sql( + ''' select 42; select * from tst_df union all select * from tst_df; - ''').fetchall() + ''' + ).fetchall() assert res == [(1,), (23,), (3,), (5,), (1,), (23,), (3,), (5,)] def test_duckdb_query_empty_result(self): @@ -41,17 +43,21 @@ def test_duckdb_query_empty_result(self): def test_duckdb_from_query(self, duckdb_cursor): # duckdb.from_query cannot be used to run arbitrary queries - with pytest.raises(duckdb.ParserException, match='duckdb.from_query cannot be used to run arbitrary SQL queries'): + with pytest.raises( + duckdb.ParserException, match='duckdb.from_query cannot be used to run arbitrary SQL queries' + ): duckdb.from_query('create view v1 as select 42 i') # ... or multiple select statements - with pytest.raises(duckdb.ParserException, match='duckdb.from_query cannot be used to run arbitrary SQL queries'): + with pytest.raises( + duckdb.ParserException, match='duckdb.from_query cannot be used to run arbitrary SQL queries' + ): duckdb.from_query('select 42; select 84;') def test_named_param(self): con = duckdb.connect() original_res = con.execute( - """ + """ select count(*) FILTER (WHERE i >= $1), sum(i) FILTER (WHERE i < $2), @@ -59,11 +65,11 @@ def test_named_param(self): from range(100) tbl(i) """, - [5, 10] + [5, 10], ).fetchall() res = con.execute( - """ + """ select count(*) FILTER (WHERE i >= $param), sum(i) FILTER (WHERE i < $other_param), @@ -71,48 +77,59 @@ def test_named_param(self): from range(100) tbl(i) """, - { - 'param': 5, - 'other_param': 10 - } + {'param': 5, 'other_param': 10}, ).fetchall() - assert(res == original_res) + assert res == original_res def test_named_param_not_dict(self): con = duckdb.connect() - with pytest.raises(duckdb.InvalidInputException, match="Named parameters found, but param is not of type 'dict'"): + with pytest.raises( + duckdb.InvalidInputException, match="Named parameters found, but param is not of type 'dict'" + ): con.execute("select $name1, $name2, $name3", ['name1', 'name2', 'name3']) def test_named_param_basic(self): con = duckdb.connect() res = con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'name3': 'a'}).fetchall() - assert res == [(5,3,'a'),] + assert res == [ + (5, 3, 'a'), + ] def test_named_param_not_exhaustive(self): con = duckdb.connect() - with pytest.raises(duckdb.InvalidInputException, match="Not all named parameters have been located, missing: name3"): + with pytest.raises( + duckdb.InvalidInputException, match="Not all named parameters have been located, missing: name3" + ): con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3}) def test_named_param_excessive(self): con = duckdb.connect() - with pytest.raises(duckdb.InvalidInputException, match="Named parameters could not be transformed, because query string is missing named parameter 'not_a_named_param'"): + with pytest.raises( + duckdb.InvalidInputException, + match="Named parameters could not be transformed, because query string is missing named parameter 'not_a_named_param'", + ): con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'not_a_named_param': 5}) def test_named_param_not_named(self): con = duckdb.connect() - with pytest.raises(duckdb.InvalidInputException, match="Invalid Input Error: Param is of type 'dict', but no named parameters were found in the query"): + with pytest.raises( + duckdb.InvalidInputException, + match="Invalid Input Error: Param is of type 'dict', but no named parameters were found in the query", + ): con.execute("select $1, $1, $2", {'name1': 5, 'name2': 3}) def test_named_param_mixed(self): con = duckdb.connect() - with pytest.raises(duckdb.NotImplementedException, match="Mixing positional and named parameters is not supported yet"): + with pytest.raises( + duckdb.NotImplementedException, match="Mixing positional and named parameters is not supported yet" + ): con.execute("select $name1, $1, $2", {'name1': 5, 'name2': 3}) def test_named_param_strings_with_dollarsign(self): @@ -128,12 +145,11 @@ def test_named_param_case_insensivity(self): """ select $NaMe1, $NAME2, $name3 """, - { - 'name1': 5, - 'nAmE2': 3, - 'NAME3': 'a' - }).fetchall() - assert res == [(5,3,'a'),] + {'name1': 5, 'nAmE2': 3, 'NAME3': 'a'}, + ).fetchall() + assert res == [ + (5, 3, 'a'), + ] def test_named_param_keyword(self): con = duckdb.connect() @@ -148,33 +164,21 @@ def test_conversion_from_tuple(self): con = duckdb.connect() # Tuple converts to list - result = con.execute("select $1", [(21,22,42)]).fetchall() + result = con.execute("select $1", [(21, 22, 42)]).fetchall() assert result == [([21, 22, 42],)] # If wrapped in a Value, it can convert to a struct - result = con.execute("select $1", [ - Value( - ('a', 21, True), - {'v1': str, 'v2': int, 'v3': bool} - ) - ]).fetchall() + result = con.execute("select $1", [Value(('a', 21, True), {'v1': str, 'v2': int, 'v3': bool})]).fetchall() assert result == [({'v1': 'a', 'v2': 21, 'v3': True},)] # If the amount of items in the tuple and the children of the struct don't match # we throw an error - with pytest.raises(duckdb.InvalidInputException, match='Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children'): - result = con.execute("select $1", [ - Value( - ('a', 21, True), - {'v1': str, 'v2': int} - ) - ]).fetchall() + with pytest.raises( + duckdb.InvalidInputException, + match='Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children', + ): + result = con.execute("select $1", [Value(('a', 21, True), {'v1': str, 'v2': int})]).fetchall() # If we try to create anything other than a STRUCT or a LIST out of the tuple, we throw an error with pytest.raises(duckdb.InvalidInputException, match="Can't convert tuple to a Value of type VARCHAR"): - result = con.execute("select $1", [ - Value( - (21, 42), - str - ) - ]) \ No newline at end of file + result = con.execute("select $1", [Value((21, 42), str)]) diff --git a/tools/pythonpkg/tests/fast/api/test_explain.py b/tools/pythonpkg/tests/fast/api/test_explain.py index dab6d54bf690..5fe73fd10d5a 100644 --- a/tools/pythonpkg/tests/fast/api/test_explain.py +++ b/tools/pythonpkg/tests/fast/api/test_explain.py @@ -1,45 +1,46 @@ import pytest import duckdb + class TestExplain(object): - def test_explain_basic(self): - res = duckdb.sql('select 42').explain() - assert isinstance(res, str) + def test_explain_basic(self): + res = duckdb.sql('select 42').explain() + assert isinstance(res, str) - def test_explain_standard(self): - res = duckdb.sql('select 42').explain('standard') - assert isinstance(res, str) + def test_explain_standard(self): + res = duckdb.sql('select 42').explain('standard') + assert isinstance(res, str) - res = duckdb.sql('select 42').explain('STANDARD') - assert isinstance(res, str) + res = duckdb.sql('select 42').explain('STANDARD') + assert isinstance(res, str) - res = duckdb.sql('select 42').explain(duckdb.STANDARD) - assert isinstance(res, str) + res = duckdb.sql('select 42').explain(duckdb.STANDARD) + assert isinstance(res, str) - res = duckdb.sql('select 42').explain(duckdb.ExplainType.STANDARD) - assert isinstance(res, str) + res = duckdb.sql('select 42').explain(duckdb.ExplainType.STANDARD) + assert isinstance(res, str) - res = duckdb.sql('select 42').explain(0) - assert isinstance(res, str) + res = duckdb.sql('select 42').explain(0) + assert isinstance(res, str) - def test_explain_analyze(self): - res = duckdb.sql('select 42').explain('analyze') - assert isinstance(res, str) + def test_explain_analyze(self): + res = duckdb.sql('select 42').explain('analyze') + assert isinstance(res, str) - res = duckdb.sql('select 42').explain('ANALYZE') - assert isinstance(res, str) + res = duckdb.sql('select 42').explain('ANALYZE') + assert isinstance(res, str) - res = duckdb.sql('select 42').explain(duckdb.ANALYZE) - assert isinstance(res, str) + res = duckdb.sql('select 42').explain(duckdb.ANALYZE) + assert isinstance(res, str) - res = duckdb.sql('select 42').explain(duckdb.ExplainType.ANALYZE) - assert isinstance(res, str) + res = duckdb.sql('select 42').explain(duckdb.ExplainType.ANALYZE) + assert isinstance(res, str) - res = duckdb.sql('select 42').explain(1) - assert isinstance(res, str) + res = duckdb.sql('select 42').explain(1) + assert isinstance(res, str) - def test_explain_df(self): - pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [42]}) - res = duckdb.sql('select * from df').explain('ANALYZE') - assert isinstance(res, str) + def test_explain_df(self): + pd = pytest.importorskip("pandas") + df = pd.DataFrame({'a': [42]}) + res = duckdb.sql('select * from df').explain('ANALYZE') + assert isinstance(res, str) diff --git a/tools/pythonpkg/tests/fast/api/test_insert_into.py b/tools/pythonpkg/tests/fast/api/test_insert_into.py index 4b16f89981a2..e6d4c6ba1d75 100644 --- a/tools/pythonpkg/tests/fast/api/test_insert_into.py +++ b/tools/pythonpkg/tests/fast/api/test_insert_into.py @@ -2,6 +2,7 @@ from pandas import DataFrame import pytest + class TestInsertInto(object): def test_insert_into_schema(self, duckdb_cursor): # open connection @@ -10,7 +11,7 @@ def test_insert_into_schema(self, duckdb_cursor): con.execute('CREATE TABLE s.t (id INTEGER PRIMARY KEY)') # make relation - df = DataFrame([1],columns=['id']) + df = DataFrame([1], columns=['id']) rel = con.from_df(df) rel.insert_into('s.t') @@ -21,7 +22,7 @@ def test_insert_into_schema(self, duckdb_cursor): with pytest.raises(duckdb.CatalogException): rel.insert_into('t') - #If we add t in the default schema it should work. + # If we add t in the default schema it should work. con.execute('CREATE TABLE t (id INTEGER PRIMARY KEY)') rel.insert_into('t') assert con.execute("select * from t").fetchall() == [(1,)] diff --git a/tools/pythonpkg/tests/fast/api/test_join.py b/tools/pythonpkg/tests/fast/api/test_join.py index 16b2632f0cd8..106e71968ba2 100644 --- a/tools/pythonpkg/tests/fast/api/test_join.py +++ b/tools/pythonpkg/tests/fast/api/test_join.py @@ -1,6 +1,7 @@ import duckdb import pytest + class TestJoin(object): def test_alias_from_sql(self): con = duckdb.connect() diff --git a/tools/pythonpkg/tests/fast/api/test_query_interrupt.py b/tools/pythonpkg/tests/fast/api/test_query_interrupt.py index c28ea7bb9cb1..52089022af1d 100644 --- a/tools/pythonpkg/tests/fast/api/test_query_interrupt.py +++ b/tools/pythonpkg/tests/fast/api/test_query_interrupt.py @@ -5,12 +5,14 @@ import threading import _thread as thread + def send_keyboard_interrupt(): # Wait a little, so we're sure the 'execute' has started time.sleep(0.1) # Send an interrupt to the main thread thread.interrupt_main() + class TestQueryInterruption(object): def test_query_interruption(self): con = duckdb.connect() diff --git a/tools/pythonpkg/tests/fast/api/test_read_csv.py b/tools/pythonpkg/tests/fast/api/test_read_csv.py index 3e60f98aa26a..cc8f94dc54ce 100644 --- a/tools/pythonpkg/tests/fast/api/test_read_csv.py +++ b/tools/pythonpkg/tests/fast/api/test_read_csv.py @@ -6,348 +6,415 @@ import duckdb from io import StringIO, BytesIO + def TestFile(name): - import os - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'..','data',name) - return filename + import os + + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', name) + return filename + class TestReadCSV(object): - def test_using_connection_wrapper(self): - rel = duckdb.read_csv(TestFile('category.csv')) - res = rel.fetchone() - print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_using_connection_wrapper_with_keyword(self): - rel = duckdb.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) - res = rel.fetchone() - print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_no_options(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv')) - res = rel.fetchone() - print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_dtype(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) - res = rel.fetchone() - print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_dtype_as_list(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['string']) - res = rel.fetchone() - print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['double']) - res = rel.fetchone() - print(res) - assert res == (1.0, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_sep(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), sep=" ") - res = rel.fetchone() - print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) - - def test_delimiter(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ") - res = rel.fetchone() - print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) - - def test_delimiter_and_sep(self, duckdb_cursor): - with pytest.raises(duckdb.InvalidInputException, match="read_csv takes either 'delimiter' or 'sep', not both"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ", sep=" ") - - def test_header_true(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), header=True) - res = rel.fetchone() - print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - @pytest.mark.skip(reason="Issue #6011 needs to be fixed first, header=False doesn't work correctly") - def test_header_false(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), header=False) - - def test_na_values(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values='Action') - res = rel.fetchone() - print(res) - assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_skiprows(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), skiprows=1) - res = rel.fetchone() - print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - # We want to detect this at bind time - def test_compression_wrong(self, duckdb_cursor): - with pytest.raises(duckdb.Error, match="Input is not a GZIP stream"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), compression='gzip') - - def test_quotechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quotechar="") - res = rel.fetchone() - print(res) - assert res == ('"AAA"BB',) - - def test_escapechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), escapechar=";") - res = rel.limit(1,1).fetchone() - print(res) - assert res == ('345', 'TEST6', '"text""2""text"') - - def test_encoding_wrong(self, duckdb_cursor): - with pytest.raises(duckdb.BinderException, match="Copy is only supported for UTF-8 encoded files, ENCODING 'UTF-8'"): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding=";") - - def test_encoding_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding="UTF-8") - res = rel.limit(1,1).fetchone() - print(res) - assert res == (345, 'TEST6', 'text"2"text') - - def test_parallel_true(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), parallel=True) - res = rel.fetchone() - print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_parallel_true(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), parallel=False) - res = rel.fetchone() - print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_date_format_as_datetime(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), date_format='%m/%d/%Y') - res = rel.fetchone() - print(res) - assert res == (123, 'TEST2', datetime.time(12, 12, 12), datetime.datetime(2000, 1, 1, 0, 0), datetime.datetime(2000, 1, 1, 12, 12)) - - def test_date_format_as_date(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), date_format='%Y-%m-%d') - res = rel.fetchone() - print(res) - assert res == (123, 'TEST2', datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12)) - - def test_timestamp_format(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), timestamp_format='%m/%d/%Y') - res = rel.fetchone() - print(res) - assert res == (123, 'TEST2', datetime.time(12, 12, 12), datetime.date(2000, 1, 1), '2000-01-01 12:12:00') - - def test_sample_size_incorrect(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('problematic.csv'), header=True, sample_size=1) - with pytest.raises(duckdb.InvalidInputException): - # The sniffer couldn't detect that this column contains non-integer values - while True: - res = rel.fetchone() - if res is None: - break - - def test_sample_size_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('problematic.csv'), header=True, sample_size=-1) - res = rel.fetchone() - print(res) - assert res == ('1', '1', '1') - - def test_all_varchar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), all_varchar=True) - res = rel.fetchone() - print(res) - assert res == ('1', 'Action', '2006-02-15 04:46:27') - - def test_null_padding(self, duckdb_cursor): - - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=False) - res = rel.fetchall() - assert res == [('# this file has a bunch of gunk at the top',), ('one,two,three,four',), ('1,a,alice',), ('2,b,bob',)] - - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=True) - res = rel.fetchall() - assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] - - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=False) - res = rel.fetchall() - assert res == [('# this file has a bunch of gunk at the top',), ('one,two,three,four',), ('1,a,alice',), ('2,b,bob',)] - - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=True) - res = rel.fetchall() - assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] - - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=False) - res = rel.fetchall() - assert res == [('# this file has a bunch of gunk at the top',), ('one,two,three,four',), ('1,a,alice',), ('2,b,bob',)] - - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=True) - res = rel.fetchall() - assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] - - rel = duckdb.from_csv_auto(TestFile('nullpadding.csv'), null_padding=False) - res = rel.fetchall() - assert res == [('# this file has a bunch of gunk at the top',), ('one,two,three,four',), ('1,a,alice',), ('2,b,bob',)] - - rel = duckdb.from_csv_auto(TestFile('nullpadding.csv'), null_padding=True) - res = rel.fetchall() - assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] - - def test_normalize_names(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=False) - df = rel.df() - column_names = list(df.columns.values) - # The names are not normalized, so they are capitalized - assert 'CATEGORY_ID' in column_names - - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=True) - df = rel.df() - column_names = list(df.columns.values) - # The capitalized names are normalized to lowercase instead - assert 'CATEGORY_ID' not in column_names - - def test_filename(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=False) - df = rel.df() - column_names = list(df.columns.values) - # The filename is not included in the returned columns - assert 'filename' not in column_names - - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=True) - df = rel.df() - column_names = list(df.columns.values) - # The filename is included in the returned columns - assert 'filename' in column_names - - def test_read_pathlib_path(self, duckdb_cursor): - pathlib = pytest.importorskip("pathlib") - path = pathlib.Path(TestFile('category.csv')) - rel = duckdb_cursor.read_csv(path) - res = rel.fetchone() - print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) - - def test_read_filelike(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - string = StringIO("c1,c2,c3\na,b,c") - res = duckdb_cursor.read_csv(string, header=True).fetchall() - assert res == [('a', 'b', 'c')] - - def test_read_filelike_rel_out_of_scope(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - def keep_in_scope(): - string = StringIO("c1,c2,c3\na,b,c") - # Create a ReadCSVRelation on a file-like object - # this will add the object to our internal object filesystem - rel = duckdb_cursor.read_csv(string, header=True) - # The file-like object will still exist, so we can execute this later - return rel - - def close_scope(): - string = StringIO("c1,c2,c3\na,b,c") - # Create a ReadCSVRelation on a file-like object - # this will add the object to our internal object filesystem - res = duckdb_cursor.read_csv(string, header=True).fetchall() - # When the relation goes out of scope - we delete the file-like object from our filesystem - return res - - relation = keep_in_scope() - res = relation.fetchall() - - res2 = close_scope() - assert res == res2 - - def test_filelike_bytesio(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - string = BytesIO(b"c1,c2,c3\na,b,c") - res = duckdb_cursor.read_csv(string, header=True).fetchall() - assert res == [('a', 'b', 'c')] - - def test_filelike_exception(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - class ReadError: - def __init__(self): - pass - def read(self, amount): - raise ValueError(amount) - def seek(self, loc): - return 0 - - class SeekError: - def __init__(self): - pass - def read(self, amount): - return b'test' - def seek(self, loc): - raise ValueError(loc) - - obj = ReadError() - with pytest.raises(ValueError): - res = duckdb_cursor.read_csv(obj, header=True).fetchall() - - obj = SeekError() - with pytest.raises(ValueError): - res = duckdb_cursor.read_csv(obj, header=True).fetchall() - - def test_filelike_custom(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - class CustomIO: - def __init__(self): - self.loc = 0 - pass - def seek(self, loc): - self.loc = loc - return loc - def read(self, amount): - out = b"c1,c2,c3\na,b,c"[self.loc : self.loc + amount : 1] - self.loc += amount - return out - - obj = CustomIO() - res = duckdb_cursor.read_csv(obj, header=True).fetchall() - assert res == [('a', 'b', 'c')] - - def test_filelike_non_readable(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - obj = 5; - with pytest.raises(ValueError, match="Can not read from a non file-like object"): - res = duckdb_cursor.read_csv(obj, header=True).fetchall() - - def test_filelike_none(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - obj = None; - with pytest.raises(ValueError, match="Can not read from a non file-like object"): - res = duckdb_cursor.read_csv(obj, header=True).fetchall() - - def test_internal_object_filesystem_cleanup(self, duckdb_cursor): - _ = pytest.importorskip("fsspec") - class CountedObject(StringIO): - instance_count = 0 - def __init__(self, str): - CountedObject.instance_count += 1 - super().__init__(str) - def __del__(self): - CountedObject.instance_count -= 1 - - def scoped_objects(duckdb_cursor): - obj = CountedObject("a,b,c") - rel1 = duckdb_cursor.read_csv(obj) - assert rel1.fetchall() == [('a','b','c',)] - assert CountedObject.instance_count == 1 - - obj = CountedObject("a,b,c") - rel2 = duckdb_cursor.read_csv(obj) - assert rel2.fetchall() == [('a','b','c',)] - assert CountedObject.instance_count == 2 - - obj = CountedObject("a,b,c") - rel3 = duckdb_cursor.read_csv(obj) - assert rel3.fetchall() == [('a','b','c',)] - assert CountedObject.instance_count == 3 - assert CountedObject.instance_count == 0 - scoped_objects(duckdb_cursor) - assert CountedObject.instance_count == 0 + def test_using_connection_wrapper(self): + rel = duckdb.read_csv(TestFile('category.csv')) + res = rel.fetchone() + print(res) + assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_using_connection_wrapper_with_keyword(self): + rel = duckdb.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + res = rel.fetchone() + print(res) + assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_no_options(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv')) + res = rel.fetchone() + print(res) + assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_dtype(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + res = rel.fetchone() + print(res) + assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_dtype_as_list(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['string']) + res = rel.fetchone() + print(res) + assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['double']) + res = rel.fetchone() + print(res) + assert res == (1.0, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_sep(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), sep=" ") + res = rel.fetchone() + print(res) + assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + + def test_delimiter(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ") + res = rel.fetchone() + print(res) + assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + + def test_delimiter_and_sep(self, duckdb_cursor): + with pytest.raises(duckdb.InvalidInputException, match="read_csv takes either 'delimiter' or 'sep', not both"): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ", sep=" ") + + def test_header_true(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), header=True) + res = rel.fetchone() + print(res) + assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + @pytest.mark.skip(reason="Issue #6011 needs to be fixed first, header=False doesn't work correctly") + def test_header_false(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), header=False) + + def test_na_values(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values='Action') + res = rel.fetchone() + print(res) + assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_skiprows(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), skiprows=1) + res = rel.fetchone() + print(res) + assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + # We want to detect this at bind time + def test_compression_wrong(self, duckdb_cursor): + with pytest.raises(duckdb.Error, match="Input is not a GZIP stream"): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), compression='gzip') + + def test_quotechar(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quotechar="") + res = rel.fetchone() + print(res) + assert res == ('"AAA"BB',) + + def test_escapechar(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), escapechar=";") + res = rel.limit(1, 1).fetchone() + print(res) + assert res == ('345', 'TEST6', '"text""2""text"') + + def test_encoding_wrong(self, duckdb_cursor): + with pytest.raises( + duckdb.BinderException, match="Copy is only supported for UTF-8 encoded files, ENCODING 'UTF-8'" + ): + rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding=";") + + def test_encoding_correct(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding="UTF-8") + res = rel.limit(1, 1).fetchone() + print(res) + assert res == (345, 'TEST6', 'text"2"text') + + def test_parallel_true(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), parallel=True) + res = rel.fetchone() + print(res) + assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_parallel_true(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), parallel=False) + res = rel.fetchone() + print(res) + assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_date_format_as_datetime(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), date_format='%m/%d/%Y') + res = rel.fetchone() + print(res) + assert res == ( + 123, + 'TEST2', + datetime.time(12, 12, 12), + datetime.datetime(2000, 1, 1, 0, 0), + datetime.datetime(2000, 1, 1, 12, 12), + ) + + def test_date_format_as_date(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), date_format='%Y-%m-%d') + res = rel.fetchone() + print(res) + assert res == ( + 123, + 'TEST2', + datetime.time(12, 12, 12), + datetime.date(2000, 1, 1), + datetime.datetime(2000, 1, 1, 12, 12), + ) + + def test_timestamp_format(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), timestamp_format='%m/%d/%Y') + res = rel.fetchone() + print(res) + assert res == (123, 'TEST2', datetime.time(12, 12, 12), datetime.date(2000, 1, 1), '2000-01-01 12:12:00') + + def test_sample_size_incorrect(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('problematic.csv'), header=True, sample_size=1) + with pytest.raises(duckdb.InvalidInputException): + # The sniffer couldn't detect that this column contains non-integer values + while True: + res = rel.fetchone() + if res is None: + break + + def test_sample_size_correct(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('problematic.csv'), header=True, sample_size=-1) + res = rel.fetchone() + print(res) + assert res == ('1', '1', '1') + + def test_all_varchar(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), all_varchar=True) + res = rel.fetchone() + print(res) + assert res == ('1', 'Action', '2006-02-15 04:46:27') + + def test_null_padding(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=False) + res = rel.fetchall() + assert res == [ + ('# this file has a bunch of gunk at the top',), + ('one,two,three,four',), + ('1,a,alice',), + ('2,b,bob',), + ] + + rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=True) + res = rel.fetchall() + assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] + + rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=False) + res = rel.fetchall() + assert res == [ + ('# this file has a bunch of gunk at the top',), + ('one,two,three,four',), + ('1,a,alice',), + ('2,b,bob',), + ] + + rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=True) + res = rel.fetchall() + assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] + + rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=False) + res = rel.fetchall() + assert res == [ + ('# this file has a bunch of gunk at the top',), + ('one,two,three,four',), + ('1,a,alice',), + ('2,b,bob',), + ] + + rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=True) + res = rel.fetchall() + assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] + + rel = duckdb.from_csv_auto(TestFile('nullpadding.csv'), null_padding=False) + res = rel.fetchall() + assert res == [ + ('# this file has a bunch of gunk at the top',), + ('one,two,three,four',), + ('1,a,alice',), + ('2,b,bob',), + ] + + rel = duckdb.from_csv_auto(TestFile('nullpadding.csv'), null_padding=True) + res = rel.fetchall() + assert res == [(1, 'a', 'alice', None), (2, 'b', 'bob', None)] + + def test_normalize_names(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=False) + df = rel.df() + column_names = list(df.columns.values) + # The names are not normalized, so they are capitalized + assert 'CATEGORY_ID' in column_names + + rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=True) + df = rel.df() + column_names = list(df.columns.values) + # The capitalized names are normalized to lowercase instead + assert 'CATEGORY_ID' not in column_names + + def test_filename(self, duckdb_cursor): + rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=False) + df = rel.df() + column_names = list(df.columns.values) + # The filename is not included in the returned columns + assert 'filename' not in column_names + + rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=True) + df = rel.df() + column_names = list(df.columns.values) + # The filename is included in the returned columns + assert 'filename' in column_names + + def test_read_pathlib_path(self, duckdb_cursor): + pathlib = pytest.importorskip("pathlib") + path = pathlib.Path(TestFile('category.csv')) + rel = duckdb_cursor.read_csv(path) + res = rel.fetchone() + print(res) + assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + + def test_read_filelike(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + string = StringIO("c1,c2,c3\na,b,c") + res = duckdb_cursor.read_csv(string, header=True).fetchall() + assert res == [('a', 'b', 'c')] + + def test_read_filelike_rel_out_of_scope(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + + def keep_in_scope(): + string = StringIO("c1,c2,c3\na,b,c") + # Create a ReadCSVRelation on a file-like object + # this will add the object to our internal object filesystem + rel = duckdb_cursor.read_csv(string, header=True) + # The file-like object will still exist, so we can execute this later + return rel + + def close_scope(): + string = StringIO("c1,c2,c3\na,b,c") + # Create a ReadCSVRelation on a file-like object + # this will add the object to our internal object filesystem + res = duckdb_cursor.read_csv(string, header=True).fetchall() + # When the relation goes out of scope - we delete the file-like object from our filesystem + return res + + relation = keep_in_scope() + res = relation.fetchall() + + res2 = close_scope() + assert res == res2 + + def test_filelike_bytesio(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + string = BytesIO(b"c1,c2,c3\na,b,c") + res = duckdb_cursor.read_csv(string, header=True).fetchall() + assert res == [('a', 'b', 'c')] + + def test_filelike_exception(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + + class ReadError: + def __init__(self): + pass + + def read(self, amount): + raise ValueError(amount) + + def seek(self, loc): + return 0 + + class SeekError: + def __init__(self): + pass + + def read(self, amount): + return b'test' + + def seek(self, loc): + raise ValueError(loc) + + obj = ReadError() + with pytest.raises(ValueError): + res = duckdb_cursor.read_csv(obj, header=True).fetchall() + + obj = SeekError() + with pytest.raises(ValueError): + res = duckdb_cursor.read_csv(obj, header=True).fetchall() + + def test_filelike_custom(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + + class CustomIO: + def __init__(self): + self.loc = 0 + pass + + def seek(self, loc): + self.loc = loc + return loc + + def read(self, amount): + out = b"c1,c2,c3\na,b,c"[self.loc : self.loc + amount : 1] + self.loc += amount + return out + + obj = CustomIO() + res = duckdb_cursor.read_csv(obj, header=True).fetchall() + assert res == [('a', 'b', 'c')] + + def test_filelike_non_readable(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + obj = 5 + with pytest.raises(ValueError, match="Can not read from a non file-like object"): + res = duckdb_cursor.read_csv(obj, header=True).fetchall() + + def test_filelike_none(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + obj = None + with pytest.raises(ValueError, match="Can not read from a non file-like object"): + res = duckdb_cursor.read_csv(obj, header=True).fetchall() + + def test_internal_object_filesystem_cleanup(self, duckdb_cursor): + _ = pytest.importorskip("fsspec") + + class CountedObject(StringIO): + instance_count = 0 + + def __init__(self, str): + CountedObject.instance_count += 1 + super().__init__(str) + + def __del__(self): + CountedObject.instance_count -= 1 + + def scoped_objects(duckdb_cursor): + obj = CountedObject("a,b,c") + rel1 = duckdb_cursor.read_csv(obj) + assert rel1.fetchall() == [ + ( + 'a', + 'b', + 'c', + ) + ] + assert CountedObject.instance_count == 1 + + obj = CountedObject("a,b,c") + rel2 = duckdb_cursor.read_csv(obj) + assert rel2.fetchall() == [ + ( + 'a', + 'b', + 'c', + ) + ] + assert CountedObject.instance_count == 2 + + obj = CountedObject("a,b,c") + rel3 = duckdb_cursor.read_csv(obj) + assert rel3.fetchall() == [ + ( + 'a', + 'b', + 'c', + ) + ] + assert CountedObject.instance_count == 3 + + assert CountedObject.instance_count == 0 + scoped_objects(duckdb_cursor) + assert CountedObject.instance_count == 0 diff --git a/tools/pythonpkg/tests/fast/api/test_streaming_result.py b/tools/pythonpkg/tests/fast/api/test_streaming_result.py index b2ff40ff1aa3..80c689a899a3 100644 --- a/tools/pythonpkg/tests/fast/api/test_streaming_result.py +++ b/tools/pythonpkg/tests/fast/api/test_streaming_result.py @@ -1,56 +1,63 @@ import pytest import duckdb + class TestStreamingResult(object): - def test_fetch_one(self): - # fetch one - res = duckdb.sql('SELECT * FROM range(100000)') - result = [] - while len(result) < 5000: - tpl = res.fetchone() - result.append(tpl[0]) - assert result == list(range(5000)) + def test_fetch_one(self): + # fetch one + res = duckdb.sql('SELECT * FROM range(100000)') + result = [] + while len(result) < 5000: + tpl = res.fetchone() + result.append(tpl[0]) + assert result == list(range(5000)) - # fetch one with error - res = duckdb.sql("SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)") - with pytest.raises(duckdb.ConversionException): - while True: - tpl = res.fetchone() - if tpl is None: - break + # fetch one with error + res = duckdb.sql( + "SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)" + ) + with pytest.raises(duckdb.ConversionException): + while True: + tpl = res.fetchone() + if tpl is None: + break - def test_fetch_many(self): - # fetch many - res = duckdb.sql('SELECT * FROM range(100000)') - result = [] - while len(result) < 5000: - tpl = res.fetchmany(10) - result += [x[0] for x in tpl] - assert result == list(range(5000)) + def test_fetch_many(self): + # fetch many + res = duckdb.sql('SELECT * FROM range(100000)') + result = [] + while len(result) < 5000: + tpl = res.fetchmany(10) + result += [x[0] for x in tpl] + assert result == list(range(5000)) - # fetch many with error - res = duckdb.sql("SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)") - with pytest.raises(duckdb.ConversionException): - while True: - tpl = res.fetchmany(10) - if tpl is None: - break + # fetch many with error + res = duckdb.sql( + "SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)" + ) + with pytest.raises(duckdb.ConversionException): + while True: + tpl = res.fetchmany(10) + if tpl is None: + break - def test_record_batch_reader(self): - pytest.importorskip("pyarrow") - pytest.importorskip("pyarrow.dataset") - # record batch reader - res = duckdb.sql('SELECT * FROM range(100000) t(i)') - reader = res.fetch_arrow_reader(batch_size=16_384) - result = [] - for batch in reader: - result += batch.to_pydict()['i'] - assert result == list(range(100000)) + def test_record_batch_reader(self): + pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow.dataset") + # record batch reader + res = duckdb.sql('SELECT * FROM range(100000) t(i)') + reader = res.fetch_arrow_reader(batch_size=16_384) + result = [] + for batch in reader: + result += batch.to_pydict()['i'] + assert result == list(range(100000)) - # record batch reader with error - res = duckdb.sql("SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)") - reader = res.fetch_arrow_reader(batch_size=16_384) - with pytest.raises(OSError): - result = [] - for batch in reader: - result += batch.to_pydict()['i'] + # record batch reader with error + res = duckdb.sql( + "SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)" + ) + reader = res.fetch_arrow_reader(batch_size=16_384) + with pytest.raises(OSError): + result = [] + for batch in reader: + result += batch.to_pydict()['i'] diff --git a/tools/pythonpkg/tests/fast/api/test_to_csv.py b/tools/pythonpkg/tests/fast/api/test_to_csv.py index 025f13092a5a..8230df6aa60d 100644 --- a/tools/pythonpkg/tests/fast/api/test_to_csv.py +++ b/tools/pythonpkg/tests/fast/api/test_to_csv.py @@ -8,21 +8,23 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestToCSV(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_basic_to_csv(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5,3,23,2], 'b': [45,234,234,2]}) + df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_sep(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5,3,23,2], 'b': [45,234,234,2]}) + df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, sep=',') @@ -33,7 +35,7 @@ def test_to_csv_sep(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_na_rep(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5,None,23,2], 'b': [45,234,234,2]}) + df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, na_rep="test") @@ -44,7 +46,7 @@ def test_to_csv_na_rep(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_header(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5,None,23,2], 'b': [45,234,234,2]}) + df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, header=True) @@ -55,7 +57,7 @@ def test_to_csv_header(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_quotechar(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ["\'a,b,c\'",None,"hello","bye"], 'b': [45,234,234,2]}) + df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quotechar='\'', sep=',') @@ -84,9 +86,7 @@ def test_to_csv_date_format(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame(tm.getTimeSeriesData()) dt_index = df.index - df = pandas.DataFrame( - {"A": dt_index, "B": dt_index.shift(1)}, index=dt_index - ) + df = pandas.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, date_format="%Y%m%d") @@ -98,9 +98,7 @@ def test_to_csv_date_format(self, pandas): def test_to_csv_timestamp_format(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - df = pandas.DataFrame( - {'0': pandas.Series(data=data, dtype='object')} - ) + df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, timestamp_format='%m/%d/%Y') @@ -111,9 +109,7 @@ def test_to_csv_timestamp_format(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_off(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame( - {'a': ['string1', 'string2', 'string3']} - ) + df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=None) @@ -123,9 +119,7 @@ def test_to_csv_quoting_off(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_on(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame( - {'a': ['string1', 'string2', 'string3']} - ) + df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting="force") @@ -135,9 +129,7 @@ def test_to_csv_quoting_on(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_quote_all(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame( - {'a': ['string1', 'string2', 'string3']} - ) + df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=csv.QUOTE_ALL) @@ -147,19 +139,17 @@ def test_to_csv_quoting_quote_all(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_incorrect(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame( - {'a': ['string1', 'string2', 'string3']} - ) + df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) rel = duckdb.from_df(df) - with pytest.raises(duckdb.InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8"): + with pytest.raises( + duckdb.InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8" + ): rel.to_csv(temp_file_name, encoding="nope") @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_correct(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame( - {'a': ['string1', 'string2', 'string3']} - ) + df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, encoding="UTF-8") csv_rel = duckdb.read_csv(temp_file_name) @@ -168,9 +158,7 @@ def test_to_csv_encoding_correct(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_compression_gzip(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame( - {'a': ['string1', 'string2', 'string3']} - ) + df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, compression="gzip") csv_rel = duckdb.read_csv(temp_file_name, compression="gzip") diff --git a/tools/pythonpkg/tests/fast/api/test_to_parquet.py b/tools/pythonpkg/tests/fast/api/test_to_parquet.py index ab6d384b2583..9eace0f31fe8 100644 --- a/tools/pythonpkg/tests/fast/api/test_to_parquet.py +++ b/tools/pythonpkg/tests/fast/api/test_to_parquet.py @@ -8,10 +8,11 @@ import csv import pytest + class TestToParquet(object): def test_basic_to_parquet(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': [5,3,23,2], 'b': [45,234,234,2]}) + df = pd.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name) @@ -21,9 +22,7 @@ def test_basic_to_parquet(self): def test_compression_gzip(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame( - {'a': ['string1', 'string2', 'string3']} - ) + df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, compression="gzip") csv_rel = duckdb.read_parquet(temp_file_name, compression="gzip") diff --git a/tools/pythonpkg/tests/fast/api/test_with_propagating_exceptions.py b/tools/pythonpkg/tests/fast/api/test_with_propagating_exceptions.py index c9fa98d4d1d6..e9cfb3c0ccfe 100644 --- a/tools/pythonpkg/tests/fast/api/test_with_propagating_exceptions.py +++ b/tools/pythonpkg/tests/fast/api/test_with_propagating_exceptions.py @@ -1,8 +1,8 @@ import pytest import duckdb -class TestWithPropagatingExceptions(object): +class TestWithPropagatingExceptions(object): def test_with(self): # Should propagate exception raised in the 'with duckdb.connect() ..' with pytest.raises(duckdb.ParserException, match="syntax error at or near *"): diff --git a/tools/pythonpkg/tests/fast/arrow/parquet_write_roundtrip.py b/tools/pythonpkg/tests/fast/arrow/parquet_write_roundtrip.py index ca88abd39b0e..093040c040db 100644 --- a/tools/pythonpkg/tests/fast/arrow/parquet_write_roundtrip.py +++ b/tools/pythonpkg/tests/fast/arrow/parquet_write_roundtrip.py @@ -4,8 +4,10 @@ import numpy import pandas import datetime + pa = pytest.importorskip("pyarrow") + def parquet_types_test(type_list): temp = tempfile.NamedTemporaryFile() temp_name = temp.name @@ -15,9 +17,7 @@ def parquet_types_test(type_list): sql_type = type_pair[2] add_cast = len(type_pair) > 3 and type_pair[3] add_sql_cast = len(type_pair) > 4 and type_pair[4] - df = pandas.DataFrame.from_dict({ - 'val': numpy.array(value_list, dtype=numpy_type) - }) + df = pandas.DataFrame.from_dict({'val': numpy.array(value_list, dtype=numpy_type)}) duckdb_cursor = duckdb.connect() duckdb_cursor.execute(f"CREATE TABLE tmp AS SELECT val::{sql_type} val FROM df") duckdb_cursor.execute(f"COPY tmp TO '{temp_name}' (FORMAT PARQUET)") @@ -40,14 +40,14 @@ def parquet_types_test(type_list): class TestParquetRoundtrip(object): def test_roundtrip_numeric(self, duckdb_cursor): type_list = [ - ([-2**7, 0, 2**7-1], numpy.int8, 'TINYINT'), - ([-2**15, 0, 2**15-1], numpy.int16, 'SMALLINT'), - ([-2**31, 0, 2**31-1], numpy.int32, 'INTEGER'), - ([-2**63, 0, 2**63-1], numpy.int64, 'BIGINT'), - ([0, 42, 2**8-1], numpy.uint8, 'UTINYINT'), - ([0, 42, 2**16-1], numpy.uint16, 'USMALLINT'), - ([0, 42, 2**32-1], numpy.uint32, 'UINTEGER', False, True), - ([0, 42, 2**64-1], numpy.uint64, 'UBIGINT'), + ([-(2**7), 0, 2**7 - 1], numpy.int8, 'TINYINT'), + ([-(2**15), 0, 2**15 - 1], numpy.int16, 'SMALLINT'), + ([-(2**31), 0, 2**31 - 1], numpy.int32, 'INTEGER'), + ([-(2**63), 0, 2**63 - 1], numpy.int64, 'BIGINT'), + ([0, 42, 2**8 - 1], numpy.uint8, 'UTINYINT'), + ([0, 42, 2**16 - 1], numpy.uint16, 'USMALLINT'), + ([0, 42, 2**32 - 1], numpy.uint32, 'UINTEGER', False, True), + ([0, 42, 2**64 - 1], numpy.uint64, 'UBIGINT'), ([0, 0.5, -0.5], numpy.float32, 'REAL'), ([0, 0.5, -0.5], numpy.float64, 'DOUBLE'), ] @@ -58,25 +58,18 @@ def test_roundtrip_timestamp(self, duckdb_cursor): datetime.datetime(2018, 3, 10, 11, 17, 54), datetime.datetime(1900, 12, 12, 23, 48, 42), None, - datetime.datetime(1992, 7, 9, 7, 5, 33) + datetime.datetime(1992, 7, 9, 7, 5, 33), ] type_list = [ (date_time_list, 'datetime64[ns]', 'TIMESTAMP_NS'), (date_time_list, 'datetime64[us]', 'TIMESTAMP'), (date_time_list, 'datetime64[ms]', 'TIMESTAMP_MS'), (date_time_list, 'datetime64[s]', 'TIMESTAMP_S'), - (date_time_list, 'datetime64[D]', 'DATE', True) + (date_time_list, 'datetime64[D]', 'DATE', True), ] parquet_types_test(type_list) def test_roundtrip_varchar(self, duckdb_cursor): - varchar_list = [ - 'hello', - 'this is a very long string', - 'hello', - None - ] - type_list = [ - (varchar_list, object, 'VARCHAR') - ] + varchar_list = ['hello', 'this is a very long string', 'hello', None] + type_list = [(varchar_list, object, 'VARCHAR')] parquet_types_test(type_list) diff --git a/tools/pythonpkg/tests/fast/arrow/test_2426.py b/tools/pythonpkg/tests/fast/arrow/test_2426.py index 8732a3e8dabf..cdef8da7688a 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_2426.py +++ b/tools/pythonpkg/tests/fast/arrow/test_2426.py @@ -1,22 +1,25 @@ import duckdb import os + try: import pyarrow as pa + can_run = True except: can_run = False + class Test2426(object): - def test_2426(self,duckdb_cursor): + def test_2426(self, duckdb_cursor): if not can_run: return - + con = duckdb.connect() con.execute("Create Table test (a integer)") - for i in range (1024): + for i in range(1024): for j in range(2): - con.execute("Insert Into test values ('"+str(i)+"')") + con.execute("Insert Into test values ('" + str(i) + "')") con.execute("Insert Into test values ('5000')") con.execute("Insert Into test values ('6000')") sql = ''' diff --git a/tools/pythonpkg/tests/fast/arrow/test_6584.py b/tools/pythonpkg/tests/fast/arrow/test_6584.py index 93571968e8d9..f0da385c6c0c 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_6584.py +++ b/tools/pythonpkg/tests/fast/arrow/test_6584.py @@ -4,19 +4,21 @@ pyarrow = pytest.importorskip('pyarrow') -def f(cur, i, data): + +def f(cur, i, data): cur.execute(f"create table t_{i} as select * from data") return cur.execute(f"select * from t_{i}").arrow() + def test_6584(): pool = ThreadPoolExecutor(max_workers=2) - data = pyarrow.Table.from_pydict({"a": [1,2,3]}) + data = pyarrow.Table.from_pydict({"a": [1, 2, 3]}) c = duckdb.connect() futures = [] for i in range(2): - fut = pool.submit(f, c.cursor(), i,data) + fut = pool.submit(f, c.cursor(), i, data) futures.append(fut) for fut in futures: arrow_res = fut.result() - assert data.equals(arrow_res) \ No newline at end of file + assert data.equals(arrow_res) diff --git a/tools/pythonpkg/tests/fast/arrow/test_6796.py b/tools/pythonpkg/tests/fast/arrow/test_6796.py index 73b0ac61c6cd..6690f22cd75f 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_6796.py +++ b/tools/pythonpkg/tests/fast/arrow/test_6796.py @@ -4,26 +4,27 @@ pyarrow = pytest.importorskip('pyarrow') + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_6796(pandas): - conn = duckdb.connect() - input_df = pandas.DataFrame({ "foo": ["bar"] }) - conn.register("input_df", input_df) + conn = duckdb.connect() + input_df = pandas.DataFrame({"foo": ["bar"]}) + conn.register("input_df", input_df) - query = """ + query = """ select * from input_df union all select * from input_df """ - # fetching directly into Pandas works - res_df = conn.execute(query).fetch_df() - res_arrow = conn.execute(query).fetch_arrow_table() + # fetching directly into Pandas works + res_df = conn.execute(query).fetch_df() + res_arrow = conn.execute(query).fetch_arrow_table() - df_arrow_table = pyarrow.Table.from_pandas(res_df) + df_arrow_table = pyarrow.Table.from_pandas(res_df) - result_1 = conn.execute("select * from df_arrow_table order by all").fetchall() + result_1 = conn.execute("select * from df_arrow_table order by all").fetchall() - result_2 = conn.execute("select * from res_arrow order by all").fetchall() + result_2 = conn.execute("select * from res_arrow order by all").fetchall() - assert result_1 == result_2 \ No newline at end of file + assert result_1 == result_2 diff --git a/tools/pythonpkg/tests/fast/arrow/test_7652.py b/tools/pythonpkg/tests/fast/arrow/test_7652.py index b8a2a228698d..60fce51a7b00 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_7652.py +++ b/tools/pythonpkg/tests/fast/arrow/test_7652.py @@ -6,6 +6,7 @@ pa = pytest.importorskip("pyarrow", minversion="11") pq = pytest.importorskip("pyarrow.parquet", minversion="11") + class Test7652(object): def test_7652(self): temp_file_name = tempfile.NamedTemporaryFile(suffix='.parquet').name @@ -19,10 +20,9 @@ def test_7652(self): fake_table = pa.Table.from_arrays([pa.array(generated_list, pa.int64())], names=['n0']) # Write that column with DELTA_BINARY_PACKED encoding - with pq.ParquetWriter(temp_file_name, - fake_table.schema, - column_encoding={"n0": "DELTA_BINARY_PACKED"}, - use_dictionary=False) as writer: + with pq.ParquetWriter( + temp_file_name, fake_table.schema, column_encoding={"n0": "DELTA_BINARY_PACKED"}, use_dictionary=False + ) as writer: writer.write_table(fake_table) # Check to make sure that PyArrow can read the file and retrieve the expected values. diff --git a/tools/pythonpkg/tests/fast/arrow/test_7699.py b/tools/pythonpkg/tests/fast/arrow/test_7699.py index 8e3996026d51..36a7a47b336f 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_7699.py +++ b/tools/pythonpkg/tests/fast/arrow/test_7699.py @@ -6,13 +6,16 @@ pq = pytest.importorskip("pyarrow.parquet") pl = pytest.importorskip("polars") + class Test7699(object): def test_7699(self): - pl_tbl = pl.DataFrame({ - "col1" : pl.Series([ - string.ascii_uppercase[ix+10] for ix in list(range(2)) + list(range(3)) - ]).cast(pl.Categorical), - }) + pl_tbl = pl.DataFrame( + { + "col1": pl.Series([string.ascii_uppercase[ix + 10] for ix in list(range(2)) + list(range(3))]).cast( + pl.Categorical + ), + } + ) nickname = "df1234" duckdb.register(nickname, pl_tbl) diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_batch_index.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_batch_index.py index 1e9a043ab449..e9e92923e3db 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_batch_index.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_batch_index.py @@ -2,12 +2,15 @@ import pytest import pandas as pd import duckdb + try: import pyarrow as pa + can_run = True except: can_run = False + class TestArrowBatchIndex(object): def test_arrow_batch_index(self, duckdb_cursor): if not can_run: diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_case_sensitive.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_case_sensitive.py index 9c1538edddc1..627861d57a8d 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_case_sensitive.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_case_sensitive.py @@ -1,22 +1,28 @@ import duckdb import pytest + try: import pyarrow as pa + can_run = True except: can_run = False + class TestArrowCaseSensitive(object): def test_arrow_case_sensitive(self, duckdb_cursor): if not can_run: return - data = (pa.array([1], type=pa.int32()),pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0],data[1]],['A1','a1']) + data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ['A1', 'a1']) con = duckdb.connect() con.register('arrow_tbl', arrow_table) - print (con.execute("DESCRIBE arrow_tbl;").fetchall()) - assert con.execute("DESCRIBE arrow_tbl;").fetchall() == [('A1', 'INTEGER', 'YES', None, None, None), ('a1_1', 'INTEGER', 'YES', None, None, None)] + print(con.execute("DESCRIBE arrow_tbl;").fetchall()) + assert con.execute("DESCRIBE arrow_tbl;").fetchall() == [ + ('A1', 'INTEGER', 'YES', None, None, None), + ('a1_1', 'INTEGER', 'YES', None, None, None), + ] assert con.execute("select A1 from arrow_tbl;").fetchall() == [(1,)] assert con.execute("select a1_1 from arrow_tbl;").fetchall() == [(1000,)] assert arrow_table.column_names == ['A1', 'a1'] @@ -24,11 +30,15 @@ def test_arrow_case_sensitive(self, duckdb_cursor): def test_arrow_case_sensitive_repeated(self, duckdb_cursor): if not can_run: return - data = (pa.array([1], type=pa.int32()),pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[1]],['A1','a1_1','a1']) + data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[1]], ['A1', 'a1_1', 'a1']) con = duckdb.connect() con.register('arrow_tbl', arrow_table) - print (con.execute("DESCRIBE arrow_tbl;").fetchall()) - assert con.execute("DESCRIBE arrow_tbl;").fetchall() == [('A1', 'INTEGER', 'YES', None, None, None), ('a1_1', 'INTEGER', 'YES', None, None, None), ('a1_2', 'INTEGER', 'YES', None, None, None)] - assert arrow_table.column_names == ['A1','a1_1','a1'] \ No newline at end of file + print(con.execute("DESCRIBE arrow_tbl;").fetchall()) + assert con.execute("DESCRIBE arrow_tbl;").fetchall() == [ + ('A1', 'INTEGER', 'YES', None, None, None), + ('a1_1', 'INTEGER', 'YES', None, None, None), + ('a1_2', 'INTEGER', 'YES', None, None, None), + ] + assert arrow_table.column_names == ['A1', 'a1_1', 'a1'] diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch.py index 6df4f3d05ed9..093989f1e06d 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch.py @@ -1,18 +1,22 @@ import duckdb import pytest + try: import pyarrow as pa + can_run = True except: can_run = False + def check_equal(duckdb_conn): true_result = duckdb_conn.execute("SELECT * from test").fetchall() duck_tbl = duckdb_conn.table("test") duck_from_arrow = duckdb_conn.from_arrow(duck_tbl.arrow()) duck_from_arrow.create("testarrow") arrow_result = duckdb_conn.execute("SELECT * from testarrow").fetchall() - assert(arrow_result == true_result) + assert arrow_result == true_result + class TestArrowFetch(object): def test_over_vector_size(self, duckdb_cursor): @@ -21,8 +25,8 @@ def test_over_vector_size(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (a INTEGER)") - for value in range (10000): - duckdb_conn.execute("INSERT INTO test VALUES ("+str(value) + ");") + for value in range(10000): + duckdb_conn.execute("INSERT INTO test VALUES (" + str(value) + ");") duckdb_conn.execute("INSERT INTO test VALUES(NULL);") check_equal(duckdb_conn) @@ -33,7 +37,7 @@ def test_empty_table(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (a INTEGER)") - + check_equal(duckdb_conn) def test_over_vector_size(self, duckdb_cursor): @@ -43,10 +47,10 @@ def test_over_vector_size(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (a INTEGER)") - for value in range (10000): - duckdb_conn.execute("INSERT INTO test VALUES ("+str(value) + ");") + for value in range(10000): + duckdb_conn.execute("INSERT INTO test VALUES (" + str(value) + ");") duckdb_conn.execute("INSERT INTO test VALUES(NULL);") - + check_equal(duckdb_conn) def test_table_nulls(self, duckdb_cursor): @@ -57,7 +61,7 @@ def test_table_nulls(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test (a INTEGER)") duckdb_conn.execute("INSERT INTO test VALUES(NULL);") - + check_equal(duckdb_conn) def test_table_without_nulls(self, duckdb_cursor): @@ -68,7 +72,7 @@ def test_table_without_nulls(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test (a INTEGER)") duckdb_conn.execute("INSERT INTO test VALUES(1);") - + check_equal(duckdb_conn) def test_table_with_prepared_statements(self, duckdb_cursor): @@ -80,8 +84,8 @@ def test_table_with_prepared_statements(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test (a INTEGER)") duckdb_conn.execute("PREPARE s1 AS INSERT INTO test VALUES ($1), ($2 / 2)") - for value in range (10000): - duckdb_conn.execute("EXECUTE s1("+str(value) + "," + str(value*2)+ ");") + for value in range(10000): + duckdb_conn.execute("EXECUTE s1(" + str(value) + "," + str(value * 2) + ");") check_equal(duckdb_conn) diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch_recordbatch.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch_recordbatch.py index d4ea64c2dece..a5357820a845 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch_recordbatch.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_fetch_recordbatch.py @@ -1,5 +1,6 @@ import duckdb import pytest + pa = pytest.importorskip('pyarrow') @@ -13,17 +14,17 @@ def test_record_batch_next_batch_numeric(self, duckdb_cursor): record_batch_reader = query.fetch_record_batch(1024) assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 952) + assert len(chunk) == 952 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - + res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() assert res == correct @@ -32,23 +33,25 @@ def test_record_batch_next_batch_numeric(self, duckdb_cursor): def test_record_batch_next_batch_bool(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor_check = duckdb.connect() - duckdb_cursor.execute("CREATE table t as SELECT CASE WHEN i % 2 = 0 THEN true ELSE false END AS a from range(3000) as tbl(i);") + duckdb_cursor.execute( + "CREATE table t as SELECT CASE WHEN i % 2 = 0 THEN true ELSE false END AS a from range(3000) as tbl(i);" + ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 952) + assert len(chunk) == 952 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - + res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() assert res == correct @@ -62,11 +65,11 @@ def test_record_batch_next_batch_varchar(self, duckdb_cursor): record_batch_reader = query.fetch_record_batch(1024) assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 952) + assert len(chunk) == 952 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() @@ -82,16 +85,18 @@ def test_record_batch_next_batch_varchar(self, duckdb_cursor): def test_record_batch_next_batch_struct(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor_check = duckdb.connect() - duckdb_cursor.execute("CREATE table t as select {'x': i, 'y': i::varchar, 'z': i+1} as a from range(3000) as tbl(i);") + duckdb_cursor.execute( + "CREATE table t as select {'x': i, 'y': i::varchar, 'z': i+1} as a from range(3000) as tbl(i);" + ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 952) + assert len(chunk) == 952 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() @@ -112,11 +117,11 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): record_batch_reader = query.fetch_record_batch(1024) assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 952) + assert len(chunk) == 952 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() @@ -138,11 +143,11 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): record_batch_reader = query.fetch_record_batch(1024) assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 952) + assert len(chunk) == 952 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() @@ -159,16 +164,18 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): def test_record_batch_next_batch_with_null(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor_check = duckdb.connect() - duckdb_cursor.execute("CREATE table t as SELECT CASE WHEN i % 2 = 0 THEN i ELSE NULL END AS a from range(3000) as tbl(i);") + duckdb_cursor.execute( + "CREATE table t as SELECT CASE WHEN i % 2 = 0 THEN i ELSE NULL END AS a from range(3000) as tbl(i);" + ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1024) + assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 952) + assert len(chunk) == 952 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() @@ -187,7 +194,7 @@ def test_record_batch_read_default(self, duckdb_cursor): query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch() chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 3000) + assert len(chunk) == 3000 def test_record_batch_next_batch_multiple_vectors_per_chunk(self, duckdb_cursor): duckdb_cursor = duckdb.connect() @@ -195,24 +202,23 @@ def test_record_batch_next_batch_multiple_vectors_per_chunk(self, duckdb_cursor) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(2048) chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 2048) + assert len(chunk) == 2048 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 2048) + assert len(chunk) == 2048 chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 904) + assert len(chunk) == 904 with pytest.raises(StopIteration): chunk = record_batch_reader.read_next_batch() query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1) chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 1) + assert len(chunk) == 1 query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(2000) chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 2000) - + assert len(chunk) == 2000 def test_record_batch_next_batch_multiple_vectors_per_chunk_error(self, duckdb_cursor): duckdb_cursor = duckdb.connect() @@ -229,7 +235,7 @@ def test_record_batch_reader_from_relation(self, duckdb_cursor): relation = duckdb_cursor.table('t') record_batch_reader = relation.record_batch() chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == 3000) + assert len(chunk) == 3000 def test_record_coverage(self, duckdb_cursor): duckdb_cursor = duckdb.connect() @@ -238,7 +244,7 @@ def test_record_coverage(self, duckdb_cursor): record_batch_reader = query.fetch_record_batch(1024) chunk = record_batch_reader.read_all() - assert(len(chunk) == 2048) + assert len(chunk) == 2048 def test_record_batch_query_error(self): duckdb_cursor = duckdb.connect() @@ -251,9 +257,11 @@ def test_record_batch_query_error(self): def test_many_list_batches(self): conn = duckdb.connect() - conn.execute(""" + conn.execute( + """ create or replace table tbl as select * from (select {'a': [5,4,3,2,1]}), range(10000000) - """) + """ + ) query = "SELECT * FROM tbl" chunk_size = 1_000_000 @@ -264,22 +272,21 @@ def test_many_list_batches(self): for batch in batch_iter: del batch - def test_many_chunk_sizes(self): object_size = 1000000 duckdb_cursor = duckdb.connect() query = duckdb_cursor.execute(f"CREATE table t as select range a from range({object_size});") for i in [1, 2, 4, 8, 16, 32, 33, 77, 999, 999999]: query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(i) - num_loops = int(object_size/i) + record_batch_reader = query.fetch_record_batch(i) + num_loops = int(object_size / i) for j in range(num_loops): assert record_batch_reader.schema.names == ['a'] chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == i) - remainder = object_size%i + assert len(chunk) == i + remainder = object_size % i if remainder > 0: chunk = record_batch_reader.read_next_batch() - assert(len(chunk) == remainder) + assert len(chunk) == remainder with pytest.raises(StopIteration): - chunk = record_batch_reader.read_next_batch() + chunk = record_batch_reader.read_next_batch() diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_list.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_list.py index 6f12f6c2867b..a4860a15c399 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_list.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_list.py @@ -1,31 +1,33 @@ import duckdb import numpy as np + try: import pyarrow as pa + can_run = True except: can_run = False + def check_equal(duckdb_conn): true_result = duckdb_conn.execute("SELECT * from test").fetchall() arrow_result = duckdb_conn.execute("SELECT * from testarrow").fetchall() - assert(arrow_result == true_result) + assert arrow_result == true_result -def create_and_register_arrow_table(column_list, duckdb_conn): +def create_and_register_arrow_table(column_list, duckdb_conn): pydict = {name: data for (name, _, data) in column_list} - arrow_schema = pa.schema([ - (name, dtype) for (name, dtype, _) in column_list - ]) + arrow_schema = pa.schema([(name, dtype) for (name, dtype, _) in column_list]) res = pa.Table.from_pydict(pydict, schema=arrow_schema) duck_from_arrow = duckdb_conn.from_arrow(res) duck_from_arrow.create("testarrow") + def create_and_register_comparison_result(column_list, duckdb_conn): columns = ",".join([f'{name} {dtype}' for (name, dtype, _) in column_list]) column_amount = len(column_list) - assert(column_amount) + assert column_amount row_amount = len(column_list[0][2]) inserted_values = [] for row in range(row_amount): @@ -41,25 +43,32 @@ def create_and_register_comparison_result(column_list, duckdb_conn): duckdb_conn.execute(query, inserted_values) + class TestArrowListType(object): def test_regular_list(self): if not can_run: return duckdb_conn = duckdb.connect() - n = 5 #Amount of lists - generated_size = 3 #Size of each list - list_size = -1 #Argument passed to `pa._list()` + n = 5 # Amount of lists + generated_size = 3 # Size of each list + list_size = -1 # Argument passed to `pa._list()` data = [np.random.random((generated_size)) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) - create_and_register_arrow_table([ - ('a', list_type, data), - ], duckdb_conn) - create_and_register_comparison_result([ - ('a', 'FLOAT[]', data), - ], duckdb_conn) + create_and_register_arrow_table( + [ + ('a', list_type, data), + ], + duckdb_conn, + ) + create_and_register_comparison_result( + [ + ('a', 'FLOAT[]', data), + ], + duckdb_conn, + ) check_equal(duckdb_conn) @@ -68,18 +77,24 @@ def test_fixedsize_list(self): return duckdb_conn = duckdb.connect() - n = 5 #Amount of lists - generated_size = 3 #Size of each list - list_size = 3 #Argument passed to `pa._list()` + n = 5 # Amount of lists + generated_size = 3 # Size of each list + list_size = 3 # Argument passed to `pa._list()` data = [np.random.random((generated_size)) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) - create_and_register_arrow_table([ - ('a', list_type, data), - ], duckdb_conn) - create_and_register_comparison_result([ - ('a', 'FLOAT[]', data), - ], duckdb_conn) + create_and_register_arrow_table( + [ + ('a', list_type, data), + ], + duckdb_conn, + ) + create_and_register_comparison_result( + [ + ('a', 'FLOAT[]', data), + ], + duckdb_conn, + ) check_equal(duckdb_conn) diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_recordbatchreader.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_recordbatchreader.py index 052fcacb813b..05325375a3b5 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_recordbatchreader.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_recordbatchreader.py @@ -1,106 +1,143 @@ import duckdb import os + try: import pyarrow import pyarrow.parquet import pyarrow.dataset import numpy as np + can_run = True except: can_run = False -class TestArrowRecordBatchReader(object): - def test_parallel_reader(self,duckdb_cursor): +class TestArrowRecordBatchReader(object): + def test_parallel_reader(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - userdata_parquet_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + userdata_parquet_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) - batches= [r for r in userdata_parquet_dataset.to_batches()] - reader=pyarrow.dataset.Scanner.from_batches(batches,schema=userdata_parquet_dataset.schema).to_reader() + batches = [r for r in userdata_parquet_dataset.to_batches()] + reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() rel = duckdb_conn.from_arrow(reader) - assert rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + assert ( + rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + ) # The reader is already consumed so this should be 0 - assert rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 + assert ( + rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 + ) - def test_parallel_reader_replacement_scans(self,duckdb_cursor): + def test_parallel_reader_replacement_scans(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') - - userdata_parquet_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") - - batches= [r for r in userdata_parquet_dataset.to_batches()] - reader=pyarrow.dataset.Scanner.from_batches(batches,schema=userdata_parquet_dataset.schema).to_reader() - - assert duckdb_conn.execute("select count(*) from reader where first_name=\'Jose\' and salary > 134708.82").fetchone()[0] == 12 - assert duckdb_conn.execute("select count(*) from reader where first_name=\'Jose\' and salary > 134708.82").fetchone()[0] == 0 - - def test_parallel_reader_register(self,duckdb_cursor): + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + + userdata_parquet_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) + + batches = [r for r in userdata_parquet_dataset.to_batches()] + reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() + + assert ( + duckdb_conn.execute( + "select count(*) from reader where first_name=\'Jose\' and salary > 134708.82" + ).fetchone()[0] + == 12 + ) + assert ( + duckdb_conn.execute( + "select count(*) from reader where first_name=\'Jose\' and salary > 134708.82" + ).fetchone()[0] + == 0 + ) + + def test_parallel_reader_register(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - userdata_parquet_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + userdata_parquet_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) - batches= [r for r in userdata_parquet_dataset.to_batches()] - reader=pyarrow.dataset.Scanner.from_batches(batches,schema=userdata_parquet_dataset.schema).to_reader() + batches = [r for r in userdata_parquet_dataset.to_batches()] + reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() duckdb_conn.register("bla", reader) - assert duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[0] == 12 - assert duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[0] == 0 - - def test_parallel_reader_default_conn(self,duckdb_cursor): + assert ( + duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ + 0 + ] + == 12 + ) + assert ( + duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ + 0 + ] + == 0 + ) + + def test_parallel_reader_default_conn(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - userdata_parquet_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + userdata_parquet_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) - batches= [r for r in userdata_parquet_dataset.to_batches()] - reader=pyarrow.dataset.Scanner.from_batches(batches,schema=userdata_parquet_dataset.schema).to_reader() + batches = [r for r in userdata_parquet_dataset.to_batches()] + reader = pyarrow.dataset.Scanner.from_batches(batches, schema=userdata_parquet_dataset.schema).to_reader() rel = duckdb.from_arrow(reader) - assert rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + assert ( + rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + ) # The reader is already consumed so this should be 0 - assert rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 - + assert ( + rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 + ) diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_replacement_scan.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_replacement_scan.py index d635042f8a2f..b3a4abd4388d 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_replacement_scan.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_replacement_scan.py @@ -2,37 +2,40 @@ import pytest import os import pandas as pd + try: import pyarrow.parquet as pq import pyarrow.dataset as ds + can_run = True except: can_run = False + class TestArrowReplacementScan(object): def test_arrow_table_replacement_scan(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') userdata_parquet_table = pq.read_table(parquet_filename) df = userdata_parquet_table.to_pandas() con = duckdb.connect() - - for i in range (5): - assert con.execute("select count(*) from userdata_parquet_table").fetchone() == (1000,) - assert con.execute("select count(*) from df").fetchone() == (1000,) + + for i in range(5): + assert con.execute("select count(*) from userdata_parquet_table").fetchone() == (1000,) + assert con.execute("select count(*) from df").fetchone() == (1000,) def test_arrow_table_replacement_scan_view(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') userdata_parquet_table = pq.read_table(parquet_filename) con = duckdb.connect() - + con.execute("create view x as select * from userdata_parquet_table") del userdata_parquet_table with pytest.raises(duckdb.CatalogException, match='Table with name userdata_parquet_table does not exist'): @@ -41,9 +44,9 @@ def test_arrow_table_replacement_scan_view(self, duckdb_cursor): def test_arrow_dataset_replacement_scan(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') userdata_parquet_table = pq.read_table(parquet_filename) - userdata_parquet_dataset= ds.dataset(parquet_filename) + userdata_parquet_dataset = ds.dataset(parquet_filename) con = duckdb.connect() - assert con.execute("select count(*) from userdata_parquet_dataset").fetchone() == (1000,) + assert con.execute("select count(*) from userdata_parquet_dataset").fetchone() == (1000,) diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_scanner.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_scanner.py index 9522127bcc9d..6d74ddb50751 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_scanner.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_scanner.py @@ -1,5 +1,6 @@ import duckdb import os + try: import pyarrow import pyarrow.parquet @@ -7,27 +8,30 @@ from pyarrow.dataset import Scanner import pyarrow.compute as pc import numpy as np + can_run = True except: can_run = False -class TestArrowScanner(object): - def test_parallel_scanner(self,duckdb_cursor): +class TestArrowScanner(object): + def test_parallel_scanner(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - arrow_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + arrow_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) @@ -37,21 +41,23 @@ def test_parallel_scanner(self,duckdb_cursor): assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 - def test_parallel_scanner_replacement_scans(self,duckdb_cursor): + def test_parallel_scanner_replacement_scans(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - arrow_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + arrow_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) @@ -59,22 +65,23 @@ def test_parallel_scanner_replacement_scans(self,duckdb_cursor): assert duckdb_conn.execute("select count(*) from arrow_scanner").fetchone()[0] == 12 - - def test_parallel_scanner_register(self,duckdb_cursor): + def test_parallel_scanner_register(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - arrow_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + arrow_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) @@ -84,18 +91,20 @@ def test_parallel_scanner_register(self,duckdb_cursor): assert duckdb_conn.execute("select count(*) from bla").fetchone()[0] == 12 - def test_parallel_scanner_default_conn(self,duckdb_cursor): + def test_parallel_scanner_default_conn(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - arrow_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + arrow_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) @@ -104,5 +113,3 @@ def test_parallel_scanner_default_conn(self,duckdb_cursor): rel = duckdb.from_arrow(arrow_scanner) assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 - - \ No newline at end of file diff --git a/tools/pythonpkg/tests/fast/arrow/test_arrow_types.py b/tools/pythonpkg/tests/fast/arrow/test_arrow_types.py index bb8bade80822..379f517167ae 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_arrow_types.py +++ b/tools/pythonpkg/tests/fast/arrow/test_arrow_types.py @@ -1,26 +1,27 @@ import duckdb + try: import pyarrow as pa import pyarrow.dataset as ds + can_run = True except: can_run = False + class TestArrowTypes(object): - def test_null_type(self, duckdb_cursor): if not can_run: return schema = pa.schema([("data", pa.null())]) - inputs = [pa.array([None,None,None], type=pa.null())] + inputs = [pa.array([None, None, None], type=pa.null())] arrow_table = pa.Table.from_arrays(inputs, schema=schema) duckdb_conn = duckdb.connect() - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) rel = duckdb.from_arrow(arrow_table).arrow() # We turn it to an array of int32 nulls schema = pa.schema([("data", pa.int32())]) - inputs = [pa.array([None,None,None], type=pa.null())] + inputs = [pa.array([None, None, None], type=pa.null())] arrow_table = pa.Table.from_arrays(inputs, schema=schema) assert rel['data'] == arrow_table['data'] - diff --git a/tools/pythonpkg/tests/fast/arrow/test_binary_type.py b/tools/pythonpkg/tests/fast/arrow/test_binary_type.py index 3cf2b11746cd..489d4caf826a 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_binary_type.py +++ b/tools/pythonpkg/tests/fast/arrow/test_binary_type.py @@ -1,20 +1,24 @@ import duckdb import os + try: import pyarrow as pa from pyarrow import parquet as pq import numpy as np + can_run = True except: can_run = False + def create_binary_table(type): schema = pa.schema([("data", type)]) inputs = [pa.array([b"foo", b"bar", b"baz"], type=type)] return pa.Table.from_arrays(inputs, schema=schema) + class TestArrowBinary(object): - def test_binary_types(self,duckdb_cursor): + def test_binary_types(self, duckdb_cursor): if not can_run: return @@ -35,5 +39,3 @@ def test_binary_types(self,duckdb_cursor): rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [(b"foo",), (b"bar",), (b"baz",)] - - \ No newline at end of file diff --git a/tools/pythonpkg/tests/fast/arrow/test_buffer_size_option.py b/tools/pythonpkg/tests/fast/arrow/test_buffer_size_option.py index a8d2544c7720..f1bcf6dd878f 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_buffer_size_option.py +++ b/tools/pythonpkg/tests/fast/arrow/test_buffer_size_option.py @@ -4,16 +4,17 @@ pa = pytest.importorskip("pyarrow") from duckdb.typing import * + class TestArrowBufferSize(object): def test_arrow_buffer_size(self): con = duckdb.connect() - + # All small string res = con.query("select 'bla'").arrow() assert res[0][0].type == pa.string() res = con.query("select 'bla'").record_batch() assert res.schema[0].type == pa.string() - + # All Large String con.execute("SET arrow_large_buffer_size=True") res = con.query("select 'bla'").arrow() @@ -31,7 +32,7 @@ def test_arrow_buffer_size(self): def test_arrow_buffer_size_udf(self): def just_return(x): return x - + con = duckdb.connect() con.create_function('just_return', just_return, [VARCHAR], VARCHAR, type='arrow') @@ -41,6 +42,6 @@ def just_return(x): # All Large String con.execute("SET arrow_large_buffer_size=True") - + res = con.query("select just_return('bla')").arrow() - assert res[0][0].type == pa.large_string() \ No newline at end of file + assert res[0][0].type == pa.large_string() diff --git a/tools/pythonpkg/tests/fast/arrow/test_dataset.py b/tools/pythonpkg/tests/fast/arrow/test_dataset.py index bef307f3a0f5..2f3d7a53664f 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_dataset.py +++ b/tools/pythonpkg/tests/fast/arrow/test_dataset.py @@ -1,6 +1,7 @@ import duckdb import os import pytest + pyarrow = pytest.importorskip("pyarrow") np = pytest.importorskip("numpy") pyarrow.parquet = pytest.importorskip("pyarrow.parquet") @@ -8,60 +9,72 @@ class TestArrowDataset(object): - - def test_parallel_dataset(self,duckdb_cursor): + def test_parallel_dataset(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - userdata_parquet_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + userdata_parquet_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) rel = duckdb_conn.from_arrow(userdata_parquet_dataset) - assert rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + assert ( + rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + ) - def test_parallel_dataset_register(self,duckdb_cursor): + def test_parallel_dataset_register(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - userdata_parquet_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + userdata_parquet_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) - rel = duckdb_conn.register("dataset",userdata_parquet_dataset) + rel = duckdb_conn.register("dataset", userdata_parquet_dataset) - assert duckdb_conn.execute("Select count(*) from dataset where first_name = 'Jose' and salary > 134708.82").fetchone()[0] == 12 + assert ( + duckdb_conn.execute( + "Select count(*) from dataset where first_name = 'Jose' and salary > 134708.82" + ).fetchone()[0] + == 12 + ) - def test_parallel_dataset_roundtrip(self,duckdb_cursor): + def test_parallel_dataset_roundtrip(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - userdata_parquet_dataset= pyarrow.dataset.dataset([ - parquet_filename, - parquet_filename, - parquet_filename, - ] - , format="parquet") + userdata_parquet_dataset = pyarrow.dataset.dataset( + [ + parquet_filename, + parquet_filename, + parquet_filename, + ], + format="parquet", + ) - rel = duckdb_conn.register("dataset",userdata_parquet_dataset) + rel = duckdb_conn.register("dataset", userdata_parquet_dataset) - query = duckdb_conn.execute("SELECT * FROM dataset order by id" ) + query = duckdb_conn.execute("SELECT * FROM dataset order by id") record_batch_reader = query.fetch_record_batch(2048) arrow_table = record_batch_reader.read_all() @@ -75,7 +88,6 @@ def test_parallel_dataset_roundtrip(self,duckdb_cursor): assert result_1 == result_2 - def test_ducktyping(self, duckdb_cursor): duckdb_conn = duckdb.connect() dataset = CustomDataset() @@ -87,11 +99,9 @@ def test_ducktyping(self, duckdb_cursor): class CustomDataset(pyarrow.dataset.Dataset): # For testing duck-typing of dataset/scanner https://github.com/duckdb/duckdb/pull/5998 - SCHEMA = pyarrow.schema([pyarrow.field("a", pyarrow.int64(), True), - pyarrow.field("b", pyarrow.float64(), True)]) - DATA = pyarrow.Table.from_arrays([pyarrow.array(range(100)), - pyarrow.array(np.arange(100)*1.0)], - schema=SCHEMA) + SCHEMA = pyarrow.schema([pyarrow.field("a", pyarrow.int64(), True), pyarrow.field("b", pyarrow.float64(), True)]) + DATA = pyarrow.Table.from_arrays([pyarrow.array(range(100)), pyarrow.array(np.arange(100) * 1.0)], schema=SCHEMA) + def __init__(self): pass @@ -104,7 +114,6 @@ def schema(self): class CustomScanner(pyarrow.dataset.Scanner): - def __init__(self, filter=None, columns=None, **kwargs): self.filter = filter self.columns = columns @@ -115,10 +124,7 @@ def projected_schema(self): if self.columns is None: return CustomDataset.SCHEMA else: - return pyarrow.schema([f for f in CustomDataset.SCHEMA.fields - if f.name in self.columns]) + return pyarrow.schema([f for f in CustomDataset.SCHEMA.fields if f.name in self.columns]) def to_reader(self): - return pyarrow.dataset.dataset(CustomDataset.DATA).scanner( - filter=self.filter, columns=self.columns - ).to_reader() \ No newline at end of file + return pyarrow.dataset.dataset(CustomDataset.DATA).scanner(filter=self.filter, columns=self.columns).to_reader() diff --git a/tools/pythonpkg/tests/fast/arrow/test_date.py b/tools/pythonpkg/tests/fast/arrow/test_date.py index 277139f13667..ad5ec96a4d93 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_date.py +++ b/tools/pythonpkg/tests/fast/arrow/test_date.py @@ -2,41 +2,46 @@ import os import datetime import pytest + try: import pyarrow as pa import pandas as pd + can_run = True except: can_run = False + class TestArrowDate(object): def test_date_types(self, duckdb_cursor): if not can_run: return - - data = (pa.array([1000*60*60*24], type=pa.date64()),pa.array([1], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0],data[1]],['a','b']) - rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == arrow_table['b']) - assert (rel['b'] == arrow_table['b']) + data = (pa.array([1000 * 60 * 60 * 24], type=pa.date64()), pa.array([1], type=pa.date32())) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + rel = duckdb.from_arrow(arrow_table).arrow() + assert rel['a'] == arrow_table['b'] + assert rel['b'] == arrow_table['b'] def test_date_null(self, duckdb_cursor): if not can_run: - return - data = (pa.array([None], type=pa.date64()),pa.array([None], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0],data[1]],['a','b']) + return + data = (pa.array([None], type=pa.date64()), pa.array([None], type=pa.date32())) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == arrow_table['b']) - assert (rel['b'] == arrow_table['b']) + assert rel['a'] == arrow_table['b'] + assert rel['b'] == arrow_table['b'] def test_max_date(self, duckdb_cursor): if not can_run: - return - data = (pa.array([2147483647], type=pa.date32()),pa.array([2147483647], type=pa.date32())) - result = pa.Table.from_arrays([data[0],data[1]],['a','b']) - data = (pa.array([2147483647*(1000*60*60*24)], type=pa.date64()),pa.array([2147483647], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0],data[1]],['a','b']) + return + data = (pa.array([2147483647], type=pa.date32()), pa.array([2147483647], type=pa.date32())) + result = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + data = ( + pa.array([2147483647 * (1000 * 60 * 60 * 24)], type=pa.date64()), + pa.array([2147483647], type=pa.date32()), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == result['a']) - assert (rel['b'] == result['b']) \ No newline at end of file + assert rel['a'] == result['a'] + assert rel['b'] == result['b'] diff --git a/tools/pythonpkg/tests/fast/arrow/test_dictionary_arrow.py b/tools/pythonpkg/tests/fast/arrow/test_dictionary_arrow.py index 7ab29e507504..f389836529fe 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_dictionary_arrow.py +++ b/tools/pythonpkg/tests/fast/arrow/test_dictionary_arrow.py @@ -1,4 +1,5 @@ import duckdb + try: import pyarrow as pa import pyarrow.parquet @@ -6,47 +7,48 @@ from pandas import Timestamp import datetime import pandas as pd + can_run = True except: can_run = False -class TestArrowDictionary(object): - def test_dictionary(self,duckdb_cursor): +class TestArrowDictionary(object): + def test_dictionary(self, duckdb_cursor): if not can_run: return indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) assert rel.execute().fetchall() == [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] # Bigger than Vector Size - indices_list = [0, 1, 0, 1, 2, 1, 0, 2,3] * 10000 + indices_list = [0, 1, 0, 1, 2, 1, 0, 2, 3] * 10000 indices = pa.array(indices_list) - dictionary = pa.array([10, 100, None,999999]) + dictionary = pa.array([10, 100, None, 999999]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) result = [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,), (999999,)] * 10000 assert rel.execute().fetchall() == result - #Table with dictionary and normal array + # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array,pa.array(indices_list)],['a','b']) + arrow_table = pa.Table.from_arrays([dict_array, pa.array(indices_list)], ['a', 'b']) rel = duckdb.from_arrow(arrow_table) - result = [(10,0), (100,1), (10,0), (100,1), (None,2), (100,1), (10,0), (None,2), (999999,3)] * 10000 + result = [(10, 0), (100, 1), (10, 0), (100, 1), (None, 2), (100, 1), (10, 0), (None, 2), (999999, 3)] * 10000 assert rel.execute().fetchall() == result - def test_dictionary_null_index(self,duckdb_cursor): + def test_dictionary_null_index(self, duckdb_cursor): if not can_run: return indices = pa.array([None, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) assert rel.execute().fetchall() == [(None,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] @@ -54,9 +56,9 @@ def test_dictionary_null_index(self,duckdb_cursor): indices = pa.array([None, 1, None, 1, 2, 1, 0]) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) - print (rel.execute().fetchall()) + print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] # Test Big Vector @@ -64,18 +66,18 @@ def test_dictionary_null_index(self,duckdb_cursor): indices = pa.array(indices_list * 1000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 1000 assert rel.execute().fetchall() == result - #Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array,indices],['a','b']) + # Table with dictionary and normal array + arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) rel = duckdb.from_arrow(arrow_table) - result = [(None,None), (100,1), (None,None), (100,1), (100,2), (100,1), (10,0)] * 1000 + result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 1000 assert rel.execute().fetchall() == result - def test_dictionary_batches(self,duckdb_cursor): + def test_dictionary_batches(self, duckdb_cursor): if not can_run: return @@ -83,20 +85,20 @@ def test_dictionary_batches(self,duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) batch_arrow_table = pyarrow.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result - #Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array,indices],['a','b']) + # Table with dictionary and normal array + arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) batch_arrow_table = pyarrow.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb.from_arrow(batch_arrow_table) - result = [(None,None), (100,1), (None,None), (100,1), (100,2), (100,1), (10,0)] * 10000 + result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 assert rel.execute().fetchall() == result - def test_dictionary_batches_parallel(self,duckdb_cursor): + def test_dictionary_batches_parallel(self, duckdb_cursor): if not can_run: return @@ -108,20 +110,20 @@ def test_dictionary_batches_parallel(self,duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) batch_arrow_table = pyarrow.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_conn.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result - #Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array,indices],['a','b']) + # Table with dictionary and normal array + arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) batch_arrow_table = pyarrow.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_conn.from_arrow(batch_arrow_table) - result = [(None,None), (100,1), (None,None), (100,1), (100,2), (100,1), (10,0)] * 10000 + result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 assert rel.execute().fetchall() == result - def test_dictionary_index_types(self,duckdb_cursor): + def test_dictionary_index_types(self, duckdb_cursor): if not can_run: return indices_list = [None, 1, None, 1, 2, 1, 0] @@ -138,13 +140,12 @@ def test_dictionary_index_types(self,duckdb_cursor): for index_type in index_types: dict_array = pa.DictionaryArray.from_arrays(index_type, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) - result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)]* 10000 + result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result - - def test_dictionary_strings(self,duckdb_cursor): + def test_dictionary_strings(self, duckdb_cursor): if not can_run: return @@ -152,21 +153,42 @@ def test_dictionary_strings(self,duckdb_cursor): indices = pa.array(indices_list * 1000) dictionary = pa.array(['Matt Daaaaaaaaamon', 'Alec Baldwin', 'Sean Penn', 'Tim Robbins', 'Samuel L. Jackson']) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) - result = [(None,), ('Matt Daaaaaaaaamon',), ( 'Alec Baldwin',), ('Sean Penn',), ('Tim Robbins',), ('Samuel L. Jackson',), (None,)] * 1000 + result = [ + (None,), + ('Matt Daaaaaaaaamon',), + ('Alec Baldwin',), + ('Sean Penn',), + ('Tim Robbins',), + ('Samuel L. Jackson',), + (None,), + ] * 1000 assert rel.execute().fetchall() == result - def test_dictionary_timestamps(self,duckdb_cursor): + def test_dictionary_timestamps(self, duckdb_cursor): if not can_run: return indices_list = [None, 0, 1, 2, None] indices = pa.array(indices_list * 1000) - dictionary = pa.array([Timestamp(year=2001, month=9, day=25),Timestamp(year=2006, month=11, day=14),Timestamp(year=2012, month=5, day=15),Timestamp(year=2018, month=11, day=2)]) + dictionary = pa.array( + [ + Timestamp(year=2001, month=9, day=25), + Timestamp(year=2006, month=11, day=14), + Timestamp(year=2012, month=5, day=15), + Timestamp(year=2018, month=11, day=2), + ] + ) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array],['a']) + arrow_table = pa.Table.from_arrays([dict_array], ['a']) rel = duckdb.from_arrow(arrow_table) - print (rel.execute().fetchall()) - expected = [(None,), (datetime.datetime(2001, 9, 25, 0, 0),), (datetime.datetime(2006, 11, 14, 0, 0),), (datetime.datetime(2012, 5, 15, 0, 0),), (None,)] * 1000 + print(rel.execute().fetchall()) + expected = [ + (None,), + (datetime.datetime(2001, 9, 25, 0, 0),), + (datetime.datetime(2006, 11, 14, 0, 0),), + (datetime.datetime(2012, 5, 15, 0, 0),), + (None,), + ] * 1000 result = rel.execute().fetchall() assert result == expected diff --git a/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py b/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py index ddec7739c47c..5d0049caebda 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py +++ b/tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py @@ -2,6 +2,7 @@ import os import pytest import tempfile + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") @@ -12,39 +13,40 @@ ## DuckDB connection used in this test duckdb_conn = duckdb.connect() + def numeric_operators(data_type, tbl_name): - duckdb_conn.execute("CREATE TABLE " +tbl_name+ " (a "+data_type+", b "+data_type+", c "+data_type+")") - duckdb_conn.execute("INSERT INTO " +tbl_name+ " VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") - duck_tbl = duckdb_conn.table(tbl_name) - arrow_table = duck_tbl.arrow() - print (arrow_table) + duckdb_conn.execute("CREATE TABLE " + tbl_name + " (a " + data_type + ", b " + data_type + ", c " + data_type + ")") + duckdb_conn.execute("INSERT INTO " + tbl_name + " VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") + duck_tbl = duckdb_conn.table(tbl_name) + arrow_table = duck_tbl.arrow() + print(arrow_table) - duckdb_conn.register("testarrow",arrow_table) - # Try == - assert duckdb_conn.execute("SELECT count(*) from testarrow where a =1").fetchone()[0] == 1 - # Try > - assert duckdb_conn.execute("SELECT count(*) from testarrow where a >1").fetchone()[0] == 2 - # Try >= - assert duckdb_conn.execute("SELECT count(*) from testarrow where a >=10").fetchone()[0] == 2 - # Try < - assert duckdb_conn.execute("SELECT count(*) from testarrow where a <10").fetchone()[0] == 1 - # Try <= - assert duckdb_conn.execute("SELECT count(*) from testarrow where a <=10").fetchone()[0] == 2 + duckdb_conn.register("testarrow", arrow_table) + # Try == + assert duckdb_conn.execute("SELECT count(*) from testarrow where a =1").fetchone()[0] == 1 + # Try > + assert duckdb_conn.execute("SELECT count(*) from testarrow where a >1").fetchone()[0] == 2 + # Try >= + assert duckdb_conn.execute("SELECT count(*) from testarrow where a >=10").fetchone()[0] == 2 + # Try < + assert duckdb_conn.execute("SELECT count(*) from testarrow where a <10").fetchone()[0] == 1 + # Try <= + assert duckdb_conn.execute("SELECT count(*) from testarrow where a <=10").fetchone()[0] == 2 - # Try Is Null - assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NULL").fetchone()[0] == 1 - # Try Is Not Null - assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 + # Try Is Null + assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 - # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a=10 and b =1").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a =100 and b = 10 and c = 100").fetchone()[0] == 1 + # Try And + assert duckdb_conn.execute("SELECT count(*) from testarrow where a=10 and b =1").fetchone()[0] == 0 + assert duckdb_conn.execute("SELECT count(*) from testarrow where a =100 and b = 10 and c = 100").fetchone()[0] == 1 - # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = 100 or b =1").fetchone()[0] == 2 + # Try Or + assert duckdb_conn.execute("SELECT count(*) from testarrow where a = 100 or b =1").fetchone()[0] == 2 - duckdb_conn.execute("EXPLAIN SELECT count(*) from testarrow where a = 100 or b =1") - print(duckdb_conn.fetchall()) + duckdb_conn.execute("EXPLAIN SELECT count(*) from testarrow where a = 100 or b =1") + print(duckdb_conn.fetchall()) def numeric_check_or_pushdown(tbl_name): @@ -52,35 +54,41 @@ def numeric_check_or_pushdown(tbl_name): arrow_table = duck_tbl.arrow() arrow_tbl_name = "testarrow_" + tbl_name - duckdb_conn.register(arrow_tbl_name ,arrow_table) + duckdb_conn.register(arrow_tbl_name, arrow_table) # Multiple column in the root OR node, don't push down - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR b=2 AND (a>3 OR b<5)").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR b=2 AND (a>3 OR b<5)" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match # Single column in the root OR node - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR a=10").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR a=10").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a=10.*|$", query_res[0][1]) assert match # Single column + root OR node with AND - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR (a>3 AND a<5)").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR (a>3 AND a<5)" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a>3 AND a<5.*|$", query_res[0][1]) assert match # Single column multiple ORs - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR a>3 OR a<5").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR a>3 OR a<5").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a>3 OR a<5.*|$", query_res[0][1]) assert match # Testing not equal - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a!=1 OR a>3 OR a<2").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a!=1 OR a>3 OR a<2").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a!=1 OR a>3 OR a<2.*|$", query_res[0][1]) assert match # Multiple OR filters connected with ANDs - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE (a<2 OR a>3) AND (a=1 OR a=4) AND (b=1 OR b<5)").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE (a<2 OR a>3) AND (a=1 OR a=4) AND (b=1 OR b<5)" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters: a<2 OR a>3 AND a=1|\n.*OR a=4.*\n.*b=2 OR b<5.*|$", query_res[0][1]) assert match @@ -90,54 +98,72 @@ def string_check_or_pushdown(tbl_name): arrow_table = duck_tbl.arrow() arrow_tbl_name = "testarrow_varchar" - duckdb_conn.register(arrow_tbl_name ,arrow_table) + duckdb_conn.register(arrow_tbl_name, arrow_table) # Check string zonemap - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a>='1' OR a<='10'").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a>='1' OR a<='10'").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a>=1 OR a<=10.*|$", query_res[0][1]) assert match # No support for OR with is null - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a IS NULL or a='1'").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a IS NULL or a='1'").fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match # No support for OR with is not null - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a IS NOT NULL OR a='1'").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a IS NOT NULL OR a='1'" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match # OR with the like operator - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR a LIKE '10%'").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR a LIKE '10%'").fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match class TestArrowFilterPushdown(object): - def test_filter_pushdown_numeric(self,duckdb_cursor): - - numeric_types = ['TINYINT', 'SMALLINT', 'INTEGER', 'BIGINT', 'UTINYINT', 'USMALLINT', 'UINTEGER', 'UBIGINT', - 'FLOAT', 'DOUBLE', 'HUGEINT'] + def test_filter_pushdown_numeric(self, duckdb_cursor): + numeric_types = [ + 'TINYINT', + 'SMALLINT', + 'INTEGER', + 'BIGINT', + 'UTINYINT', + 'USMALLINT', + 'UINTEGER', + 'UBIGINT', + 'FLOAT', + 'DOUBLE', + 'HUGEINT', + ] for data_type in numeric_types: tbl_name = "test_" + data_type numeric_operators(data_type, tbl_name) numeric_check_or_pushdown(tbl_name) - def test_filter_pushdown_decimal(self,duckdb_cursor): - numeric_types = {'DECIMAL(4,1)': 'test_decimal_4_1', 'DECIMAL(9,1)': 'test_decimal_9_1', - 'DECIMAL(18,4)': 'test_decimal_18_4','DECIMAL(30,12)': 'test_decimal_30_12'} + def test_filter_pushdown_decimal(self, duckdb_cursor): + numeric_types = { + 'DECIMAL(4,1)': 'test_decimal_4_1', + 'DECIMAL(9,1)': 'test_decimal_9_1', + 'DECIMAL(18,4)': 'test_decimal_18_4', + 'DECIMAL(30,12)': 'test_decimal_30_12', + } for data_type in numeric_types: tbl_name = numeric_types[data_type] numeric_operators(data_type, tbl_name) numeric_check_or_pushdown(tbl_name) - def test_filter_pushdown_varchar(self,duckdb_cursor): + def test_filter_pushdown_varchar(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_varchar (a VARCHAR, b VARCHAR, c VARCHAR)") - duckdb_conn.execute("INSERT INTO test_varchar VALUES ('1','1','1'),('10','10','10'),('100','10','100'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_varchar VALUES ('1','1','1'),('10','10','10'),('100','10','100'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_varchar") arrow_table = duck_tbl.arrow() - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='1'").fetchone()[0] == 1 # Try > @@ -156,21 +182,25 @@ def test_filter_pushdown_varchar(self,duckdb_cursor): # Try And assert duckdb_conn.execute("SELECT count(*) from testarrow where a='10' and b ='1'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='100' and b = '10' and c = '100'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a ='100' and b = '10' and c = '100'").fetchone()[ + 0 + ] + == 1 + ) # Try Or assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '100' or b ='1'").fetchone()[0] == 2 # More complex tests for OR pushed down on string string_check_or_pushdown("test_varchar") - - def test_filter_pushdown_bool(self,duckdb_cursor): + def test_filter_pushdown_bool(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_bool (a BOOL, b BOOL)") duckdb_conn.execute("INSERT INTO test_bool VALUES (TRUE,TRUE),(TRUE,FALSE),(FALSE,TRUE),(NULL,NULL)") duck_tbl = duckdb_conn.table("test_bool") arrow_table = duck_tbl.arrow() - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a =True").fetchone()[0] == 2 @@ -184,13 +214,15 @@ def test_filter_pushdown_bool(self,duckdb_cursor): # Try Or assert duckdb_conn.execute("SELECT count(*) from testarrow where a = True or b =True").fetchone()[0] == 3 - def test_filter_pushdown_time(self,duckdb_cursor): + def test_filter_pushdown_time(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_time (a TIME, b TIME, c TIME)") - duckdb_conn.execute("INSERT INTO test_time VALUES ('00:01:00','00:01:00','00:01:00'),('00:10:00','00:10:00','00:10:00'),('01:00:00','00:10:00','01:00:00'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_time VALUES ('00:01:00','00:01:00','00:01:00'),('00:10:00','00:10:00','00:10:00'),('01:00:00','00:10:00','01:00:00'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_time") arrow_table = duck_tbl.arrow() - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='00:01:00'").fetchone()[0] == 1 # Try > @@ -208,19 +240,32 @@ def test_filter_pushdown_time(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='00:10:00' and b ='00:01:00'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='01:00:00' and b = '00:10:00' and c = '01:00:00'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a='00:10:00' and b ='00:01:00'").fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='01:00:00' and b = '00:10:00' and c = '01:00:00'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '01:00:00' or b ='00:01:00'").fetchone()[0] == 2 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a = '01:00:00' or b ='00:01:00'").fetchone()[0] + == 2 + ) - def test_filter_pushdown_timestamp(self,duckdb_cursor): + def test_filter_pushdown_timestamp(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_timestamp (a TIMESTAMP, b TIMESTAMP, c TIMESTAMP)") - duckdb_conn.execute("INSERT INTO test_timestamp VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'),('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'),('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_timestamp VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'),('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'),('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_timestamp") arrow_table = duck_tbl.arrow() - print (arrow_table) + print(arrow_table) - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2008-01-01 00:00:01'").fetchone()[0] == 1 # Try > @@ -238,19 +283,36 @@ def test_filter_pushdown_timestamp(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'").fetchone()[0] == 2 - - def test_filter_pushdown_timestamp_TZ(self,duckdb_cursor): + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 2 + ) + + def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_timestamptz (a TIMESTAMPTZ, b TIMESTAMPTZ, c TIMESTAMPTZ)") - duckdb_conn.execute("INSERT INTO test_timestamptz VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'),('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'),('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_timestamptz VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'),('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'),('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_timestamptz") arrow_table = duck_tbl.arrow() - print (arrow_table) + print(arrow_table) - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2008-01-01 00:00:01'").fetchone()[0] == 1 # Try > @@ -268,19 +330,35 @@ def test_filter_pushdown_timestamp_TZ(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'").fetchone()[0] == 2 - - - def test_filter_pushdown_date(self,duckdb_cursor): + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 2 + ) + + def test_filter_pushdown_date(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_date (a DATE, b DATE, c DATE)") - duckdb_conn.execute("INSERT INTO test_date VALUES ('2000-01-01','2000-01-01','2000-01-01'),('2000-10-01','2000-10-01','2000-10-01'),('2010-01-01','2000-10-01','2010-01-01'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_date VALUES ('2000-01-01','2000-01-01','2000-01-01'),('2000-10-01','2000-10-01','2000-10-01'),('2010-01-01','2000-10-01','2010-01-01'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_date") arrow_table = duck_tbl.arrow() - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2000-01-01'").fetchone()[0] == 1 # Try > @@ -298,14 +376,33 @@ def test_filter_pushdown_date(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='2000-10-01' and b ='2000-01-01'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2010-01-01' and b = '2000-10-01' and c = '2010-01-01'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a='2000-10-01' and b ='2000-01-01'").fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='2010-01-01' and b = '2000-10-01' and c = '2010-01-01'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '2010-01-01' or b ='2000-01-01'").fetchone()[0] == 2 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a = '2010-01-01' or b ='2000-01-01'").fetchone()[ + 0 + ] + == 2 + ) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_filter_pushdown_blob(self, pandas): - df = pandas.DataFrame({'a': [bytes([1]), bytes([2]), bytes([3]), None], 'b': [bytes([1]), bytes([2]), bytes([3]), None],'c': [bytes([1]), bytes([2]), bytes([3]), None]}) + df = pandas.DataFrame( + { + 'a': [bytes([1]), bytes([2]), bytes([3]), None], + 'b': [bytes([1]), bytes([2]), bytes([3]), None], + 'c': [bytes([1]), bytes([2]), bytes([3]), None], + } + ) arrow_table = pa.Table.from_pandas(df) # Try == @@ -326,25 +423,28 @@ def test_filter_pushdown_blob(self, pandas): # Try And assert duckdb_conn.execute("SELECT count(*) from arrow_table where a='\x02' and b ='\x01'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from arrow_table where a ='\x02' and b = '\x02' and c = '\x02'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute( + "SELECT count(*) from arrow_table where a ='\x02' and b = '\x02' and c = '\x02'" + ).fetchone()[0] + == 1 + ) # Try Or assert duckdb_conn.execute("SELECT count(*) from arrow_table where a = '\x01' or b ='\x02'").fetchone()[0] == 2 - - def test_filter_pushdown_no_projection(self,duckdb_cursor): + def test_filter_pushdown_no_projection(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_int (a INTEGER, b INTEGER, c INTEGER)") duckdb_conn.execute("INSERT INTO test_int VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table("test_int") arrow_table = duck_tbl.arrow() - duckdb_conn.register("testarrowtable",arrow_table) + duckdb_conn.register("testarrowtable", arrow_table) assert duckdb_conn.execute("SELECT * FROM testarrowtable VALUES where a =1").fetchall() == [(1, 1, 1)] arrow_dataset = ds.dataset(arrow_table) - duckdb_conn.register("testarrowdataset",arrow_dataset) + duckdb_conn.register("testarrowdataset", arrow_dataset) assert duckdb_conn.execute("SELECT * FROM testarrowdataset VALUES where a =1").fetchall() == [(1, 1, 1)] @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_filter_pushdown_2145(self,duckdb_cursor, pandas): - + def test_filter_pushdown_2145(self, duckdb_cursor, pandas): date1 = pandas.date_range("2018-01-01", "2018-12-31", freq="B") df1 = pandas.DataFrame(np.random.randn(date1.shape[0], 5), columns=list("ABCDE")) df1["date"] = date1 @@ -359,7 +459,7 @@ def test_filter_pushdown_2145(self,duckdb_cursor, pandas): table = pq.ParquetDataset(["data1.parquet", "data2.parquet"]).read() con = duckdb.connect() - con.register("testarrow",table) + con.register("testarrow", table) output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() expected_df = duckdb.from_parquet("data*.parquet").filter("date > '2019-01-01'").df() @@ -368,11 +468,11 @@ def test_filter_pushdown_2145(self,duckdb_cursor, pandas): os.remove("data1.parquet") os.remove("data2.parquet") - def test_filter_column_removal(self,duckdb_cursor): + def test_filter_column_removal(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test AS SELECT range i, range j FROM range(5)") duck_test_table = duckdb_conn.table("test") arrow_test_table = duck_test_table.arrow() - duckdb_conn.register("arrow_test_table",arrow_test_table) + duckdb_conn.register("arrow_test_table", arrow_test_table) # PR 4817 - remove filter columns that are unused in the remainder of the query plan from the table function query_res = duckdb_conn.execute("EXPLAIN SELECT count(*) from testarrow where a = 100 or b =1").fetchall() diff --git a/tools/pythonpkg/tests/fast/arrow/test_integration.py b/tools/pythonpkg/tests/fast/arrow/test_integration.py index a5087b8fff16..badff4ee899f 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_integration.py +++ b/tools/pythonpkg/tests/fast/arrow/test_integration.py @@ -1,19 +1,22 @@ import duckdb import os import datetime + try: import pyarrow import pyarrow.parquet import numpy as np + can_run = True except: can_run = False + class TestArrowIntegration(object): def test_parquet_roundtrip(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' # TODO timestamp @@ -37,10 +40,10 @@ def test_parquet_roundtrip(self, duckdb_cursor): assert rel_from_arrow.equals(rel_from_arrow2, check_metadata=True) assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) - def test_unsigned_roundtrip(self,duckdb_cursor): + def test_unsigned_roundtrip(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','unsigned.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'unsigned.parquet') cols = 'a, b, c, d' unsigned_parquet_table = pyarrow.parquet.read_table(parquet_filename) @@ -54,7 +57,9 @@ def test_unsigned_roundtrip(self,duckdb_cursor): assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) con = duckdb.connect() - con.execute("select NULL c_null, (c % 4 = 0)::bool c_bool, (c%128)::tinyint c_tinyint, c::smallint*1000 c_smallint, c::integer*100000 c_integer, c::bigint*1000000000000 c_bigint, c::float c_float, c::double c_double, 'c_' || c::string c_string from (select case when range % 2 == 0 then range else null end as c from range(-10000, 10000)) sq") + con.execute( + "select NULL c_null, (c % 4 = 0)::bool c_bool, (c%128)::tinyint c_tinyint, c::smallint*1000 c_smallint, c::integer*100000 c_integer, c::bigint*1000000000000 c_bigint, c::float c_float, c::double c_double, 'c_' || c::string c_string from (select case when range % 2 == 0 then range else null end as c from range(-10000, 10000)) sq" + ) arrow_result = con.fetch_arrow_table() arrow_result.validate(full=True) arrow_result.combine_chunks() @@ -65,7 +70,7 @@ def test_unsigned_roundtrip(self,duckdb_cursor): assert round_tripping.equals(arrow_result, check_metadata=True) - def test_decimals_roundtrip(self,duckdb_cursor): + def test_decimals_roundtrip(self, duckdb_cursor): if not can_run: return @@ -85,77 +90,86 @@ def test_decimals_roundtrip(self,duckdb_cursor): arrow_result = duckdb_conn.execute("SELECT sum(a), sum(b), sum(c),sum(d) from testarrow").fetchall() - assert(arrow_result == true_result) + assert arrow_result == true_result arrow_result = duckdb_conn.execute("SELECT typeof(a), typeof(b), typeof(c),typeof(d) from testarrow").fetchone() - assert (arrow_result[0] == 'DECIMAL(4,2)') - assert (arrow_result[1] == 'DECIMAL(9,2)') - assert (arrow_result[2] == 'DECIMAL(18,2)') - assert (arrow_result[3] == 'DECIMAL(30,2)') + assert arrow_result[0] == 'DECIMAL(4,2)' + assert arrow_result[1] == 'DECIMAL(9,2)' + assert arrow_result[2] == 'DECIMAL(18,2)' + assert arrow_result[3] == 'DECIMAL(30,2)' - #Lets also test big number comming from arrow land - data = (pyarrow.array(np.array([9999999999999999999999999999999999]), type=pyarrow.decimal128(38,0))) - arrow_tbl = pyarrow.Table.from_arrays([data],['a']) + # Lets also test big number comming from arrow land + data = pyarrow.array(np.array([9999999999999999999999999999999999]), type=pyarrow.decimal128(38, 0)) + arrow_tbl = pyarrow.Table.from_arrays([data], ['a']) duckdb_conn = duckdb.connect() duckdb_conn.from_arrow(arrow_tbl).create("bigdecimal") result = duckdb_conn.execute('select * from bigdecimal') - assert (result.fetchone()[0] == 9999999999999999999999999999999999) + assert result.fetchone()[0] == 9999999999999999999999999999999999 - def test_intervals_roundtrip(self,duckdb_cursor): + def test_intervals_roundtrip(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() # test for import from apache arrow - expected_value = pyarrow.MonthDayNano([2, 8, - (datetime.timedelta(seconds=1, microseconds=1, - milliseconds=1, minutes=1, - hours=1) // - datetime.timedelta(microseconds=1)) * 1000]) + expected_value = pyarrow.MonthDayNano( + [ + 2, + 8, + ( + datetime.timedelta(seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1) + // datetime.timedelta(microseconds=1) + ) + * 1000, + ] + ) arr = [expected_value] data = pyarrow.array(arr, pyarrow.month_day_nano_interval()) - arrow_tbl = pyarrow.Table.from_arrays([data],['a']) + arrow_tbl = pyarrow.Table.from_arrays([data], ['a']) duckdb_conn = duckdb.connect() duckdb_conn.from_arrow(arrow_tbl).create("intervaltbl") duck_arrow_tbl = duckdb_conn.table("intervaltbl").arrow()['a'] - assert (duck_arrow_tbl[0].value == expected_value) + assert duck_arrow_tbl[0].value == expected_value # test for select interval from duckdb duckdb_conn.execute("CREATE TABLE test (a INTERVAL)") duckdb_conn.execute("INSERT INTO test VALUES (INTERVAL 1 YEAR + INTERVAL 1 DAY + INTERVAL 1 SECOND)") expected_value = pyarrow.MonthDayNano([12, 1, 1000000000]) duck_tbl_arrow = duckdb_conn.table("test").arrow()['a'] - assert (duck_tbl_arrow[0].value.months == expected_value.months) - assert (duck_tbl_arrow[0].value.days == expected_value.days) - assert (duck_tbl_arrow[0].value.nanoseconds == expected_value.nanoseconds) + assert duck_tbl_arrow[0].value.months == expected_value.months + assert duck_tbl_arrow[0].value.days == expected_value.days + assert duck_tbl_arrow[0].value.nanoseconds == expected_value.nanoseconds - def test_null_intervals_roundtrip(self,duckdb_cursor): + def test_null_intervals_roundtrip(self, duckdb_cursor): if not can_run: return # test for null interval - expected_value = pyarrow.MonthDayNano([2, 8, - (datetime.timedelta(seconds=1, microseconds=1, - milliseconds=1, minutes=1, - hours=1) // - datetime.timedelta(microseconds=1)) * 1000]) - arr = [ - None, - expected_value - ] + expected_value = pyarrow.MonthDayNano( + [ + 2, + 8, + ( + datetime.timedelta(seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1) + // datetime.timedelta(microseconds=1) + ) + * 1000, + ] + ) + arr = [None, expected_value] data = pyarrow.array(arr, pyarrow.month_day_nano_interval()) - arrow_tbl = pyarrow.Table.from_arrays([data],['a']) + arrow_tbl = pyarrow.Table.from_arrays([data], ['a']) duckdb_conn = duckdb.connect() duckdb_conn.from_arrow(arrow_tbl).create("intervalnulltbl") duckdb_tbl_arrow = duckdb_conn.table("intervalnulltbl").arrow()['a'] - assert (duckdb_tbl_arrow[0].value == None) - assert (duckdb_tbl_arrow[1].value == expected_value) + assert duckdb_tbl_arrow[0].value == None + assert duckdb_tbl_arrow[1].value == expected_value - def test_nested_interval_roundtrip(self,duckdb_cursor): + def test_nested_interval_roundtrip(self, duckdb_cursor): if not can_run: return # Dictionary @@ -165,26 +179,28 @@ def test_nested_interval_roundtrip(self,duckdb_cursor): second_value = pyarrow.MonthDayNano([90, 12, 0]) dictionary = pyarrow.array([first_value, second_value, None]) dict_array = pyarrow.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pyarrow.Table.from_arrays([dict_array],['a']) + arrow_table = pyarrow.Table.from_arrays([dict_array], ['a']) duckdb_conn.from_arrow(arrow_table).create("dictionarytbl") duckdb_tbl_arrow = duckdb_conn.table("dictionarytbl").arrow()['a'] - assert duckdb_tbl_arrow[0].value == first_value - assert duckdb_tbl_arrow[1].value == second_value - assert duckdb_tbl_arrow[2].value == first_value - assert duckdb_tbl_arrow[3].value == second_value - assert duckdb_tbl_arrow[4].value == None - assert duckdb_tbl_arrow[5].value == second_value - assert duckdb_tbl_arrow[6].value == first_value - assert duckdb_tbl_arrow[7].value == None - + assert duckdb_tbl_arrow[0].value == first_value + assert duckdb_tbl_arrow[1].value == second_value + assert duckdb_tbl_arrow[2].value == first_value + assert duckdb_tbl_arrow[3].value == second_value + assert duckdb_tbl_arrow[4].value == None + assert duckdb_tbl_arrow[5].value == second_value + assert duckdb_tbl_arrow[6].value == first_value + assert duckdb_tbl_arrow[7].value == None + # List - query = duckdb.query("SELECT a from (select list_value(INTERVAL 3 MONTHS, INTERVAL 5 DAYS, INTERVAL 10 SECONDS, NULL) as a) as t").arrow()['a'] + query = duckdb.query( + "SELECT a from (select list_value(INTERVAL 3 MONTHS, INTERVAL 5 DAYS, INTERVAL 10 SECONDS, NULL) as a) as t" + ).arrow()['a'] assert query[0][0].value == pyarrow.MonthDayNano([3, 0, 0]) assert query[0][1].value == pyarrow.MonthDayNano([0, 5, 0]) assert query[0][2].value == pyarrow.MonthDayNano([0, 0, 10000000000]) assert query[0][3].value == None - + # Struct query = "SELECT a from (SELECT STRUCT_PACK(a := INTERVAL 1 MONTHS, b := INTERVAL 10 DAYS, c:= INTERVAL 20 SECONDS) as a) as t" true_answer = duckdb.query(query).fetchall() @@ -201,7 +217,7 @@ def test_min_max_interval_roundtrip(self, duckdb_cursor): interval_min_value = pyarrow.MonthDayNano([0, 0, 0]) interval_max_value = pyarrow.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) data = pyarrow.array([interval_min_value, interval_max_value], pyarrow.month_day_nano_interval()) - arrow_tbl = pyarrow.Table.from_arrays([data],['a']) + arrow_tbl = pyarrow.Table.from_arrays([data], ['a']) duckdb_conn = duckdb.connect() duckdb_conn.from_arrow(arrow_tbl).create("intervalminmaxtbl") @@ -209,7 +225,7 @@ def test_min_max_interval_roundtrip(self, duckdb_cursor): assert duck_arrow_tbl[0].value == pyarrow.MonthDayNano([0, 0, 0]) assert duck_arrow_tbl[1].value == pyarrow.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) - def test_strings_roundtrip(self,duckdb_cursor): + def test_strings_roundtrip(self, duckdb_cursor): if not can_run: return @@ -218,8 +234,10 @@ def test_strings_roundtrip(self,duckdb_cursor): duckdb_conn.execute("CREATE TABLE test (a varchar)") # Test Small, Null and Very Big String - for i in range (0,1000): - duckdb_conn.execute("INSERT INTO test VALUES ('Matt Damon'),(NULL), ('Jeffffreeeey Jeeeeef Baaaaaaazos'), ('X-Content-Type-Options')") + for i in range(0, 1000): + duckdb_conn.execute( + "INSERT INTO test VALUES ('Matt Damon'),(NULL), ('Jeffffreeeey Jeeeeef Baaaaaaazos'), ('X-Content-Type-Options')" + ) true_result = duckdb_conn.execute("SELECT * from test").fetchall() @@ -231,4 +249,4 @@ def test_strings_roundtrip(self,duckdb_cursor): arrow_result = duckdb_conn.execute("SELECT * from testarrow").fetchall() - assert(arrow_result == true_result) + assert arrow_result == true_result diff --git a/tools/pythonpkg/tests/fast/arrow/test_interval.py b/tools/pythonpkg/tests/fast/arrow/test_interval.py index 7dacb0318e74..7a891a610021 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_interval.py +++ b/tools/pythonpkg/tests/fast/arrow/test_interval.py @@ -2,37 +2,52 @@ import os import datetime import pytest + try: import pyarrow as pa import pandas as pd + can_run = True except: can_run = False + class TestArrowInterval(object): def test_duration_types(self, duckdb_cursor): if not can_run: return - expected_arrow = pa.Table.from_arrays([pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())],['a']) - data = (pa.array([1000000000], type=pa.duration('ns')),pa.array([1000000], type=pa.duration('us')),pa.array([1000], pa.duration('ms')),pa.array([1], pa.duration('s'))) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2],data[3]],['a','b','c','d']) + expected_arrow = pa.Table.from_arrays( + [pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())], ['a'] + ) + data = ( + pa.array([1000000000], type=pa.duration('ns')), + pa.array([1000000], type=pa.duration('us')), + pa.array([1000], pa.duration('ms')), + pa.array([1], pa.duration('s')), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == expected_arrow['a']) - assert (rel['b'] == expected_arrow['a']) - assert (rel['c'] == expected_arrow['a']) - assert (rel['d'] == expected_arrow['a']) + assert rel['a'] == expected_arrow['a'] + assert rel['b'] == expected_arrow['a'] + assert rel['c'] == expected_arrow['a'] + assert rel['d'] == expected_arrow['a'] def test_duration_null(self, duckdb_cursor): if not can_run: - return - expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())],['a']) - data = (pa.array([None], type=pa.duration('ns')),pa.array([None], type=pa.duration('us')),pa.array([None], pa.duration('ms')),pa.array([None], pa.duration('s'))) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2],data[3]],['a','b','c','d']) + return + expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())], ['a']) + data = ( + pa.array([None], type=pa.duration('ns')), + pa.array([None], type=pa.duration('us')), + pa.array([None], pa.duration('ms')), + pa.array([None], pa.duration('s')), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == expected_arrow['a']) - assert (rel['b'] == expected_arrow['a']) - assert (rel['c'] == expected_arrow['a']) - assert (rel['d'] == expected_arrow['a']) + assert rel['a'] == expected_arrow['a'] + assert rel['b'] == expected_arrow['a'] + assert rel['c'] == expected_arrow['a'] + assert rel['d'] == expected_arrow['a'] def test_duration_overflow(self, duckdb_cursor): if not can_run: @@ -40,8 +55,7 @@ def test_duration_overflow(self, duckdb_cursor): # Only seconds can overflow data = pa.array([9223372036854775807], pa.duration('s')) - arrow_table = pa.Table.from_arrays([data],['a']) + arrow_table = pa.Table.from_arrays([data], ['a']) with pytest.raises(duckdb.ConversionException, match='Could not convert Interval to Microsecond'): - arrow_from_duck = duckdb.from_arrow(arrow_table).arrow() - \ No newline at end of file + arrow_from_duck = duckdb.from_arrow(arrow_table).arrow() diff --git a/tools/pythonpkg/tests/fast/arrow/test_large_string.py b/tools/pythonpkg/tests/fast/arrow/test_large_string.py index 593da316fad8..4836048de63d 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_large_string.py +++ b/tools/pythonpkg/tests/fast/arrow/test_large_string.py @@ -1,18 +1,21 @@ import duckdb import os + try: import pyarrow as pa from pyarrow import parquet as pq import numpy as np + can_run = True except: can_run = False + class TestArrowLargeString(object): - def test_large_string_type(self,duckdb_cursor): + def test_large_string_type(self, duckdb_cursor): if not can_run: return - + schema = pa.schema([("data", pa.large_string())]) inputs = [pa.array(["foo", "baaaar", "b"], type=pa.large_string())] arrow_table = pa.Table.from_arrays(inputs, schema=schema) @@ -20,4 +23,3 @@ def test_large_string_type(self,duckdb_cursor): rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [('foo',), ('baaaar',), ('b',)] - \ No newline at end of file diff --git a/tools/pythonpkg/tests/fast/arrow/test_multiple_reads.py b/tools/pythonpkg/tests/fast/arrow/test_multiple_reads.py index 28dae2079243..935a8a9c3542 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_multiple_reads.py +++ b/tools/pythonpkg/tests/fast/arrow/test_multiple_reads.py @@ -1,21 +1,24 @@ import duckdb import os + try: import pyarrow import pyarrow.parquet + can_run = True except: can_run = False + class TestArrowReads(object): def test_multiple_queries_same_relation(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) rel = duckdb.from_arrow(userdata_parquet_table) - assert(rel.aggregate("(avg(salary))::INT").execute().fetchone()[0] == 149005) - assert(rel.aggregate("(avg(salary))::INT").execute().fetchone()[0] == 149005) + assert rel.aggregate("(avg(salary))::INT").execute().fetchone()[0] == 149005 + assert rel.aggregate("(avg(salary))::INT").execute().fetchone()[0] == 149005 diff --git a/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py b/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py index 123fbb105c98..3033b975a824 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py +++ b/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py @@ -1,80 +1,92 @@ import duckdb + try: import pyarrow as pa import pyarrow.parquet import numpy as np import pandas as pd import pytest + can_run = True except: can_run = False + def compare_results(query): true_answer = duckdb.query(query).fetchall() from_arrow = duckdb.from_arrow(duckdb.query(query).arrow()).fetchall() assert true_answer == from_arrow + def arrow_to_pandas(query): return duckdb.query(query).arrow().to_pandas()['a'].values.tolist() + class TestArrowNested(object): - def test_lists_basic(self,duckdb_cursor): + def test_lists_basic(self, duckdb_cursor): if not can_run: return - - #Test Constant List - query = duckdb.query("SELECT a from (select list_value(3,5,10) as a) as t").arrow()['a'].to_numpy() + + # Test Constant List + query = duckdb.query("SELECT a from (select list_value(3,5,10) as a) as t").arrow()['a'].to_numpy() assert query[0][0] == 3 assert query[0][1] == 5 assert query[0][2] == 10 # Empty List - query = duckdb.query("SELECT a from (select list_value() as a) as t").arrow()['a'].to_numpy() + query = duckdb.query("SELECT a from (select list_value() as a) as t").arrow()['a'].to_numpy() assert len(query[0]) == 0 - #Test Constant List With Null - query = duckdb.query("SELECT a from (select list_value(3,NULL) as a) as t").arrow()['a'].to_numpy() + # Test Constant List With Null + query = duckdb.query("SELECT a from (select list_value(3,NULL) as a) as t").arrow()['a'].to_numpy() assert query[0][0] == 3 assert np.isnan(query[0][1]) - def test_list_types(self,duckdb_cursor): + def test_list_types(self, duckdb_cursor): if not can_run: return - #Large Lists - data = pyarrow.array([[1],None, [2]], type=pyarrow.large_list(pyarrow.int64())) - arrow_table = pa.Table.from_arrays([data],['a']) + # Large Lists + data = pyarrow.array([[1], None, [2]], type=pyarrow.large_list(pyarrow.int64())) + arrow_table = pa.Table.from_arrays([data], ['a']) rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [([1],), (None,), ([2],)] - #Fixed Size Lists - data = pyarrow.array([[1],None, [2]], type=pyarrow.list_(pyarrow.int64(),1)) - arrow_table = pa.Table.from_arrays([data],['a']) + # Fixed Size Lists + data = pyarrow.array([[1], None, [2]], type=pyarrow.list_(pyarrow.int64(), 1)) + arrow_table = pa.Table.from_arrays([data], ['a']) rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [([1],), (None,), ([2],)] - #Complex nested structures with different list types - data = [pyarrow.array([[1],None, [2]], type=pyarrow.list_(pyarrow.int64(),1)),pyarrow.array([[1],None, [2]], type=pyarrow.large_list(pyarrow.int64())),pyarrow.array([[1,2,3],None, [2,1]], type=pyarrow.list_(pyarrow.int64()))] - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2]],['a','b','c']) + # Complex nested structures with different list types + data = [ + pyarrow.array([[1], None, [2]], type=pyarrow.list_(pyarrow.int64(), 1)), + pyarrow.array([[1], None, [2]], type=pyarrow.large_list(pyarrow.int64())), + pyarrow.array([[1, 2, 3], None, [2, 1]], type=pyarrow.list_(pyarrow.int64())), + ] + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) rel = duckdb.from_arrow(arrow_table) res = rel.project('a').execute().fetchall() assert res == [([1],), (None,), ([2],)] res = rel.project('b').execute().fetchall() assert res == [([1],), (None,), ([2],)] res = rel.project('c').execute().fetchall() - assert res == [([1,2,3],), (None,), ([2,1],)] + assert res == [([1, 2, 3],), (None,), ([2, 1],)] - #Struct Holding different List Types - struct = [pa.StructArray.from_arrays( data,['fixed', 'large','normal'])] - arrow_table = pa.Table.from_arrays(struct,['a']) + # Struct Holding different List Types + struct = [pa.StructArray.from_arrays(data, ['fixed', 'large', 'normal'])] + arrow_table = pa.Table.from_arrays(struct, ['a']) rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() - assert res == [({'fixed': [1], 'large': [1], 'normal': [1, 2, 3]},), ({'fixed': None, 'large': None, 'normal': None},), ({'fixed': [2], 'large': [2], 'normal': [2, 1]},)] - + assert res == [ + ({'fixed': [1], 'large': [1], 'normal': [1, 2, 3]},), + ({'fixed': None, 'large': None, 'normal': None},), + ({'fixed': [2], 'large': [2], 'normal': [2, 1]},), + ] - def test_lists_roundtrip(self,duckdb_cursor): + def test_lists_roundtrip(self, duckdb_cursor): if not can_run: return # Integers @@ -82,78 +94,92 @@ def test_lists_roundtrip(self,duckdb_cursor): compare_results("SELECT a from (select list_value(3,5,NULL) as a) as t") compare_results("SELECT a from (select list_value(NULL,NULL,NULL) as a) as t") compare_results("SELECT a from (select list_value() as a) as t") - #Strings + # Strings compare_results("SELECT a from (select list_value('test','test_one','test_two') as a) as t") compare_results("SELECT a from (select list_value('test','test_one',NULL) as a) as t") - #Big Lists + # Big Lists compare_results("SELECT a from (SELECT LIST(i) as a FROM range(10000) tbl(i)) as t") - #Multiple Lists + # Multiple Lists compare_results("SELECT a from (SELECT LIST(i) as a FROM range(10000) tbl(i) group by i%10) as t") - #Unique Constants + # Unique Constants compare_results("SELECT a from (SELECT list_value(1) as a FROM range(10) tbl(i)) as t") - #Nested Lists + # Nested Lists compare_results("SELECT LIST(le) FROM (SELECT LIST(i) le from range(100) tbl(i) group by i%10) as t") - #LIST[LIST[LIST[LIST[LIST[INTEGER]]]]]] - compare_results("SELECT list (lllle) llllle from (SELECT list (llle) lllle from (SELECT list(lle) llle from (SELECT LIST(le) lle FROM (SELECT LIST(i) le from range(100) tbl(i) group by i%10) as t) as t1) as t2) as t3") + # LIST[LIST[LIST[LIST[LIST[INTEGER]]]]]] + compare_results( + "SELECT list (lllle) llllle from (SELECT list (llle) lllle from (SELECT list(lle) llle from (SELECT LIST(le) lle FROM (SELECT LIST(i) le from range(100) tbl(i) group by i%10) as t) as t1) as t2) as t3" + ) - compare_results('''SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs - from (SELECT a%4 as grp, list(a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T;''') - #Tests for converting multiple lists to/from Arrow with NULL values and/or strings - compare_results("SELECT list(st) from (select i, case when i%10 then NULL else i::VARCHAR end as st from range(1000) tbl(i)) as t group by i%5") + compare_results( + '''SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs + from (SELECT a%4 as grp, list(a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T;''' + ) + # Tests for converting multiple lists to/from Arrow with NULL values and/or strings + compare_results( + "SELECT list(st) from (select i, case when i%10 then NULL else i::VARCHAR end as st from range(1000) tbl(i)) as t group by i%5" + ) - def test_struct_roundtrip(self,duckdb_cursor): + def test_struct_roundtrip(self, duckdb_cursor): if not can_run: return compare_results("SELECT a from (SELECT STRUCT_PACK(a := 42, b := 43) as a) as t") compare_results("SELECT a from (SELECT STRUCT_PACK(a := NULL, b := 43) as a) as t") compare_results("SELECT a from (SELECT STRUCT_PACK(a := NULL) as a) as t") compare_results("SELECT a from (SELECT STRUCT_PACK(a := i, b := i) as a FROM range(10000) tbl(i)) as t") - compare_results("SELECT a from (SELECT STRUCT_PACK(a := LIST_VALUE(1,2,3), b := i) as a FROM range(10000) tbl(i)) as t") + compare_results( + "SELECT a from (SELECT STRUCT_PACK(a := LIST_VALUE(1,2,3), b := i) as a FROM range(10000) tbl(i)) as t" + ) - def test_map_roundtrip(self,duckdb_cursor): + def test_map_roundtrip(self, duckdb_cursor): if not can_run: return compare_results("SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t") - + compare_results("SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t") - + compare_results("SELECT a from (select MAP(LIST_VALUE(),LIST_VALUE()) as a) as t") - compare_results("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t") - compare_results("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie','Tenacious D'),LIST_VALUE(10,10)) as a) as t") + compare_results( + "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t" + ) + compare_results( + "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie','Tenacious D'),LIST_VALUE(10,10)) as a) as t" + ) compare_results("SELECT m from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(m)") - compare_results("SELECT m from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10000) tbl(i) group by i%5) as lst_tbl) as T") + compare_results( + "SELECT m from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10000) tbl(i) group by i%5) as lst_tbl) as T" + ) def test_map_arrow_to_duckdb(self, duckdb_cursor): if not can_run: return map_type = pa.map_(pa.int32(), pa.int32()) - values = [ - [ - (3, 12), - (3, 21) - ], - [ - (5, 42) - ] - ] - arrow_table = pa.table( - {'detail': pa.array(values, map_type)} - ) - with pytest.raises(duckdb.InvalidInputException, match="Arrow map contains duplicate key, which isn't supported by DuckDB map type"): + values = [[(3, 12), (3, 21)], [(5, 42)]] + arrow_table = pa.table({'detail': pa.array(values, map_type)}) + with pytest.raises( + duckdb.InvalidInputException, + match="Arrow map contains duplicate key, which isn't supported by DuckDB map type", + ): rel = duckdb.from_arrow(arrow_table).fetchall() - - def test_map_arrow_to_pandas(self,duckdb_cursor): + + def test_map_arrow_to_pandas(self, duckdb_cursor): if not can_run: return - assert arrow_to_pandas("SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t") == [[(1, 10), (2, 9), (3, 8), (4, 7)]] + assert arrow_to_pandas( + "SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t" + ) == [[(1, 10), (2, 9), (3, 8), (4, 7)]] assert arrow_to_pandas("SELECT a from (select MAP(LIST_VALUE(),LIST_VALUE()) as a) as t") == [[]] - assert arrow_to_pandas("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t") == [[('Jon Lajoie', 10), ('Backstreet Boys', 9), ('Tenacious D', 10)]] - assert arrow_to_pandas("SELECT a from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(a)") == [[(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)]] - assert arrow_to_pandas("SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a") == [[({'i': 1, 'j': 2}, {'i': 1, 'j': 2}), ({'i': 3, 'j': 4}, {'i': 3, 'j': 4})]] - - - def test_frankstein_nested(self,duckdb_cursor): + assert arrow_to_pandas( + "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t" + ) == [[('Jon Lajoie', 10), ('Backstreet Boys', 9), ('Tenacious D', 10)]] + assert arrow_to_pandas( + "SELECT a from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(a)" + ) == [[(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)]] + assert arrow_to_pandas( + "SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a" + ) == [[({'i': 1, 'j': 2}, {'i': 1, 'j': 2}), ({'i': 3, 'j': 4}, {'i': 3, 'j': 4})]] + + def test_frankstein_nested(self, duckdb_cursor): if not can_run: return # List of structs W/ Struct that is NULL entirely @@ -161,12 +187,16 @@ def test_frankstein_nested(self,duckdb_cursor): # Lists of structs with lists compare_results("SELECT [{'i':1,'j':[2,3]},NULL]") - + # Maps embedded in a struct - compare_results("SELECT {'i':mp,'j':mp2} FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t") + compare_results( + "SELECT {'i':mp,'j':mp2} FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t" + ) - # List of maps - compare_results("SELECT [mp,mp2] FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t") + # List of maps + compare_results( + "SELECT [mp,mp2] FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t" + ) # Map with list as key and/or value compare_results("SELECT MAP(LIST_VALUE([1,2],[3,4],[5,4]),LIST_VALUE([1,2],[3,4],[5,4]))") @@ -181,4 +211,6 @@ def test_frankstein_nested(self,duckdb_cursor): compare_results("SELECT [{'i':1,'j':[2,3]},NULL,{'i':1,'j':[2,3]}]") # MAP that is NULL entirely - compare_results("SELECT * FROM (VALUES (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))),(NULL), (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))), (NULL)) as a") + compare_results( + "SELECT * FROM (VALUES (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))),(NULL), (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))), (NULL)) as a" + ) diff --git a/tools/pythonpkg/tests/fast/arrow/test_parallel.py b/tools/pythonpkg/tests/fast/arrow/test_parallel.py index fa2fc8c43998..2609d1aebcc8 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_parallel.py +++ b/tools/pythonpkg/tests/fast/arrow/test_parallel.py @@ -1,56 +1,59 @@ import duckdb import os + try: import pyarrow import pyarrow.parquet import numpy as np + can_run = True except: can_run = False + class TestArrowParallel(object): - def test_parallel_run(self,duckdb_cursor): + def test_parallel_run(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - data = (pyarrow.array(np.random.randint(800, size=1000000), type=pyarrow.int32())) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data],['a']).to_batches(10000)) + data = pyarrow.array(np.random.randint(800, size=1000000), type=pyarrow.int32()) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(10000)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads - assert(rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000000) - assert(rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000000) + assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000000 + assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000000 - def test_parallel_types_and_different_batches(self,duckdb_cursor): + def test_parallel_types_and_different_batches(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) for i in [7, 51, 99, 100, 101, 500, 1000, 2000]: - data = (pyarrow.array(np.arange(3,7), type=pyarrow.int32())) - tbl = pyarrow.Table.from_arrays([data],['a']) + data = pyarrow.array(np.arange(3, 7), type=pyarrow.int32()) + tbl = pyarrow.Table.from_arrays([data], ['a']) rel_id = duckdb_conn.from_arrow(tbl) userdata_parquet_table2 = pyarrow.Table.from_batches(userdata_parquet_table.to_batches(i)) rel = duckdb_conn.from_arrow(userdata_parquet_table2) result = rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)') - assert (result.execute().fetchone()[0] == 4) + assert result.execute().fetchone()[0] == 4 - def test_parallel_fewer_batches_than_threads(self,duckdb_cursor): + def test_parallel_fewer_batches_than_threads(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - data = (pyarrow.array(np.random.randint(800, size=1000), type=pyarrow.int32())) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data],['a']).to_batches(2)) + data = pyarrow.array(np.random.randint(800, size=1000), type=pyarrow.int32()) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(2)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads - assert(rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000) \ No newline at end of file + assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000 diff --git a/tools/pythonpkg/tests/fast/arrow/test_polars.py b/tools/pythonpkg/tests/fast/arrow/test_polars.py index 9edb176448c8..991d356ed18c 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_polars.py +++ b/tools/pythonpkg/tests/fast/arrow/test_polars.py @@ -1,11 +1,13 @@ import duckdb import pytest + pl = pytest.importorskip("polars") arrow = pytest.importorskip("pyarrow") pl_testing = pytest.importorskip("polars.testing") + class TestPolars(object): - def test_polars(self,duckdb_cursor): + def test_polars(self, duckdb_cursor): df = pl.DataFrame( { "A": [1, 2, 3, 4, 5], @@ -27,7 +29,7 @@ def test_polars(self,duckdb_cursor): con_result = con.execute('SELECT * FROM df').pl() pl_testing.assert_frame_equal(df, con_result) - def test_register_polars(self,duckdb_cursor): + def test_register_polars(self, duckdb_cursor): con = duckdb.connect() df = pl.DataFrame( { @@ -38,14 +40,13 @@ def test_register_polars(self,duckdb_cursor): } ) # scan plus return a polars dataframe - con.register('polars_df',df) + con.register('polars_df', df) polars_result = con.execute('select * from polars_df').pl() pl_testing.assert_frame_equal(df, polars_result) con.unregister('polars_df') with pytest.raises(duckdb.CatalogException, match='Table with name polars_df does not exist'): con.execute("SELECT * FROM polars_df;").pl() - con.register('polars_df',df.lazy()) + con.register('polars_df', df.lazy()) polars_result = con.execute('select * from polars_df').pl() pl_testing.assert_frame_equal(df, polars_result) - diff --git a/tools/pythonpkg/tests/fast/arrow/test_progress.py b/tools/pythonpkg/tests/fast/arrow/test_progress.py index c9acad08000a..c20ebe510a59 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_progress.py +++ b/tools/pythonpkg/tests/fast/arrow/test_progress.py @@ -1,42 +1,43 @@ import duckdb import os import pytest + pyarrow_parquet = pytest.importorskip("pyarrow.parquet") import sys -class TestProgressBarArrow(object): +class TestProgressBarArrow(object): def test_progress_arrow(self): if os.name == 'nt': return np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") - data = (pyarrow.array(np.arange(10000000), type=pyarrow.int32())) + data = pyarrow.array(np.arange(10000000), type=pyarrow.int32()) duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data],['a']) + tbl = pyarrow.Table.from_arrays([data], ['a']) rel = duckdb_conn.from_arrow(tbl) result = rel.aggregate('sum(a)') - assert (result.execute().fetchone()[0] == 49999995000000) + assert result.execute().fetchone()[0] == 49999995000000 # Multiple Threads duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - assert (result.execute().fetchone()[0] == 49999995000000) + assert result.execute().fetchone()[0] == 49999995000000 # More than one batch - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data],['a']).to_batches(100)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(100)) rel = duckdb_conn.from_arrow(tbl) result = rel.aggregate('sum(a)') - assert (result.execute().fetchone()[0] == 49999995000000) + assert result.execute().fetchone()[0] == 49999995000000 # Single Thread duckdb_conn.execute("PRAGMA threads=1") duck_res = result.execute() py_res = duck_res.fetchone()[0] - assert (py_res == 49999995000000) + assert py_res == 49999995000000 def test_progress_arrow_empty(self): if os.name == 'nt': @@ -44,12 +45,12 @@ def test_progress_arrow_empty(self): np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") - data = (pyarrow.array(np.arange(0), type=pyarrow.int32())) + data = pyarrow.array(np.arange(0), type=pyarrow.int32()) duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data],['a']) + tbl = pyarrow.Table.from_arrays([data], ['a']) rel = duckdb_conn.from_arrow(tbl) result = rel.aggregate('sum(a)') - assert (result.execute().fetchone()[0] == None) + assert result.execute().fetchone()[0] == None diff --git a/tools/pythonpkg/tests/fast/arrow/test_projection_pushdown.py b/tools/pythonpkg/tests/fast/arrow/test_projection_pushdown.py index 0aeaa14c5ece..069c4a10112a 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_projection_pushdown.py +++ b/tools/pythonpkg/tests/fast/arrow/test_projection_pushdown.py @@ -1,15 +1,18 @@ import duckdb import os import pytest + try: import pyarrow as pa import pyarrow.dataset as ds + can_run = True except: can_run = False + class TestArrowProjectionPushdown(object): - def test_projection_pushdown_no_filter(self,duckdb_cursor): + def test_projection_pushdown_no_filter(self, duckdb_cursor): if not can_run: return duckdb_conn = duckdb.connect() @@ -17,9 +20,9 @@ def test_projection_pushdown_no_filter(self,duckdb_cursor): duckdb_conn.execute("INSERT INTO test VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table("test") arrow_table = duck_tbl.arrow() - duckdb_conn.register("testarrowtable",arrow_table) + duckdb_conn.register("testarrowtable", arrow_table) assert duckdb_conn.execute("SELECT sum(a) FROM testarrowtable").fetchall() == [(111,)] arrow_dataset = ds.dataset(arrow_table) - duckdb_conn.register("testarrowdataset",arrow_dataset) - assert duckdb_conn.execute("SELECT sum(a) FROM testarrowdataset").fetchall() == [(111,)] \ No newline at end of file + duckdb_conn.register("testarrowdataset", arrow_dataset) + assert duckdb_conn.execute("SELECT sum(a) FROM testarrowdataset").fetchall() == [(111,)] diff --git a/tools/pythonpkg/tests/fast/arrow/test_time.py b/tools/pythonpkg/tests/fast/arrow/test_time.py index ecfd7f15edb9..d575fc0f8a7f 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_time.py +++ b/tools/pythonpkg/tests/fast/arrow/test_time.py @@ -2,64 +2,76 @@ import os import datetime import pytest + try: import pyarrow as pa import pandas as pd + can_run = True except: can_run = False + class TestArrowTime(object): def test_time_types(self, duckdb_cursor): if not can_run: return - - data = (pa.array([1], type=pa.time32('s')),pa.array([1000], type=pa.time32('ms')),pa.array([1000000], pa.time64('us')),pa.array([1000000000], pa.time64('ns'))) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2],data[3]],['a','b','c','d']) - rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == arrow_table['c']) - assert (rel['b'] == arrow_table['c']) - assert (rel['c'] == arrow_table['c']) - assert (rel['d'] == arrow_table['c']) + data = ( + pa.array([1], type=pa.time32('s')), + pa.array([1000], type=pa.time32('ms')), + pa.array([1000000], pa.time64('us')), + pa.array([1000000000], pa.time64('ns')), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + rel = duckdb.from_arrow(arrow_table).arrow() + assert rel['a'] == arrow_table['c'] + assert rel['b'] == arrow_table['c'] + assert rel['c'] == arrow_table['c'] + assert rel['d'] == arrow_table['c'] def test_time_null(self, duckdb_cursor): if not can_run: - return - data = (pa.array([None], type=pa.time32('s')),pa.array([None], type=pa.time32('ms')),pa.array([None], pa.time64('us')),pa.array([None], pa.time64('ns'))) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2],data[3]],['a','b','c','d']) + return + data = ( + pa.array([None], type=pa.time32('s')), + pa.array([None], type=pa.time32('ms')), + pa.array([None], pa.time64('us')), + pa.array([None], pa.time64('ns')), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == arrow_table['c']) - assert (rel['b'] == arrow_table['c']) - assert (rel['c'] == arrow_table['c']) - assert (rel['d'] == arrow_table['c']) + assert rel['a'] == arrow_table['c'] + assert rel['b'] == arrow_table['c'] + assert rel['c'] == arrow_table['c'] + assert rel['d'] == arrow_table['c'] def test_max_times(self, duckdb_cursor): if not can_run: - return + return data = pa.array([2147483647000000], type=pa.time64('us')) - result = pa.Table.from_arrays([data],['a']) - #Max Sec + result = pa.Table.from_arrays([data], ['a']) + # Max Sec data = pa.array([2147483647], type=pa.time32('s')) - arrow_table = pa.Table.from_arrays([data],['a']) + arrow_table = pa.Table.from_arrays([data], ['a']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == result['a']) + assert rel['a'] == result['a'] - #Max MSec + # Max MSec data = pa.array([2147483647000], type=pa.time64('us')) - result = pa.Table.from_arrays([data],['a']) + result = pa.Table.from_arrays([data], ['a']) data = pa.array([2147483647], type=pa.time32('ms')) - arrow_table = pa.Table.from_arrays([data],['a']) + arrow_table = pa.Table.from_arrays([data], ['a']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == result['a']) + assert rel['a'] == result['a'] - #Max NSec + # Max NSec data = pa.array([9223372036854774], type=pa.time64('us')) - result = pa.Table.from_arrays([data],['a']) + result = pa.Table.from_arrays([data], ['a']) data = pa.array([9223372036854774000], type=pa.time64('ns')) - arrow_table = pa.Table.from_arrays([data],['a']) + arrow_table = pa.Table.from_arrays([data], ['a']) rel = duckdb.from_arrow(arrow_table).arrow() - print (rel['a']) - print (result['a']) - assert (rel['a'] == result['a']) \ No newline at end of file + print(rel['a']) + print(result['a']) + assert rel['a'] == result['a'] diff --git a/tools/pythonpkg/tests/fast/arrow/test_timestamp_timezone.py b/tools/pythonpkg/tests/fast/arrow/test_timestamp_timezone.py index 51279ed78ec5..70bc9752530b 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_timestamp_timezone.py +++ b/tools/pythonpkg/tests/fast/arrow/test_timestamp_timezone.py @@ -3,77 +3,80 @@ import datetime try: - import pyarrow as pa - can_run = True + import pyarrow as pa + + can_run = True except: - can_run = False + can_run = False + def generate_table(current_time, precision, timezone): - timestamp_type = pa.timestamp(precision, tz=timezone) - schema = pa.schema([("data",timestamp_type)]) - inputs = [pa.array([current_time], type=timestamp_type)] - return pa.Table.from_arrays(inputs, schema=schema) + timestamp_type = pa.timestamp(precision, tz=timezone) + schema = pa.schema([("data", timestamp_type)]) + inputs = [pa.array([current_time], type=timestamp_type)] + return pa.Table.from_arrays(inputs, schema=schema) + timezones = ['UTC', 'BET', 'CET', 'Asia/Kathmandu'] -class TestArrowTimestampsTimezone(object): - def test_timestamp_timezone(self, duckdb_cursor): - if not can_run: - return - precisions = ['us','s','ns','ms'] - current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) - con = duckdb.connect() - con.execute("SET TimeZone = 'UTC'") - for precision in precisions: - arrow_table = generate_table(current_time,precision,'UTC') - res_utc = con.from_arrow(arrow_table).execute().fetchall() - assert res_utc[0][0] == current_time - def test_timestamp_timezone_overflow(self, duckdb_cursor): - if not can_run: - return - precisions = ['s','ms'] - current_time = 9223372036854775807 - for precision in precisions: - with pytest.raises(duckdb.ConversionException, match='Could not convert'): - arrow_table = generate_table(current_time,precision,'UTC') - res_utc = duckdb.from_arrow(arrow_table).execute().fetchall() +class TestArrowTimestampsTimezone(object): + def test_timestamp_timezone(self, duckdb_cursor): + if not can_run: + return + precisions = ['us', 's', 'ns', 'ms'] + current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) + con = duckdb.connect() + con.execute("SET TimeZone = 'UTC'") + for precision in precisions: + arrow_table = generate_table(current_time, precision, 'UTC') + res_utc = con.from_arrow(arrow_table).execute().fetchall() + assert res_utc[0][0] == current_time - def test_timestamp_tz_to_arrow(self, duckdb_cursor): - if not can_run: - return - precisions = ['us','s','ns','ms'] - current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) - con = duckdb.connect() - for precision in precisions: - for timezone in timezones: - con.execute("SET TimeZone = '"+timezone+"'") - arrow_table = generate_table(current_time,precision,timezone) - res = con.from_arrow(arrow_table).arrow() - assert res[0].type == pa.timestamp('us', tz=timezone) - assert res == generate_table(current_time,'us',timezone) + def test_timestamp_timezone_overflow(self, duckdb_cursor): + if not can_run: + return + precisions = ['s', 'ms'] + current_time = 9223372036854775807 + for precision in precisions: + with pytest.raises(duckdb.ConversionException, match='Could not convert'): + arrow_table = generate_table(current_time, precision, 'UTC') + res_utc = duckdb.from_arrow(arrow_table).execute().fetchall() - def test_timestamp_tz_with_null(self, duckdb_cursor): - if not can_run: - return - con = duckdb.connect() - con.execute("create table t (i timestamptz)") - con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') - arrow_tbl = rel.arrow() - con.register('t2',arrow_tbl) + def test_timestamp_tz_to_arrow(self, duckdb_cursor): + if not can_run: + return + precisions = ['us', 's', 'ns', 'ms'] + current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) + con = duckdb.connect() + for precision in precisions: + for timezone in timezones: + con.execute("SET TimeZone = '" + timezone + "'") + arrow_table = generate_table(current_time, precision, timezone) + res = con.from_arrow(arrow_table).arrow() + assert res[0].type == pa.timestamp('us', tz=timezone) + assert res == generate_table(current_time, 'us', timezone) - assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() + def test_timestamp_tz_with_null(self, duckdb_cursor): + if not can_run: + return + con = duckdb.connect() + con.execute("create table t (i timestamptz)") + con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") + rel = con.table('t') + arrow_tbl = rel.arrow() + con.register('t2', arrow_tbl) - def test_timestamp_stream(self, duckdb_cursor): - if not can_run: - return - con = duckdb.connect() - con.execute("create table t (i timestamptz)") - con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') - arrow_tbl = rel.record_batch().read_all() - con.register('t2',arrow_tbl) + assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() - assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() + def test_timestamp_stream(self, duckdb_cursor): + if not can_run: + return + con = duckdb.connect() + con.execute("create table t (i timestamptz)") + con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") + rel = con.table('t') + arrow_tbl = rel.record_batch().read_all() + con.register('t2', arrow_tbl) + assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() diff --git a/tools/pythonpkg/tests/fast/arrow/test_timestamps.py b/tools/pythonpkg/tests/fast/arrow/test_timestamps.py index c6b6d538ea5f..5d404112bdd8 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_timestamps.py +++ b/tools/pythonpkg/tests/fast/arrow/test_timestamps.py @@ -2,47 +2,62 @@ import os import datetime import pytest + try: import pyarrow as pa import pandas as pd + can_run = True except: can_run = False + class TestArrowTimestamps(object): def test_timestamp_types(self, duckdb_cursor): if not can_run: return - data = (pa.array([datetime.datetime.now()], type=pa.timestamp('ns')),pa.array([datetime.datetime.now()], type=pa.timestamp('us')),pa.array([datetime.datetime.now()], pa.timestamp('ms')),pa.array([datetime.datetime.now()], pa.timestamp('s'))) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2],data[3]],['a','b','c','d']) + data = ( + pa.array([datetime.datetime.now()], type=pa.timestamp('ns')), + pa.array([datetime.datetime.now()], type=pa.timestamp('us')), + pa.array([datetime.datetime.now()], pa.timestamp('ms')), + pa.array([datetime.datetime.now()], pa.timestamp('s')), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == arrow_table['a']) - assert (rel['b'] == arrow_table['b']) - assert (rel['c'] == arrow_table['c']) - assert (rel['d'] == arrow_table['d']) + assert rel['a'] == arrow_table['a'] + assert rel['b'] == arrow_table['b'] + assert rel['c'] == arrow_table['c'] + assert rel['d'] == arrow_table['d'] def test_timestamp_nulls(self, duckdb_cursor): if not can_run: return - data = (pa.array([None], type=pa.timestamp('ns')),pa.array([None], type=pa.timestamp('us')),pa.array([None], pa.timestamp('ms')),pa.array([None], pa.timestamp('s'))) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2],data[3]],['a','b','c','d']) + data = ( + pa.array([None], type=pa.timestamp('ns')), + pa.array([None], type=pa.timestamp('us')), + pa.array([None], pa.timestamp('ms')), + pa.array([None], pa.timestamp('s')), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) rel = duckdb.from_arrow(arrow_table).arrow() - assert (rel['a'] == arrow_table['a']) - assert (rel['b'] == arrow_table['b']) - assert (rel['c'] == arrow_table['c']) - assert (rel['d'] == arrow_table['d']) + assert rel['a'] == arrow_table['a'] + assert rel['b'] == arrow_table['b'] + assert rel['c'] == arrow_table['c'] + assert rel['d'] == arrow_table['d'] def test_timestamp_overflow(self, duckdb_cursor): if not can_run: return - data = (pa.array([9223372036854775807], pa.timestamp('s')),pa.array([9223372036854775807], pa.timestamp('ms')),pa.array([9223372036854775807], pa.timestamp('us'))) - arrow_table = pa.Table.from_arrays([data[0],data[1],data[2]],['a','b','c']) + data = ( + pa.array([9223372036854775807], pa.timestamp('s')), + pa.array([9223372036854775807], pa.timestamp('ms')), + pa.array([9223372036854775807], pa.timestamp('us')), + ) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) arrow_from_duck = duckdb.from_arrow(arrow_table).arrow() - assert (arrow_from_duck['a'] == arrow_table['a']) - assert (arrow_from_duck['b'] == arrow_table['b']) - assert (arrow_from_duck['c'] == arrow_table['c']) - - + assert arrow_from_duck['a'] == arrow_table['a'] + assert arrow_from_duck['b'] == arrow_table['b'] + assert arrow_from_duck['c'] == arrow_table['c'] with pytest.raises(duckdb.ConversionException, match='Could not convert'): duck_rel = duckdb.from_arrow(arrow_table) @@ -58,4 +73,3 @@ def test_timestamp_overflow(self, duckdb_cursor): duck_rel = duckdb.from_arrow(arrow_table) res = duck_rel.project('c::TIMESTAMP_NS') res.fetchone() - \ No newline at end of file diff --git a/tools/pythonpkg/tests/fast/arrow/test_tpch.py b/tools/pythonpkg/tests/fast/arrow/test_tpch.py index ff7bfc7126a0..62713163e520 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_tpch.py +++ b/tools/pythonpkg/tests/fast/arrow/test_tpch.py @@ -1,12 +1,15 @@ import duckdb + try: import pyarrow import pyarrow.parquet import numpy as np + can_run = True except: can_run = False + def munge(cell): try: cell = round(float(cell), 2) @@ -14,7 +17,8 @@ def munge(cell): cell = str(cell) return cell -def check_result(result,answers): + +def check_result(result, answers): for q_res in answers: db_result = result.fetchone() cq_results = q_res.split("|") @@ -27,9 +31,9 @@ def check_result(result,answers): assert ans_result == db_result return True -class TestTPCHArrow(object): - def test_tpch_arrow(self,duckdb_cursor): +class TestTPCHArrow(object): + def test_tpch_arrow(self, duckdb_cursor): if not can_run: return @@ -43,17 +47,23 @@ def test_tpch_arrow(self,duckdb_cursor): duck_tbl = duckdb_conn.table(tpch_table) arrow_tables.append(duck_tbl.arrow()) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) - duckdb_conn.execute("DROP TABLE "+tpch_table) + duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) - for i in range (1,23): - query = duckdb_conn.execute("select query from tpch_queries() where query_nr="+str(i)).fetchone()[0] - answers = duckdb_conn.execute("select answer from tpch_answers() where scale_factor = 0.01 and query_nr="+str(i)).fetchone()[0].split("\n")[1:] + for i in range(1, 23): + query = duckdb_conn.execute("select query from tpch_queries() where query_nr=" + str(i)).fetchone()[0] + answers = ( + duckdb_conn.execute( + "select answer from tpch_answers() where scale_factor = 0.01 and query_nr=" + str(i) + ) + .fetchone()[0] + .split("\n")[1:] + ) result = duckdb_conn.execute(query) - assert(check_result(result,answers)) - print ("Query " + str(i) + " works") + assert check_result(result, answers) + print("Query " + str(i) + " works") - def test_tpch_arrow_01(self,duckdb_cursor): + def test_tpch_arrow_01(self, duckdb_cursor): if not can_run: return @@ -67,17 +77,21 @@ def test_tpch_arrow_01(self,duckdb_cursor): duck_tbl = duckdb_conn.table(tpch_table) arrow_tables.append(duck_tbl.arrow()) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) - duckdb_conn.execute("DROP TABLE "+tpch_table) + duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) - for i in range (1,23): - query = duckdb_conn.execute("select query from tpch_queries() where query_nr="+str(i)).fetchone()[0] - answers = duckdb_conn.execute("select answer from tpch_answers() where scale_factor = 0.1 and query_nr="+str(i)).fetchone()[0].split("\n")[1:] + for i in range(1, 23): + query = duckdb_conn.execute("select query from tpch_queries() where query_nr=" + str(i)).fetchone()[0] + answers = ( + duckdb_conn.execute("select answer from tpch_answers() where scale_factor = 0.1 and query_nr=" + str(i)) + .fetchone()[0] + .split("\n")[1:] + ) result = duckdb_conn.execute(query) - assert(check_result(result,answers)) - print ("Query " + str(i) + " works") + assert check_result(result, answers) + print("Query " + str(i) + " works") - def test_tpch_arrow_batch(self,duckdb_cursor): + def test_tpch_arrow_batch(self, duckdb_cursor): if not can_run: return @@ -91,22 +105,34 @@ def test_tpch_arrow_batch(self,duckdb_cursor): duck_tbl = duckdb_conn.table(tpch_table) arrow_tables.append(pyarrow.Table.from_batches(duck_tbl.arrow().to_batches(10))) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) - duckdb_conn.execute("DROP TABLE "+tpch_table) + duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) - for i in range (1,23): - query = duckdb_conn.execute("select query from tpch_queries() where query_nr="+str(i)).fetchone()[0] - answers = duckdb_conn.execute("select answer from tpch_answers() where scale_factor = 0.01 and query_nr="+str(i)).fetchone()[0].split("\n")[1:] + for i in range(1, 23): + query = duckdb_conn.execute("select query from tpch_queries() where query_nr=" + str(i)).fetchone()[0] + answers = ( + duckdb_conn.execute( + "select answer from tpch_answers() where scale_factor = 0.01 and query_nr=" + str(i) + ) + .fetchone()[0] + .split("\n")[1:] + ) result = duckdb_conn.execute(query) - assert(check_result(result,answers)) - print ("Query " + str(i) + " works") + assert check_result(result, answers) + print("Query " + str(i) + " works") duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - for i in range (1,23): - query = duckdb_conn.execute("select query from tpch_queries() where query_nr="+str(i)).fetchone()[0] - answers = duckdb_conn.execute("select answer from tpch_answers() where scale_factor = 0.01 and query_nr="+str(i)).fetchone()[0].split("\n")[1:] + for i in range(1, 23): + query = duckdb_conn.execute("select query from tpch_queries() where query_nr=" + str(i)).fetchone()[0] + answers = ( + duckdb_conn.execute( + "select answer from tpch_answers() where scale_factor = 0.01 and query_nr=" + str(i) + ) + .fetchone()[0] + .split("\n")[1:] + ) result = duckdb_conn.execute(query) - assert(check_result(result,answers)) - print ("Query " + str(i) + " works (Parallel)") + assert check_result(result, answers) + print("Query " + str(i) + " works (Parallel)") diff --git a/tools/pythonpkg/tests/fast/arrow/test_unregister.py b/tools/pythonpkg/tests/fast/arrow/test_unregister.py index 81d08d90567e..c63ef0d6e64a 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_unregister.py +++ b/tools/pythonpkg/tests/fast/arrow/test_unregister.py @@ -3,18 +3,21 @@ import gc import duckdb import os + try: import pyarrow import pyarrow.parquet + can_run = True except: can_run = False + class TestArrowUnregister(object): def test_arrow_unregister1(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) @@ -37,7 +40,7 @@ def test_arrow_unregister2(self, duckdb_cursor): os.remove(db) connection = duckdb.connect(db) - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) connection.register("arrow_table", arrow_table_obj) @@ -56,4 +59,4 @@ def test_arrow_unregister2(self, duckdb_cursor): assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() - connection.close() \ No newline at end of file + connection.close() diff --git a/tools/pythonpkg/tests/fast/arrow/test_view.py b/tools/pythonpkg/tests/fast/arrow/test_view.py index 5b1825820b1b..a5d42af42d78 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_view.py +++ b/tools/pythonpkg/tests/fast/arrow/test_view.py @@ -1,21 +1,24 @@ import duckdb import os + try: import pyarrow import pyarrow.parquet + can_run = True except: can_run = False + class TestArrowView(object): def test_arrow_view(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' duckdb_conn = duckdb.connect() userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) duckdb_conn.from_arrow(userdata_parquet_table).create_view('arrow_view') - assert (duckdb_conn.execute("PRAGMA show_tables").fetchone() == ('arrow_view',)) - assert(duckdb_conn.execute("select avg(salary)::INT from arrow_view").fetchone()[0] == 149005) + assert duckdb_conn.execute("PRAGMA show_tables").fetchone() == ('arrow_view',) + assert duckdb_conn.execute("select avg(salary)::INT from arrow_view").fetchone()[0] == 149005 diff --git a/tools/pythonpkg/tests/fast/numpy/test_numpy_new_path.py b/tools/pythonpkg/tests/fast/numpy/test_numpy_new_path.py index 78ef284cfb10..c557f7033191 100644 --- a/tools/pythonpkg/tests/fast/numpy/test_numpy_new_path.py +++ b/tools/pythonpkg/tests/fast/numpy/test_numpy_new_path.py @@ -7,21 +7,22 @@ from datetime import timedelta import pytest + class TestScanNumpy(object): def test_scan_numpy(self, duckdb_cursor): - z = np.array([1,2,3]) + z = np.array([1, 2, 3]) res = duckdb.sql("select * from z").fetchall() assert res == [(1,), (2,), (3,)] - z = np.array([[1,2,3], [4,5,6]]) + z = np.array([[1, 2, 3], [4, 5, 6]]) res = duckdb.sql("select * from z").fetchall() assert res == [(1, 4), (2, 5), (3, 6)] - z = [np.array([1,2,3]), np.array([4,5,6])] + z = [np.array([1, 2, 3]), np.array([4, 5, 6])] res = duckdb.sql("select * from z").fetchall() assert res == [(1, 4), (2, 5), (3, 6)] - z = {"z": np.array([1,2,3]), "x": np.array([4,5,6])} + z = {"z": np.array([1, 2, 3]), "x": np.array([4, 5, 6])} res = duckdb.sql("select * from z").fetchall() assert res == [(1, 4), (2, 5), (3, 6)] @@ -29,7 +30,7 @@ def test_scan_numpy(self, duckdb_cursor): res = duckdb.sql("select * from z").fetchall() assert res == [('zzz',), ('xxx',)] - z = [np.array(["zzz", "xxx"]), np.array([1,2])] + z = [np.array(["zzz", "xxx"]), np.array([1, 2])] res = duckdb.sql("select * from z").fetchall() assert res == [('zzz', 1), ('xxx', 2)] @@ -39,14 +40,21 @@ def test_scan_numpy(self, duckdb_cursor): z.append({str(3 - i): i}) z = np.array(z) res = duckdb.sql("select * from z").fetchall() - assert res == [({'key': ['3'], 'value': [0]},), ({'key': ['2'], 'value': [1]},), ({'key': ['1'], 'value': [2]},)] + assert res == [ + ({'key': ['3'], 'value': [0]},), + ({'key': ['2'], 'value': [1]},), + ({'key': ['1'], 'value': [2]},), + ] # test timedelta - delta = timedelta(days=50, seconds = 27,microseconds=10,milliseconds=29000,minutes=5,hours=8,weeks=2) - delta2 = timedelta(days=5, seconds = 27,microseconds=10,milliseconds=29000,minutes=5,hours=8,weeks=2) + delta = timedelta(days=50, seconds=27, microseconds=10, milliseconds=29000, minutes=5, hours=8, weeks=2) + delta2 = timedelta(days=5, seconds=27, microseconds=10, milliseconds=29000, minutes=5, hours=8, weeks=2) z = np.array([delta, delta2]) res = duckdb.sql("select * from z").fetchall() - assert res == [(timedelta(days=64, seconds=29156, microseconds=10),), (timedelta(days=19, seconds=29156, microseconds=10),)] + assert res == [ + (timedelta(days=64, seconds=29156, microseconds=10),), + (timedelta(days=19, seconds=29156, microseconds=10),), + ] # np.empty z = np.empty((3,)) @@ -64,38 +72,38 @@ def test_scan_numpy(self, duckdb_cursor): assert res == [(None,)] # dict of mixed types - z = {"z": np.array([1,2,3]), "x": np.array(["z", "x", "c"])} + z = {"z": np.array([1, 2, 3]), "x": np.array(["z", "x", "c"])} res = duckdb.sql("select * from z").fetchall() assert res == [(1, 'z'), (2, 'x'), (3, 'c')] # list of mixed types - z = [np.array([1,2,3]), np.array(["z", "x", "c"])] + z = [np.array([1, 2, 3]), np.array(["z", "x", "c"])] res = duckdb.sql("select * from z").fetchall() assert res == [(1, 'z'), (2, 'x'), (3, 'c')] # currently unsupported formats, will throw duckdb.InvalidInputException # list of arrays with different length - z = [np.array([1,2]), np.array([3])] + z = [np.array([1, 2]), np.array([3])] with pytest.raises(duckdb.InvalidInputException): duckdb.sql("select * from z") # dict of ndarrays of different length - z = {"z":np.array([1,2]), "x":np.array([3])} + z = {"z": np.array([1, 2]), "x": np.array([3])} with pytest.raises(duckdb.InvalidInputException): duckdb.sql("select * from z") # high dimensional tensors - z = np.array([[[1,2]]]) + z = np.array([[[1, 2]]]) with pytest.raises(duckdb.InvalidInputException): duckdb.sql("select * from z") # list of ndarrys with len(shape) > 1 - z = [np.array([[1,2],[3,4]])] + z = [np.array([[1, 2], [3, 4]])] with pytest.raises(duckdb.InvalidInputException): duckdb.sql("select * from z") # dict of ndarrays with len(shape) > 1 - z = {"x":np.array([[1,2],[3,4]])} + z = {"x": np.array([[1, 2], [3, 4]])} with pytest.raises(duckdb.InvalidInputException): duckdb.sql("select * from z") diff --git a/tools/pythonpkg/tests/fast/pandas/test_2304.py b/tools/pythonpkg/tests/fast/pandas/test_2304.py index c44fb07f7829..631cdd1aae60 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_2304.py +++ b/tools/pythonpkg/tests/fast/pandas/test_2304.py @@ -3,21 +3,28 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestPandasMergeSameName(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_2304(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame({ - 'id_1': [1, 1, 1, 2, 2], - 'agedate': np.array(['2010-01-01','2010-02-01','2010-03-01','2020-02-01', '2020-03-01']).astype('datetime64[D]'), - 'age': [1, 2, 3, 1, 2], - 'v': [1.1, 1.2, 1.3, 2.1, 2.2] - }) - - df2 = pandas.DataFrame({ - 'id_1': [1, 1, 2], - 'agedate': np.array(['2010-01-01','2010-02-01', '2020-03-01']).astype('datetime64[D]'), - 'v2': [11.1, 11.2, 21.2] - }) + df1 = pandas.DataFrame( + { + 'id_1': [1, 1, 1, 2, 2], + 'agedate': np.array(['2010-01-01', '2010-02-01', '2010-03-01', '2020-02-01', '2020-03-01']).astype( + 'datetime64[D]' + ), + 'age': [1, 2, 3, 1, 2], + 'v': [1.1, 1.2, 1.3, 2.1, 2.2], + } + ) + + df2 = pandas.DataFrame( + { + 'id_1': [1, 1, 2], + 'agedate': np.array(['2010-01-01', '2010-02-01', '2020-03-01']).astype('datetime64[D]'), + 'v2': [11.1, 11.2, 21.2], + } + ) con = duckdb.connect() con.register('df1', df1) @@ -35,26 +42,26 @@ def test_2304(self, duckdb_cursor, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_pd_names(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame({ - 'id': [1, 1, 2], - 'id_1': [1, 1, 2], - 'id_3': [1, 1, 2], - }) - - df2 = pandas.DataFrame({ - 'id': [1, 1, 2], - 'id_1': [1, 1, 2], - 'id_2': [1, 1, 1] - }) - - exp_result = pandas.DataFrame({ - 'id': [1, 1, 2, 1, 1], - 'id_1': [1, 1, 2, 1, 1], - 'id_3': [1, 1, 2, 1, 1], - 'id_2': [1, 1, 2, 1, 1], - 'id_1_2': [1, 1, 2, 1, 1], - 'id_2_2': [1, 1, 1, 1, 1] - }) + df1 = pandas.DataFrame( + { + 'id': [1, 1, 2], + 'id_1': [1, 1, 2], + 'id_3': [1, 1, 2], + } + ) + + df2 = pandas.DataFrame({'id': [1, 1, 2], 'id_1': [1, 1, 2], 'id_2': [1, 1, 1]}) + + exp_result = pandas.DataFrame( + { + 'id': [1, 1, 2, 1, 1], + 'id_1': [1, 1, 2, 1, 1], + 'id_3': [1, 1, 2, 1, 1], + 'id_2': [1, 1, 2, 1, 1], + 'id_1_2': [1, 1, 2, 1, 1], + 'id_2_2': [1, 1, 1, 1, 1], + } + ) con = duckdb.connect() con.register('df1', df1) @@ -68,22 +75,24 @@ def test_pd_names(self, duckdb_cursor, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_repeat_name(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame({ - 'id': [1], - 'id_1': [1], - 'id_2': [1], - }) - - df2 = pandas.DataFrame({ - 'id': [1] - }) - - exp_result = pandas.DataFrame({ - 'id': [1], - 'id_1': [1], - 'id_2': [1], - 'id_2_1': [1], - }) + df1 = pandas.DataFrame( + { + 'id': [1], + 'id_1': [1], + 'id_2': [1], + } + ) + + df2 = pandas.DataFrame({'id': [1]}) + + exp_result = pandas.DataFrame( + { + 'id': [1], + 'id_1': [1], + 'id_2': [1], + 'id_2_1': [1], + } + ) con = duckdb.connect() con.register('df1', df1) diff --git a/tools/pythonpkg/tests/fast/pandas/test_append_df.py b/tools/pythonpkg/tests/fast/pandas/test_append_df.py index 51faeb37214e..18805a5acf05 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_append_df.py +++ b/tools/pythonpkg/tests/fast/pandas/test_append_df.py @@ -2,27 +2,29 @@ import pytest from conftest import NumpyPandas, ArrowPandas -class TestAppendDF(object): +class TestAppendDF(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_df_to_table_append(self, duckdb_cursor, pandas): conn = duckdb.connect() conn.execute("Create table integers (i integer)") - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) - conn.append('integers',df_in) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) + conn.append('integers', df_in) assert conn.execute('select count(*) from integers').fetchone()[0] == 5 @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_append_by_name(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool, c varchar)") - df_in = pandas.DataFrame({ - 'c': ['duck', 'db'], - 'b': [False, True], - 'a': [4,2] - }) + df_in = pandas.DataFrame({'c': ['duck', 'db'], 'b': [False, True], 'a': [4, 2]}) # By default we append by position, causing the following exception: - with pytest.raises(duckdb.ConversionException, match="Conversion Error: Could not convert string 'duck' to INT32"): + with pytest.raises( + duckdb.ConversionException, match="Conversion Error: Could not convert string 'duck' to INT32" + ): con.append('tbl', df_in) # When we use 'by_name' we instead append by name @@ -33,12 +35,12 @@ def test_append_by_name(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_append_by_name_quoted(self, pandas): con = duckdb.connect() - con.execute(""" + con.execute( + """ create table tbl ("needs to be quoted" integer, other varchar) - """) - df_in = pandas.DataFrame({ - "needs to be quoted": [1,2,3] - }) + """ + ) + df_in = pandas.DataFrame({"needs to be quoted": [1, 2, 3]}) con.append('tbl', df_in, by_name=True) res = con.table('tbl').fetchall() assert res == [(1, None), (2, None), (3, None)] @@ -47,18 +49,12 @@ def test_append_by_name_quoted(self, pandas): def test_append_by_name_no_exact_match(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool)") - df_in = pandas.DataFrame({ - 'c': ['a', 'b'], - 'b': [True, False], - 'a': [42, 1337] - }) + df_in = pandas.DataFrame({'c': ['a', 'b'], 'b': [True, False], 'a': [42, 1337]}) # Too many columns raises an error, because the columns cant be found in the targeted table with pytest.raises(duckdb.BinderException, match='Table "tbl" does not have a column with name "c"'): con.append('tbl', df_in, by_name=True) - df_in = pandas.DataFrame({ - 'b': [False, False, False] - }) + df_in = pandas.DataFrame({'b': [False, False, False]}) # Not matching all columns is not a problem, as they will be filled with NULL instead con.append('tbl', df_in, by_name=True) @@ -69,9 +65,7 @@ def test_append_by_name_no_exact_match(self, pandas): # Empty the table con.execute("create or replace table tbl (a integer, b bool)") - df_in = pandas.DataFrame({ - 'a': [1,2,3] - }) + df_in = pandas.DataFrame({'a': [1, 2, 3]}) con.append('tbl', df_in, by_name=True) res = con.table('tbl').fetchall() # Also works for missing columns *after* the supplied ones diff --git a/tools/pythonpkg/tests/fast/pandas/test_bug2281.py b/tools/pythonpkg/tests/fast/pandas/test_bug2281.py index e975db61827f..703baf4b6e41 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_bug2281.py +++ b/tools/pythonpkg/tests/fast/pandas/test_bug2281.py @@ -5,6 +5,7 @@ import pandas as pd import io + class TestPandasStringNull(object): def test_pandas_string_null(self, duckdb_cursor): csv = u'''what,is_control,is_test @@ -14,4 +15,4 @@ def test_pandas_string_null(self, duckdb_cursor): duckdb_cursor.register("c", df) duckdb_cursor.execute('select what, count(*) from c group by what') df_result = duckdb_cursor.fetchdf() - assert(True) # Should not crash ^^ + assert True # Should not crash ^^ diff --git a/tools/pythonpkg/tests/fast/pandas/test_bug5922.py b/tools/pythonpkg/tests/fast/pandas/test_bug5922.py index ab57263f6054..af9be1672ed8 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_bug5922.py +++ b/tools/pythonpkg/tests/fast/pandas/test_bug5922.py @@ -2,14 +2,15 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestPandasAcceptFloat16(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_pandas_accept_float16(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'col': [1,2,3]}) - df16 = df.astype({'col':'float16'}) + df = pandas.DataFrame({'col': [1, 2, 3]}) + df16 = df.astype({'col': 'float16'}) con = duckdb.connect() con.execute('CREATE TABLE tbl AS SELECT * FROM df16') con.execute('select * from tbl') df_result = con.fetchdf() - df32 = df.astype({'col':'float32'}) - assert((df32['col'] == df_result['col']).all()) \ No newline at end of file + df32 = df.astype({'col': 'float32'}) + assert (df32['col'] == df_result['col']).all() diff --git a/tools/pythonpkg/tests/fast/pandas/test_create_table_from_pandas.py b/tools/pythonpkg/tests/fast/pandas/test_create_table_from_pandas.py index 0eab9e61d7b2..69234dc7ac59 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_create_table_from_pandas.py +++ b/tools/pythonpkg/tests/fast/pandas/test_create_table_from_pandas.py @@ -4,6 +4,7 @@ import sys from conftest import NumpyPandas, ArrowPandas + def assert_create(internal_data, expected_result, data_type, pandas): conn = duckdb.connect() df_in = pandas.DataFrame(data=internal_data, dtype=data_type) @@ -13,6 +14,7 @@ def assert_create(internal_data, expected_result, data_type, pandas): result = conn.execute("SELECT * FROM t").fetchall() assert result == expected_result + def assert_create_register(internal_data, expected_result, data_type, pandas): conn = duckdb.connect() df_in = pandas.DataFrame(data=internal_data, dtype=data_type) @@ -22,19 +24,19 @@ def assert_create_register(internal_data, expected_result, data_type, pandas): result = conn.execute("SELECT * FROM t").fetchall() assert result == expected_result -class TestCreateTableFromPandas(object): +class TestCreateTableFromPandas(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_integer_create_table(self, duckdb_cursor, pandas): if sys.version_info.major < 3: return - #FIXME: This should work with other data types e.g., int8... - data_types = ['Int8','Int16','Int32','Int64'] - internal_data = [1,2,3,4] + # FIXME: This should work with other data types e.g., int8... + data_types = ['Int8', 'Int16', 'Int32', 'Int64'] + internal_data = [1, 2, 3, 4] expected_result = [(1,), (2,), (3,), (4,)] for data_type in data_types: print(data_type) - assert_create_register(internal_data,expected_result,data_type, pandas) - assert_create(internal_data,expected_result,data_type, pandas) + assert_create_register(internal_data, expected_result, data_type, pandas) + assert_create(internal_data, expected_result, data_type, pandas) - #FIXME: Also test other data types \ No newline at end of file + # FIXME: Also test other data types diff --git a/tools/pythonpkg/tests/fast/pandas/test_date_as_datetime.py b/tools/pythonpkg/tests/fast/pandas/test_date_as_datetime.py index 9227875e68c7..038f24a81b37 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_date_as_datetime.py +++ b/tools/pythonpkg/tests/fast/pandas/test_date_as_datetime.py @@ -3,11 +3,13 @@ import datetime import pytest + def run_checks(df): assert type(df['d'][0]) is datetime.date assert df['d'][0] == datetime.date(1992, 7, 30) assert pd.isnull(df['d'][1]) + def test_date_as_datetime(): con = duckdb.connect() con.execute("create table t (d date)") @@ -26,4 +28,3 @@ def test_date_as_datetime(): # Result Methods run_checks(rel.query("t_1", "select * from t_1").df(date_as_object=True)) - diff --git a/tools/pythonpkg/tests/fast/pandas/test_datetime_time.py b/tools/pythonpkg/tests/fast/pandas/test_datetime_time.py index dbcb9e8342c6..33b6ca6a69c9 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_datetime_time.py +++ b/tools/pythonpkg/tests/fast/pandas/test_datetime_time.py @@ -4,15 +4,13 @@ from conftest import NumpyPandas, ArrowPandas from datetime import datetime, timezone, time, timedelta -class TestDateTimeTime(object): +class TestDateTimeTime(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_time_high(self, duckdb_cursor, pandas): duckdb_time = duckdb.query("SELECT make_time(23, 1, 34.234345) AS '0'").df() data = [time(hour=23, minute=1, second=34, microsecond=234345)] - df_in = pandas.DataFrame( - {'0': pandas.Series(data=data, dtype='object')} - ) + df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) @@ -20,9 +18,7 @@ def test_time_high(self, duckdb_cursor, pandas): def test_time_low(self, duckdb_cursor, pandas): duckdb_time = duckdb.query("SELECT make_time(00, 01, 1.000) AS '0'").df() data = [time(hour=0, minute=1, second=1)] - df_in = pandas.DataFrame( - {'0': pandas.Series(data=data, dtype='object')} - ) + df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) @@ -33,9 +29,7 @@ def test_time_timezone_regular(self, duckdb_cursor, pandas): offset = timedelta(hours=3) tz = timezone(offset) data = [time(hour=3, minute=1, second=1, tzinfo=tz)] - df_in = pandas.DataFrame( - {'0': pandas.Series(data=data, dtype='object')} - ) + df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) @@ -46,9 +40,7 @@ def test_time_timezone_negative_extreme(self, duckdb_cursor, pandas): offset = timedelta(hours=-14) tz = timezone(offset) data = [time(hour=22, minute=1, second=1, tzinfo=tz)] - df_in = pandas.DataFrame( - {'0': pandas.Series(data=data, dtype='object')} - ) + df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) @@ -59,9 +51,7 @@ def test_time_timezone_positive_extreme(self, duckdb_cursor, pandas): offset = timedelta(hours=20) tz = timezone(offset) data = [time(hour=8, minute=1, second=1, tzinfo=tz)] - df_in = pandas.DataFrame( - {'0': pandas.Series(data=data, dtype='object')} - ) + df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) diff --git a/tools/pythonpkg/tests/fast/pandas/test_datetime_timestamp.py b/tools/pythonpkg/tests/fast/pandas/test_datetime_timestamp.py index 2feccef05cbc..76dcab868c82 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_datetime_timestamp.py +++ b/tools/pythonpkg/tests/fast/pandas/test_datetime_timestamp.py @@ -4,10 +4,11 @@ import pytest from conftest import NumpyPandas, ArrowPandas from packaging.version import Version + pd = pytest.importorskip("pandas") -class TestDateTimeTimeStamp(object): +class TestDateTimeTimeStamp(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_timestamp_high(self, pandas): duckdb_time = duckdb.query("SELECT '2260-01-01 23:59:00'::TIMESTAMP AS '0'").df() @@ -26,64 +27,76 @@ def test_timestamp_low(self, pandas): df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.skipif(Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones") + @pytest.mark.skipif( + Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + ) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_regular(self, pandas): - duckdb_time = duckdb.query(""" + duckdb_time = duckdb.query( + """ SELECT timestamp '2022-01-01 12:00:00' AT TIME ZONE 'Pacific/Easter' as "0" - """).df() + """ + ).df() offset = datetime.timedelta(hours=-2) timezone = datetime.timezone(offset) df_in = pandas.DataFrame( - {0: pandas.Series(data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype='object')} + { + 0: pandas.Series( + data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype='object' + ) + } ) df_out = duckdb.query_df(df_in, "df", "select * from df").df() print(df_out) print(duckdb_time) pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.skipif(Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones") + @pytest.mark.skipif( + Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + ) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_negative_extreme(self, pandas): - duckdb_time = duckdb.query(""" + duckdb_time = duckdb.query( + """ SELECT timestamp '2022-01-01 12:00:00' AT TIME ZONE 'Chile/EasterIsland' as "0" - """).df() + """ + ).df() offset = datetime.timedelta(hours=-19) timezone = datetime.timezone(offset) df_in = pandas.DataFrame( - {0: pandas.Series(data=[datetime.datetime( - year=2021, - month=12, - day=31, - hour=22, - tzinfo=timezone - )], dtype='object')} + { + 0: pandas.Series( + data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype='object' + ) + } ) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.skipif(Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones") + @pytest.mark.skipif( + Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + ) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_positive_extreme(self, pandas): - duckdb_time = duckdb.query(""" + duckdb_time = duckdb.query( + """ SELECT timestamp '2021-12-31 23:00:00' AT TIME ZONE 'kea_CV' as "0" - """).df() + """ + ).df() # 'kea_CV' is 20 hours ahead of UTC offset = datetime.timedelta(hours=20) timezone = datetime.timezone(offset) df_in = pandas.DataFrame( - {0: pandas.Series(data=[datetime.datetime( - year=2022, - month=1, - day=1, - hour=19, - tzinfo=timezone - )], dtype='object')} + { + 0: pandas.Series( + data=[datetime.datetime(year=2022, month=1, day=1, hour=19, tzinfo=timezone)], dtype='object' + ) + } ) df_out = duckdb.query_df(df_in, "df", """select * from df""").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) diff --git a/tools/pythonpkg/tests/fast/pandas/test_df_analyze.py b/tools/pythonpkg/tests/fast/pandas/test_df_analyze.py index 2275519cc763..19fb6ec182e3 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_df_analyze.py +++ b/tools/pythonpkg/tests/fast/pandas/test_df_analyze.py @@ -4,53 +4,55 @@ import pytest from conftest import NumpyPandas, ArrowPandas + def create_generic_dataframe(data, pandas): - return pandas.DataFrame({'col0': pandas.Series(data=data, dtype='object')}) + return pandas.DataFrame({'col0': pandas.Series(data=data, dtype='object')}) + class TestResolveObjectColumns(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_sample_low_correct(self, duckdb_cursor, pandas): - print(pandas.backend) - duckdb_conn = duckdb.connect() - duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=3") - data = [1000008, 6, 9, 4, 1, 6] - df = create_generic_dataframe(data, pandas) - roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() - duckdb_df = duckdb_conn.query("select * FROM (VALUES (1000008), (6), (9), (4), (1), (6)) as '0'").df() - pandas.testing.assert_frame_equal(duckdb_df, roundtripped_df, check_dtype=False) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_sample_low_correct(self, duckdb_cursor, pandas): + print(pandas.backend) + duckdb_conn = duckdb.connect() + duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=3") + data = [1000008, 6, 9, 4, 1, 6] + df = create_generic_dataframe(data, pandas) + roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() + duckdb_df = duckdb_conn.query("select * FROM (VALUES (1000008), (6), (9), (4), (1), (6)) as '0'").df() + pandas.testing.assert_frame_equal(duckdb_df, roundtripped_df, check_dtype=False) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): - duckdb_conn = duckdb.connect() - duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=2") - # size of list (6) divided by 'pandas_analyze_sample' (2) is the increment used - # in this case index 0 (1000008) and index 3 ([4]) are checked, which dont match - data = [1000008, 6, 9, [4], 1, 6] - df = create_generic_dataframe(data, pandas) - roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() - # Sample high enough to detect mismatch in types, fallback to VARCHAR - assert(roundtripped_df['col0'].dtype == np.dtype('object')) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): + duckdb_conn = duckdb.connect() + duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=2") + # size of list (6) divided by 'pandas_analyze_sample' (2) is the increment used + # in this case index 0 (1000008) and index 3 ([4]) are checked, which dont match + data = [1000008, 6, 9, [4], 1, 6] + df = create_generic_dataframe(data, pandas) + roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() + # Sample high enough to detect mismatch in types, fallback to VARCHAR + assert roundtripped_df['col0'].dtype == np.dtype('object') - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_sample_zero(self, duckdb_cursor, pandas): - duckdb_conn = duckdb.connect() - # Disable dataframe analyze - duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=0") - data = [1000008, 6, 9, 3, 1, 6] - df = create_generic_dataframe(data, pandas) - roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() - # Always converts to VARCHAR - if (pandas.backend == 'pyarrow'): - assert(roundtripped_df['col0'].dtype == np.dtype('int64')) - else: - assert(roundtripped_df['col0'].dtype == np.dtype('object')) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_sample_zero(self, duckdb_cursor, pandas): + duckdb_conn = duckdb.connect() + # Disable dataframe analyze + duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=0") + data = [1000008, 6, 9, 3, 1, 6] + df = create_generic_dataframe(data, pandas) + roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() + # Always converts to VARCHAR + if pandas.backend == 'pyarrow': + assert roundtripped_df['col0'].dtype == np.dtype('int64') + else: + assert roundtripped_df['col0'].dtype == np.dtype('object') - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_sample_low_incorrect_undetected(self, duckdb_cursor, pandas): - duckdb_conn = duckdb.connect() - duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=1") - data = [1000008, 6, 9, [4], [1], 6] - df = create_generic_dataframe(data, pandas) - # Sample size is too low to detect the mismatch, exception is raised when trying to convert - with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): - roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() \ No newline at end of file + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_sample_low_incorrect_undetected(self, duckdb_cursor, pandas): + duckdb_conn = duckdb.connect() + duckdb_conn.execute("SET GLOBAL pandas_analyze_sample=1") + data = [1000008, 6, 9, [4], [1], 6] + df = create_generic_dataframe(data, pandas) + # Sample size is too low to detect the mismatch, exception is raised when trying to convert + with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): + roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() diff --git a/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py b/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py index 367bae598163..d42cddc5bf28 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py +++ b/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py @@ -10,15 +10,19 @@ standard_vector_size = duckdb.__standard_vector_size__ + def create_generic_dataframe(data, pandas): return pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + class IntString: def __init__(self, value: int): self.value = value + def __str__(self): return str(self.value) + # To avoid DECIMAL being upgraded to DOUBLE (because DOUBLE outranks DECIMAL as a LogicalType) # These floats had their precision preserved as string and are now cast to decimal.Decimal def ConvertStringToDecimal(data: list, pandas): @@ -28,8 +32,8 @@ def ConvertStringToDecimal(data: list, pandas): data = pandas.Series(data=data, dtype='object') return data -class TestResolveObjectColumns(object): +class TestResolveObjectColumns(object): # TODO: add support for ArrowPandas @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_integers(self, pandas): @@ -55,9 +59,9 @@ def test_map_fallback_different_keys(self, pandas): [ [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'e': 7}], #'e' instead of 'd' as key + [{'a': 1, 'b': 3, 'c': 3, 'e': 7}], #'e' instead of 'd' as key + [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] ] ) @@ -80,9 +84,9 @@ def test_map_fallback_incorrect_amount_of_keys(self, pandas): [ [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3}], #incorrect amount of keys + [{'a': 1, 'b': 3, 'c': 3}], # incorrect amount of keys + [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] ] ) converted_df = duckdb.query_df(x, "x", "SELECT * FROM x").df() @@ -106,7 +110,7 @@ def test_struct_value_upgrade(self, pandas): [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] + [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], ] ) y = pandas.DataFrame( @@ -115,7 +119,7 @@ def test_struct_value_upgrade(self, pandas): [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}] + [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], ] ) converted_df = duckdb.query_df(x, "x", "SELECT * FROM x").df() @@ -130,7 +134,7 @@ def test_struct_null(self, pandas): [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] + [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], ] ) y = pandas.DataFrame( @@ -139,7 +143,7 @@ def test_struct_null(self, pandas): [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] + [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], ] ) converted_df = duckdb.query_df(x, "x", "SELECT * FROM x").df() @@ -154,7 +158,7 @@ def test_map_fallback_value_upgrade(self, pandas): [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], [{'a': 1, 'b': 3, 'c': 3}], [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] + [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], ] ) y = pandas.DataFrame( @@ -163,7 +167,7 @@ def test_map_fallback_value_upgrade(self, pandas): [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], [{'a': '1', 'b': '3', 'c': '3'}], [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}] + [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], ] ) converted_df = duckdb.query_df(x, "df", "SELECT * FROM df").df() @@ -179,20 +183,24 @@ def test_map_correct(self, pandas): [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}] + [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], ] ) - x.rename(columns = {0 : 'a'}, inplace = True) + x.rename(columns={0: 'a'}, inplace=True) converted_col = duckdb.query_df(x, "x", "select * from x as 'a'", connection=con).df() - con.query(""" + con.query( + """ CREATE TABLE tmp( a MAP(VARCHAR, INTEGER) ); - """) + """ + ) for _ in range(5): - con.query(""" + con.query( + """ INSERT INTO tmp VALUES (MAP(['a', 'b', 'c', 'd'], [1, 3, 3, 7])) - """) + """ + ) duckdb_col = con.query("select a from tmp AS '0'").df() print(duckdb_col.columns) print(converted_col.columns) @@ -207,23 +215,29 @@ def test_map_value_upgrade(self, pandas): [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}] + [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], ] ) - x.rename(columns = {0 : 'a'}, inplace = True) + x.rename(columns={0: 'a'}, inplace=True) converted_col = duckdb.query_df(x, "x", "select * from x", connection=con).df() - con.query(""" + con.query( + """ CREATE TABLE tmp2( a MAP(VARCHAR, VARCHAR) ); - """) - con.query(""" + """ + ) + con.query( + """ INSERT INTO tmp2 VALUES (MAP(['a', 'b', 'c', 'd'], ['1', '3', '3', 'test'])) - """) + """ + ) for _ in range(4): - con.query(""" + con.query( + """ INSERT INTO tmp2 VALUES (MAP(['a', 'b', 'c', 'd'], ['1', '3', '3', '7'])) - """) + """ + ) duckdb_col = con.query("select a from tmp2 AS '0'").df() print(duckdb_col.columns) print(converted_col.columns) @@ -231,31 +245,23 @@ def test_map_value_upgrade(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_duplicate(self, pandas): - x = pandas.DataFrame( - [ - [{'key': ['a', 'a', 'b'], 'value': [4, 0, 4]}] - ] - ) - with pytest.raises(duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains duplicates"): + x = pandas.DataFrame([[{'key': ['a', 'a', 'b'], 'value': [4, 0, 4]}]]) + with pytest.raises( + duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains duplicates" + ): converted_col = duckdb.query_df(x, "x", "select * from x").df() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_nullkey(self, pandas): - x = pandas.DataFrame( - [ - [{'key': [None, 'a', 'b'], 'value': [4, 0, 4]}] - ] - ) - with pytest.raises(duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains None"): + x = pandas.DataFrame([[{'key': [None, 'a', 'b'], 'value': [4, 0, 4]}]]) + with pytest.raises( + duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains None" + ): converted_col = duckdb.query_df(x, "x", "select * from x").df() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_nullkeylist(self, pandas): - x = pandas.DataFrame( - [ - [{'key': None, 'value': None}] - ] - ) + x = pandas.DataFrame([[{'key': None, 'value': None}]]) # Isn't actually converted to MAP because isinstance(None, list) != True converted_col = duckdb.query_df(x, "x", "select * from x").df() duckdb_col = duckdb.query("SELECT {key: NULL, value: NULL} as '0'").df() @@ -263,13 +269,10 @@ def test_map_nullkeylist(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_fallback_nullkey(self, pandas): - x = pandas.DataFrame( - [ - [{'a': 4, None: 0, 'c': 4}], - [{'a': 4, None: 0, 'd': 4}] - ] - ) - with pytest.raises(duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains None"): + x = pandas.DataFrame([[{'a': 4, None: 0, 'c': 4}], [{'a': 4, None: 0, 'd': 4}]]) + with pytest.raises( + duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains None" + ): converted_col = duckdb.query_df(x, "x", "select * from x").df() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) @@ -280,19 +283,16 @@ def test_map_fallback_nullkey_coverage(self, pandas): [{'key': None, None: 5}], ] ) - with pytest.raises(duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains None"): + with pytest.raises( + duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains None" + ): converted_col = duckdb.query_df(x, "x", "select * from x").df() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_struct_key_conversion(self, pandas): x = pandas.DataFrame( [ - [{ - IntString(5) : 1, - IntString(-25): 3, - IntString(32): 3, - IntString(32456): 7 - }], + [{IntString(5): 1, IntString(-25): 3, IntString(32): 3, IntString(32456): 7}], ] ) duckdb_col = duckdb.query("select {'5':1, '-25':3, '32':3, '32456':7} as '0'").df() @@ -302,11 +302,7 @@ def test_struct_key_conversion(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_list_correct(self, pandas): - x = pandas.DataFrame( - [ - {'0': [[5], [34], [-245]]} - ] - ) + x = pandas.DataFrame([{'0': [[5], [34], [-245]]}]) duckdb_col = duckdb.query("select [[5], [34], [-245]] as '0'").df() converted_col = duckdb.query_df(x, "tbl", "select * from tbl").df() duckdb.query("drop view if exists tbl") @@ -314,11 +310,7 @@ def test_list_correct(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_list_contains_null(self, pandas): - x = pandas.DataFrame( - [ - {'0': [[5], None, [-245]]} - ] - ) + x = pandas.DataFrame([{'0': [[5], None, [-245]]}]) duckdb_col = duckdb.query("select [[5], NULL, [-245]] as '0'").df() converted_col = duckdb.query_df(x, "tbl", "select * from tbl").df() duckdb.query("drop view if exists tbl") @@ -326,11 +318,7 @@ def test_list_contains_null(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_list_starts_with_null(self, pandas): - x = pandas.DataFrame( - [ - {'0': [None, [5], [-245]]} - ] - ) + x = pandas.DataFrame([{'0': [None, [5], [-245]]}]) duckdb_col = duckdb.query("select [NULL, [5], [-245]] as '0'").df() converted_col = duckdb.query_df(x, "tbl", "select * from tbl").df() duckdb.query("drop view if exists tbl") @@ -338,11 +326,7 @@ def test_list_starts_with_null(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_list_value_upgrade(self, pandas): - x = pandas.DataFrame( - [ - {'0': [['5'], [34], [-245]]} - ] - ) + x = pandas.DataFrame([{'0': [['5'], [34], [-245]]}]) duckdb_col = duckdb.query("select [['5'], ['34'], ['-245']] as '0'").df() converted_col = duckdb.query_df(x, "tbl", "select * from tbl").df() duckdb.query("drop view if exists tbl") @@ -353,27 +337,35 @@ def test_list_column_value_upgrade(self, pandas): con = duckdb.connect() x = pandas.DataFrame( [ - [ [1, 25, 300] ], - [ [500, 345, 30] ], - [ [50, 'a', 67] ], + [[1, 25, 300]], + [[500, 345, 30]], + [[50, 'a', 67]], ] ) - x.rename(columns = {0 : 'a'}, inplace = True) + x.rename(columns={0: 'a'}, inplace=True) converted_col = duckdb.query_df(x, "x", "select * from x", connection=con).df() - con.query(""" + con.query( + """ CREATE TABLE tmp3( a VARCHAR[] ); - """) - con.query(""" + """ + ) + con.query( + """ INSERT INTO tmp3 VALUES (['1', '25', '300']) - """) - con.query(""" + """ + ) + con.query( + """ INSERT INTO tmp3 VALUES (['500', '345', '30']) - """) - con.query(""" + """ + ) + con.query( + """ INSERT INTO tmp3 VALUES (['50', 'a', '67']) - """) + """ + ) duckdb_col = con.query("select a from tmp3 AS '0'").df() print(duckdb_col.columns) print(converted_col.columns) @@ -426,10 +418,9 @@ def test_numpy_object_with_stride(self, pandas): (6, 12, 0), (7, 14, 0), (8, 16, 0), - (9, 18, 0) + (9, 18, 0), ] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_numpy_stringliterals(self, pandas): con = duckdb.connect() @@ -449,8 +440,8 @@ def test_integer_conversion_fail(self, pandas): # Most of the time numpy.datetime64 is just a wrapper around a datetime.datetime object # But to support arbitrary precision, it can fall back to using an `int` internally - - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])# Which we don't support yet + + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) # Which we don't support yet def test_numpy_datetime(self, pandas): numpy = pytest.importorskip("numpy") @@ -461,7 +452,7 @@ def test_numpy_datetime(self, pandas): data += [numpy.datetime64('2049-01-13T00:24:31.999999')] * standard_vector_size x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) res = duckdb.query_df(x, "x", "select distinct * from x").df() - assert(len(res['dates'].__array__()) == 4) + assert len(res['dates'].__array__()) == 4 @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_numpy_datetime_int_internally(self, pandas): @@ -469,7 +460,10 @@ def test_numpy_datetime_int_internally(self, pandas): data = [numpy.datetime64('2022-12-10T21:38:24.0000000000001')] x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) - with pytest.raises(duckdb.ConversionException, match=re.escape("Conversion Error: Unimplemented type for cast (BIGINT -> TIMESTAMP)")): + with pytest.raises( + duckdb.ConversionException, + match=re.escape("Conversion Error: Unimplemented type for cast (BIGINT -> TIMESTAMP)"), + ): rel = duckdb.query_df(x, "x", "create table dates as select dates::TIMESTAMP WITHOUT TIME ZONE from x") @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) @@ -482,7 +476,7 @@ def test_fallthrough_object_conversion(self, pandas): ] ) duckdb_col = duckdb.query_df(x, "x", "select * from x").df() - df_expected_res = pandas.DataFrame({'0': pandas.Series(['4','2','0'])}) + df_expected_res = pandas.DataFrame({'0': pandas.Series(['4', '2', '0'])}) pandas.testing.assert_frame_equal(duckdb_col, df_expected_res) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) @@ -503,38 +497,46 @@ def test_numeric_decimal(self, pandas): """ duckdb_conn.execute(reference_query) # Because of this we need to wrap these native floats as DECIMAL for this test, to avoid these decimals being "upgraded" to DOUBLE - x = pandas.DataFrame({ - '0': ConvertStringToDecimal([5, '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal([5002340, 13, '-12.0000000005', '7453324234.0', None, '-324234234'], pandas), - '2': ConvertStringToDecimal(['-234234234234.0', '324234234.00000005', -128, 345345, '1E5', '1324234359'], pandas) - }) + x = pandas.DataFrame( + { + '0': ConvertStringToDecimal([5, '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), + '1': ConvertStringToDecimal( + [5002340, 13, '-12.0000000005', '7453324234.0', None, '-324234234'], pandas + ), + '2': ConvertStringToDecimal( + ['-234234234234.0', '324234234.00000005', -128, 345345, '1E5', '1324234359'], pandas + ), + } + ) reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(x, "x", "select * from x").fetchall() - assert(conversion == reference) + assert conversion == reference @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_coverage(self, pandas): duckdb_conn = duckdb.connect() - x = pandas.DataFrame({ - '0': [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")] - }) + x = pandas.DataFrame( + {'0': [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} + ) conversion = duckdb.query_df(x, "x", "select * from x").fetchall() print(conversion[0][0].__class__) for item in conversion: - assert(isinstance(item[0], float)) - assert(math.isnan(conversion[0][0])) - assert(math.isnan(conversion[1][0])) - assert(math.isnan(conversion[2][0])) - assert(math.isinf(conversion[3][0])) - assert(math.isinf(conversion[4][0])) - assert(math.isinf(conversion[5][0])) - assert(str(conversion) == '[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]') + assert isinstance(item[0], float) + assert math.isnan(conversion[0][0]) + assert math.isnan(conversion[1][0]) + assert math.isnan(conversion[2][0]) + assert math.isinf(conversion[3][0]) + assert math.isinf(conversion[4][0]) + assert math.isinf(conversion[5][0]) + assert str(conversion) == '[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]' # Test that the column 'offset' is actually used when converting, - - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])# and that the same 2048 (STANDARD_VECTOR_SIZE) values are not being scanned over and over again + + @pytest.mark.parametrize( + 'pandas', [NumpyPandas(), ArrowPandas()] + ) # and that the same 2048 (STANDARD_VECTOR_SIZE) values are not being scanned over and over again def test_multiple_chunks(self, pandas): data = [] data += [datetime.date(2022, 9, 13) for x in range(standard_vector_size)] @@ -543,131 +545,103 @@ def test_multiple_chunks(self, pandas): data += [datetime.date(2022, 9, 16) for x in range(standard_vector_size)] x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) res = duckdb.query_df(x, "x", "select distinct * from x").df() - assert(len(res['dates'].__array__()) == 4) + assert len(res['dates'].__array__()) == 4 @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_multiple_chunks_aggregate(self, pandas): conn = duckdb.connect() - conn.execute("create table dates as select '2022-09-14'::DATE + INTERVAL (i::INTEGER) DAY as i from range(0, 4096) tbl(i);") + conn.execute( + "create table dates as select '2022-09-14'::DATE + INTERVAL (i::INTEGER) DAY as i from range(0, 4096) tbl(i);" + ) res = duckdb.query("select * from dates", connection=conn).df() date_df = res.copy() # Convert the values to `datetime.date` values, and the dtype of the column to 'object' date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert(str(date_df['i'].dtype) == 'object') - expected_res = duckdb.query('select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from dates;', connection=conn).fetchall() - actual_res = duckdb.query_df(date_df, 'x', 'select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from x').fetchall() - assert(expected_res == actual_res) + assert str(date_df['i'].dtype) == 'object' + expected_res = duckdb.query( + 'select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from dates;', connection=conn + ).fetchall() + actual_res = duckdb.query_df( + date_df, 'x', 'select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from x' + ).fetchall() + assert expected_res == actual_res conn.execute('drop table dates') # Now with nulls interleaved for i in range(0, len(res['i']), 2): res['i'][i] = None - date_view = conn.register("date_view", res) date_view.execute('create table dates as select * from date_view') - expected_res = duckdb.query("select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from dates", connection=conn).fetchall() + expected_res = duckdb.query( + "select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from dates", connection=conn + ).fetchall() date_df = res.copy() # Convert the values to `datetime.date` values, and the dtype of the column to 'object' date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert(str(date_df['i'].dtype) == 'object') - actual_res = duckdb.query_df(date_df, 'x', 'select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from x').fetchall() - assert(expected_res == actual_res) + assert str(date_df['i'].dtype) == 'object' + actual_res = duckdb.query_df( + date_df, 'x', 'select avg(epoch(i)), min(epoch(i)), max(epoch(i)) from x' + ).fetchall() + assert expected_res == actual_res @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_mixed_object_types(self, pandas): - x = pandas.DataFrame({ - 'nested': pandas.Series(data=[{'a': 1, 'b': 2}, [5, 4, 3], {'key': [1,2,3], 'value': ['a', 'b', 'c']}], dtype='object'), - }) + x = pandas.DataFrame( + { + 'nested': pandas.Series( + data=[{'a': 1, 'b': 2}, [5, 4, 3], {'key': [1, 2, 3], 'value': ['a', 'b', 'c']}], dtype='object' + ), + } + ) res = duckdb.query_df(x, "x", "select * from x").df() - assert(res['nested'].dtype == np.dtype('object')) - + assert res['nested'].dtype == np.dtype('object') @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_struct(self, pandas): - x = pandas.DataFrame([ - { - # STRUCT(b STRUCT(x VARCHAR, y VARCHAR)) - 'a': { - 'b': { - 'x': 'A', - 'y': 'B' - } - } - }, - { - # STRUCT(b STRUCT(x VARCHAR)) - 'a': { - 'b': { - 'x': 'A' - } - } - } - ]) - # The dataframe has incompatible struct schemas in the nested child - # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) - con = duckdb.connect() - res = con.sql("select * from x").fetchall() - assert res == [ - ( + x = pandas.DataFrame( + [ { - 'b': { - 'key': ['x', 'y'], - 'value': ['A', 'B'] - } + # STRUCT(b STRUCT(x VARCHAR, y VARCHAR)) + 'a': {'b': {'x': 'A', 'y': 'B'}} }, - ), - ( { - 'b': { - 'key': ['x'], - 'value': ['A'] - } + # STRUCT(b STRUCT(x VARCHAR)) + 'a': {'b': {'x': 'A'}} }, - ) - ] + ] + ) + # The dataframe has incompatible struct schemas in the nested child + # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) + con = duckdb.connect() + res = con.sql("select * from x").fetchall() + assert res == [({'b': {'key': ['x', 'y'], 'value': ['A', 'B']}},), ({'b': {'key': ['x'], 'value': ['A']}},)] @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_list(self, pandas): - x = pandas.DataFrame({'a': [ - [ - # STRUCT(x VARCHAR, y VARCHAR)[] - { - 'x': 'A', - 'y': 'B' - }, - # STRUCT(x VARCHAR)[] - { - 'x': 'A' - } - ] - ]}) + x = pandas.DataFrame( + { + 'a': [ + [ + # STRUCT(x VARCHAR, y VARCHAR)[] + {'x': 'A', 'y': 'B'}, + # STRUCT(x VARCHAR)[] + {'x': 'A'}, + ] + ] + } + ) # The dataframe has incompatible struct schemas in the nested child # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) con = duckdb.connect() res = con.sql("select * from x").fetchall() - assert res == [ - ( - [ - { - 'key': ['x', 'y'], - 'value': ['A', 'B'] - }, - { - 'key': ['x'], - 'value': ['A'] - } - ], - ) - ] + assert res == [([{'key': ['x', 'y'], 'value': ['A', 'B']}, {'key': ['x'], 'value': ['A']}],)] @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_analyze_sample_too_small(self, pandas): - data = [1 for _ in range(9)] + [[1,2,3]] + [1 for _ in range(9991)] - x = pandas.DataFrame({ - 'a': pandas.Series(data=data) - }) + data = [1 for _ in range(9)] + [[1, 2, 3]] + [1 for _ in range(9991)] + x = pandas.DataFrame({'a': pandas.Series(data=data)}) with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): res = duckdb.query_df(x, "x", "select * from x").df() @@ -703,7 +677,7 @@ def test_numeric_decimal_zero_fractional(self, pandas): reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(decimals, "x", "select * from x").fetchall() - assert(conversion == reference) + assert conversion == reference @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_incompatible(self, pandas): @@ -720,29 +694,29 @@ def test_numeric_decimal_incompatible(self, pandas): ) tbl(a, b, c); """ duckdb_conn.execute(reference_query) - x = pandas.DataFrame({ - '0': ConvertStringToDecimal(['5', '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal([5002340, 13, '-12.0000000005', 7453324234, None, '-324234234'], pandas), - '2': ConvertStringToDecimal([-234234234234, '324234234.00000005', -128, 345345, 0, '1324234359'], pandas) - }) + x = pandas.DataFrame( + { + '0': ConvertStringToDecimal(['5', '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), + '1': ConvertStringToDecimal([5002340, 13, '-12.0000000005', 7453324234, None, '-324234234'], pandas), + '2': ConvertStringToDecimal( + [-234234234234, '324234234.00000005', -128, 345345, 0, '1324234359'], pandas + ), + } + ) reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(x, "x", "select * from x").fetchall() - assert(conversion == reference) + assert conversion == reference print(reference) print(conversion) - - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])#result: [('1E-28',), ('10000000000000000000000000.0',)] + @pytest.mark.parametrize( + 'pandas', [NumpyPandas(), ArrowPandas()] + ) # result: [('1E-28',), ('10000000000000000000000000.0',)] def test_numeric_decimal_combined(self, pandas): duckdb_conn = duckdb.connect() decimals = pandas.DataFrame( - data={ - "0": [ - Decimal("0.0000000000000000000000000001"), - Decimal("10000000000000000000000000.0") - ] - } + data={"0": [Decimal("0.0000000000000000000000000001"), Decimal("10000000000000000000000000.0")]} ) reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( @@ -754,11 +728,11 @@ def test_numeric_decimal_combined(self, pandas): duckdb_conn.execute(reference_query) reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(decimals, "x", "select * from x").fetchall() - assert(conversion == reference) + assert conversion == reference print(reference) print(conversion) - #result: [('1234.0',), ('123456789.0',), ('1234567890123456789.0',), ('0.1234567890123456789',)] + # result: [('1234.0',), ('123456789.0',), ('1234567890123456789.0',), ('0.1234567890123456789',)] @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_varying_sizes(self, pandas): duckdb_conn = duckdb.connect() @@ -768,7 +742,7 @@ def test_numeric_decimal_varying_sizes(self, pandas): Decimal("1234.0"), Decimal("123456789.0"), Decimal("1234567890123456789.0"), - Decimal("0.1234567890123456789") + Decimal("0.1234567890123456789"), ] } ) @@ -784,7 +758,7 @@ def test_numeric_decimal_varying_sizes(self, pandas): duckdb_conn.execute(reference_query) reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(decimals, "x", "select * from x").fetchall() - assert(conversion == reference) + assert conversion == reference print(reference) print(conversion) @@ -792,12 +766,11 @@ def test_numeric_decimal_varying_sizes(self, pandas): def test_numeric_decimal_fallback_to_double(self, pandas): duckdb_conn = duckdb.connect() # The widths of these decimal values are bigger than the max supported width for DECIMAL - data = [Decimal("1.234567890123456789012345678901234567890123456789"), Decimal("123456789012345678901234567890123456789012345678.0")] - decimals = pandas.DataFrame( - data={ - "0": data - } - ) + data = [ + Decimal("1.234567890123456789012345678901234567890123456789"), + Decimal("123456789012345678901234567890123456789012345678.0"), + ] + decimals = pandas.DataFrame(data={"0": data}) reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( VALUES @@ -808,8 +781,8 @@ def test_numeric_decimal_fallback_to_double(self, pandas): duckdb_conn.execute(reference_query) reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(decimals, "x", "select * from x").fetchall() - assert(conversion == reference) - assert(isinstance(conversion[0][0], float)) + assert conversion == reference + assert isinstance(conversion[0][0], float) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_double_mixed(self, pandas): @@ -822,13 +795,9 @@ def test_numeric_decimal_double_mixed(self, pandas): Decimal("1234543534535213412342342.2345456"), Decimal("123456789123456789123456789123456789123456789123456789123456789123456789"), Decimal("1232354.000000000000000000000000000035"), - Decimal("123.5e300") + Decimal("123.5e300"), ] - decimals = pandas.DataFrame( - data={ - "0": data - } - ) + decimals = pandas.DataFrame(data={"0": data}) reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( VALUES @@ -845,18 +814,14 @@ def test_numeric_decimal_double_mixed(self, pandas): duckdb_conn.execute(reference_query) reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(decimals, "x", "select * from x").fetchall() - assert(conversion == reference) - assert(isinstance(conversion[0][0], float)) + assert conversion == reference + assert isinstance(conversion[0][0], float) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_out_of_range(self, pandas): duckdb_conn = duckdb.connect() data = [Decimal("1.234567890123456789012345678901234567"), Decimal("123456789012345678901234567890123456.0")] - decimals = pandas.DataFrame( - data={ - "0": data - } - ) + decimals = pandas.DataFrame(data={"0": data}) reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( VALUES @@ -867,5 +832,4 @@ def test_numeric_decimal_out_of_range(self, pandas): duckdb_conn.execute(reference_query) reference = duckdb.query("select * from tbl", connection=duckdb_conn).fetchall() conversion = duckdb.query_df(decimals, "x", "select * from x").fetchall() - assert(conversion == reference) - + assert conversion == reference diff --git a/tools/pythonpkg/tests/fast/pandas/test_df_recursive_nested.py b/tools/pythonpkg/tests/fast/pandas/test_df_recursive_nested.py index 92c1a390a537..264a4db74e58 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_df_recursive_nested.py +++ b/tools/pythonpkg/tests/fast/pandas/test_df_recursive_nested.py @@ -7,246 +7,111 @@ NULL = None + def check_equal(df, reference_query): - duckdb_conn = duckdb.connect() - duckdb_conn.execute(reference_query) - res = duckdb.query('SELECT * FROM tbl', connection=duckdb_conn).fetchall() - df_res = duckdb.query('SELECT * FROM tbl', connection=duckdb_conn).df() - out = duckdb.query_df(df, 'x', "SELECT * FROM x").fetchall() - assert(res == out) + duckdb_conn = duckdb.connect() + duckdb_conn.execute(reference_query) + res = duckdb.query('SELECT * FROM tbl', connection=duckdb_conn).fetchall() + df_res = duckdb.query('SELECT * FROM tbl', connection=duckdb_conn).df() + out = duckdb.query_df(df, 'x', "SELECT * FROM x").fetchall() + assert res == out + def create_reference_query(data): - query = "CREATE TABLE tbl AS SELECT " + str(data).replace("None", "NULL") - return query + query = "CREATE TABLE tbl AS SELECT " + str(data).replace("None", "NULL") + return query + class TestDFRecursiveNested(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_list_of_structs(self, duckdb_cursor, pandas): - data = [ - [ - {'a': 5}, - NULL, - {'a': NULL} - ], - NULL, - [ - {'b': 5}, - NULL, - {'b': NULL} - ] - ] - reference_query = create_reference_query(data) - df = pandas.DataFrame([ - { 'a': data} - ]) - check_equal(df, reference_query) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_list_of_structs(self, duckdb_cursor, pandas): + data = [[{'a': 5}, NULL, {'a': NULL}], NULL, [{'b': 5}, NULL, {'b': NULL}]] + reference_query = create_reference_query(data) + df = pandas.DataFrame([{'a': data}]) + check_equal(df, reference_query) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_list_of_map(self, duckdb_cursor, pandas): - # LIST(MAP(VARCHAR, VARCHAR)) - data = [ - [ - {'key': [5], 'value': [NULL]}, - NULL, - {'key': [], 'value': []} - ], - NULL, - [ - NULL, - {'key': [3,2,4], 'value': [NULL, 'a', NULL]}, - {'key': ['a', 'b', 'c'], 'value': [1, 2, 3]} - ] - ] - reference_query = create_reference_query(data) - df = pandas.DataFrame([ - { 'a': data} - ]) - check_equal(df, reference_query) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_list_of_map(self, duckdb_cursor, pandas): + # LIST(MAP(VARCHAR, VARCHAR)) + data = [ + [{'key': [5], 'value': [NULL]}, NULL, {'key': [], 'value': []}], + NULL, + [NULL, {'key': [3, 2, 4], 'value': [NULL, 'a', NULL]}, {'key': ['a', 'b', 'c'], 'value': [1, 2, 3]}], + ] + reference_query = create_reference_query(data) + df = pandas.DataFrame([{'a': data}]) + check_equal(df, reference_query) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_recursive_list(self, duckdb_cursor, pandas): - # LIST(LIST(LIST(LIST(INTEGER)))) - data = [ - [ - [ - [ - 3, - NULL, - 5 - ], - NULL - ], - NULL, - [ - [ - 5, - -20, - NULL - ] - ] - ], - NULL, - [ - [ - [ - NULL - ] - ], - [ - [] - ], - NULL - ] - ] - reference_query = create_reference_query(data) - df = pandas.DataFrame([ - { 'a': data} - ]) - check_equal(df, reference_query) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_recursive_list(self, duckdb_cursor, pandas): + # LIST(LIST(LIST(LIST(INTEGER)))) + data = [[[[3, NULL, 5], NULL], NULL, [[5, -20, NULL]]], NULL, [[[NULL]], [[]], NULL]] + reference_query = create_reference_query(data) + df = pandas.DataFrame([{'a': data}]) + check_equal(df, reference_query) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_recursive_struct(self, duckdb_cursor, pandas): - #STRUCT(STRUCT(STRUCT(LIST))) - data = { - 'A': { - 'a': { - '1': [ - 1,2,3 - ] - }, - 'b': NULL, - 'c': { - '1': NULL - } - }, - 'B': { - 'a': { - '1': [ - 1, NULL, 3 - ] - }, - 'b': NULL, - 'c': { - '1': NULL - } - } - } - reference_query = create_reference_query(data) - df = pandas.DataFrame([ - { 'a': data} - ]) - check_equal(df, reference_query) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_recursive_struct(self, duckdb_cursor, pandas): + # STRUCT(STRUCT(STRUCT(LIST))) + data = { + 'A': {'a': {'1': [1, 2, 3]}, 'b': NULL, 'c': {'1': NULL}}, + 'B': {'a': {'1': [1, NULL, 3]}, 'b': NULL, 'c': {'1': NULL}}, + } + reference_query = create_reference_query(data) + df = pandas.DataFrame([{'a': data}]) + check_equal(df, reference_query) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_recursive_map(self, duckdb_cursor, pandas): - #MAP( - # MAP( - # INTEGER, - # MAP(INTEGER) - # ), - # INTEGER - #) - data = { - 'key': [ - { - 'key': [ - 5, - 6, - 7 - ], - 'value': [ - { - 'key': [ - 8 - ], - 'value': [ - NULL - ] - }, - NULL, - { - 'key': [ - 9 - ], - 'value': [ - 'a' - ] - } - ] - }, - { - 'key': [ - ], - 'value': [ - ] - } - ], - 'value': [ - 1, - 2 - ] - } - reference_query = create_reference_query(data) - df = pandas.DataFrame([ - { 'a': data} - ]) - check_equal(df, reference_query) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_recursive_map(self, duckdb_cursor, pandas): + # MAP( + # MAP( + # INTEGER, + # MAP(INTEGER) + # ), + # INTEGER + # ) + data = { + 'key': [ + {'key': [5, 6, 7], 'value': [{'key': [8], 'value': [NULL]}, NULL, {'key': [9], 'value': ['a']}]}, + {'key': [], 'value': []}, + ], + 'value': [1, 2], + } + reference_query = create_reference_query(data) + df = pandas.DataFrame([{'a': data}]) + check_equal(df, reference_query) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_recursive_stresstest(self, duckdb_cursor, pandas): - #LIST( - # STRUCT( - # MAP( - # STRUCT( - # LIST( - # INTEGER - # ) - # ) - # LIST( - # STRUCT( - # VARCHAR - # ) - # ) - # ) - # ) - #) - data = [ - { - 'a': { - 'key': [ - { - '1': [5,4,3], - '2': [8,7,6], - '3': [1,2,3] - }, - { - '1': [], - '2': NULL, - '3': [NULL, 0, NULL] - } - ], - 'value': [ - [ - { - 'A': 'abc', - 'B': 'def', - 'C': NULL - } - ], - [ - NULL - ] - ] - }, - 'b': NULL, - 'c': { - 'key': [], - 'value': [] - } - } - ] - reference_query = create_reference_query(data) - df = pandas.DataFrame([ - { 'a': data} - ]) - check_equal(df, reference_query) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_recursive_stresstest(self, duckdb_cursor, pandas): + # LIST( + # STRUCT( + # MAP( + # STRUCT( + # LIST( + # INTEGER + # ) + # ) + # LIST( + # STRUCT( + # VARCHAR + # ) + # ) + # ) + # ) + # ) + data = [ + { + 'a': { + 'key': [ + {'1': [5, 4, 3], '2': [8, 7, 6], '3': [1, 2, 3]}, + {'1': [], '2': NULL, '3': [NULL, 0, NULL]}, + ], + 'value': [[{'A': 'abc', 'B': 'def', 'C': NULL}], [NULL]], + }, + 'b': NULL, + 'c': {'key': [], 'value': []}, + } + ] + reference_query = create_reference_query(data) + df = pandas.DataFrame([{'a': data}]) + check_equal(df, reference_query) diff --git a/tools/pythonpkg/tests/fast/pandas/test_fetch_df_chunk.py b/tools/pythonpkg/tests/fast/pandas/test_fetch_df_chunk.py index 4b41266f4cb1..cf225459d7a3 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_fetch_df_chunk.py +++ b/tools/pythonpkg/tests/fast/pandas/test_fetch_df_chunk.py @@ -1,20 +1,20 @@ import pytest import duckdb -class TestType(object): +class TestType(object): def test_fetch_df_chunk(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") cur_chunk = query.fetch_df_chunk() - assert(cur_chunk['a'][0] == 0) - assert(len(cur_chunk) == 2048) + assert cur_chunk['a'][0] == 0 + assert len(cur_chunk) == 2048 cur_chunk = query.fetch_df_chunk() - assert(cur_chunk['a'][0] == 2048) - assert(len(cur_chunk) == 952) + assert cur_chunk['a'][0] == 2048 + assert len(cur_chunk) == 952 duckdb_cursor.execute("DROP TABLE t") - def test_monahan(self,duckdb_cursor): + def test_monahan(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") cur_chunk = query.fetch_df_chunk() @@ -23,60 +23,60 @@ def test_monahan(self,duckdb_cursor): print(cur_chunk) cur_chunk = query.fetch_df_chunk() print(cur_chunk) - #Should be empty by now + # Should be empty by now try: cur_chunk = query.fetch_df_chunk() print(cur_chunk) except Exception as err: print(err) - - #Should be empty by now + + # Should be empty by now try: cur_chunk = query.fetch_df_chunk() print(cur_chunk) except Exception as err: print(err) duckdb_cursor.execute("DROP TABLE t") - + def test_fetch_df_chunk_parameter(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range a from range(10000);") query = duckdb_cursor.execute("SELECT a FROM t") - + # Return 2 vectors cur_chunk = query.fetch_df_chunk(2) - assert(cur_chunk['a'][0] == 0) - assert(len(cur_chunk) == 4096) - + assert cur_chunk['a'][0] == 0 + assert len(cur_chunk) == 4096 + # Return Default 1 vector cur_chunk = query.fetch_df_chunk() - assert(cur_chunk['a'][0] == 4096) - assert(len(cur_chunk) == 2048) + assert cur_chunk['a'][0] == 4096 + assert len(cur_chunk) == 2048 # Return 0 vectors cur_chunk = query.fetch_df_chunk(0) - assert(len(cur_chunk) == 0) + assert len(cur_chunk) == 0 # Return more vectors than we have remaining cur_chunk = query.fetch_df_chunk(3) - assert(cur_chunk['a'][0] == 6144) - assert(len(cur_chunk) == 3856) + assert cur_chunk['a'][0] == 6144 + assert len(cur_chunk) == 3856 # These shouldn't throw errors (Just emmit empty chunks) cur_chunk = query.fetch_df_chunk(100) - assert(len(cur_chunk) == 0) + assert len(cur_chunk) == 0 cur_chunk = query.fetch_df_chunk(0) - assert(len(cur_chunk) == 0) + assert len(cur_chunk) == 0 cur_chunk = query.fetch_df_chunk() - assert(len(cur_chunk) == 0) + assert len(cur_chunk) == 0 duckdb_cursor.execute("DROP TABLE t") def test_fetch_df_chunk_negative_parameter(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range a from range(100);") query = duckdb_cursor.execute("SELECT a FROM t") - + # Return -1 vector should not work with pytest.raises(TypeError, match='incompatible function arguments'): cur_chunk = query.fetch_df_chunk(-1) - duckdb_cursor.execute("DROP TABLE t") \ No newline at end of file + duckdb_cursor.execute("DROP TABLE t") diff --git a/tools/pythonpkg/tests/fast/pandas/test_fetch_nested.py b/tools/pythonpkg/tests/fast/pandas/test_fetch_nested.py index a70f811a095c..6b989b8b7587 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_fetch_nested.py +++ b/tools/pythonpkg/tests/fast/pandas/test_fetch_nested.py @@ -3,106 +3,207 @@ import pandas as pd import numpy as np + def compare_results(query, list_values=[]): df_duck = duckdb.query(query).df() counter = 0 duck_values = df_duck['a'] for duck_value in duck_values: assert duck_value == list_values[counter] - counter+=1 + counter += 1 -class TestFetchNested(object): +class TestFetchNested(object): def test_fetch_df_list(self, duckdb_cursor): # Integers - compare_results("SELECT a from (select list_value(3,5,10) as a) as t",[[3,5,10]]) - compare_results("SELECT a from (select list_value(3,5,NULL) as a) as t",[[3,5,None]]) - compare_results("SELECT a from (select list_value(NULL,NULL,NULL) as a) as t",[[None,None,None]]) - compare_results("SELECT a from (select list_value() as a) as t",[[]]) - - #Strings - compare_results("SELECT a from (select list_value('test','test_one','test_two') as a) as t",[['test','test_one','test_two']]) - compare_results("SELECT a from (select list_value('test','test_one',NULL) as a) as t",[['test','test_one',None]]) - - #Big Lists - compare_results("SELECT a from (SELECT LIST(i) as a FROM range(10000) tbl(i)) as t",[list(range(0, 10000))]) - - # #Multiple Lists - compare_results("SELECT a from (SELECT LIST(i) as a FROM range(5) tbl(i) group by i%2) as t",[[0,2,4],[1,3]]) - - #Unique Constants - compare_results("SELECT a from (SELECT list_value(1) as a FROM range(5) tbl(i)) as t",[[1],[1],[1],[1],[1]]) - - #Nested Lists - compare_results("SELECT LIST(le) as a FROM (SELECT LIST(i) le from range(5) tbl(i) group by i%2) as t",[[[0, 2, 4], [1, 3]]]) - - #LIST[LIST[LIST[LIST[LIST[INTEGER]]]]]] - compare_results("SELECT list (lllle) as a from (SELECT list (llle) lllle from (SELECT list(lle) llle from (SELECT LIST(le) lle FROM (SELECT LIST(i) le from range(5) tbl(i) group by i%2) as t) as t1) as t2) as t3",[[[[[[0, 2, 4], [1, 3]]]]] ]) - - compare_results('''SELECT grp,lst,a FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as a - from (SELECT a_1%4 as grp, list(a_1) as lst FROM range(7) tbl(a_1) group by grp) as lst_tbl) as T;''',[[None],[None],[2, 6],[3]]) - - #Tests for converting multiple lists to/from Pandas with NULL values and/or strings - compare_results("SELECT list(st) as a from (select i, case when i%5 then NULL else i::VARCHAR end as st from range(10) tbl(i)) as t group by i%2",[['0', None, None, None, None],[None, None, '5', None, None]]) - - def test_struct_df(self,duckdb_cursor): - compare_results("SELECT a from (SELECT STRUCT_PACK(a := 42, b := 43) as a) as t",[{'a': 42, 'b': 43}]) - - compare_results("SELECT a from (SELECT STRUCT_PACK(a := NULL, b := 43) as a) as t",[{'a': None, 'b': 43}]) - - compare_results("SELECT a from (SELECT STRUCT_PACK(a := NULL) as a) as t",[{'a': None}]) + compare_results("SELECT a from (select list_value(3,5,10) as a) as t", [[3, 5, 10]]) + compare_results("SELECT a from (select list_value(3,5,NULL) as a) as t", [[3, 5, None]]) + compare_results("SELECT a from (select list_value(NULL,NULL,NULL) as a) as t", [[None, None, None]]) + compare_results("SELECT a from (select list_value() as a) as t", [[]]) + + # Strings + compare_results( + "SELECT a from (select list_value('test','test_one','test_two') as a) as t", + [['test', 'test_one', 'test_two']], + ) + compare_results( + "SELECT a from (select list_value('test','test_one',NULL) as a) as t", [['test', 'test_one', None]] + ) + + # Big Lists + compare_results("SELECT a from (SELECT LIST(i) as a FROM range(10000) tbl(i)) as t", [list(range(0, 10000))]) - compare_results("SELECT a from (SELECT STRUCT_PACK(a := i, b := i) as a FROM range(10) tbl(i)) as t",[{'a': 0, 'b': 0}, {'a': 1, 'b': 1}, {'a': 2, 'b': 2}, {'a': 3, 'b': 3}, {'a': 4, 'b': 4}, {'a': 5, 'b': 5}, {'a': 6, 'b': 6}, {'a': 7, 'b': 7}, {'a': 8, 'b': 8}, {'a': 9, 'b': 9}]) - - compare_results("SELECT a from (SELECT STRUCT_PACK(a := LIST_VALUE(1,2,3), b := i) as a FROM range(10) tbl(i)) as t",[{'a': [1, 2, 3], 'b': 0}, {'a': [1, 2, 3], 'b': 1}, {'a': [1, 2, 3], 'b': 2}, {'a': [1, 2, 3], 'b': 3}, {'a': [1, 2, 3], 'b': 4}, {'a': [1, 2, 3], 'b': 5}, {'a': [1, 2, 3], 'b': 6}, {'a': [1, 2, 3], 'b': 7}, {'a': [1, 2, 3], 'b': 8}, {'a': [1, 2, 3], 'b': 9}]) + # #Multiple Lists + compare_results( + "SELECT a from (SELECT LIST(i) as a FROM range(5) tbl(i) group by i%2) as t", [[0, 2, 4], [1, 3]] + ) + + # Unique Constants + compare_results( + "SELECT a from (SELECT list_value(1) as a FROM range(5) tbl(i)) as t", [[1], [1], [1], [1], [1]] + ) + + # Nested Lists + compare_results( + "SELECT LIST(le) as a FROM (SELECT LIST(i) le from range(5) tbl(i) group by i%2) as t", + [[[0, 2, 4], [1, 3]]], + ) + + # LIST[LIST[LIST[LIST[LIST[INTEGER]]]]]] + compare_results( + "SELECT list (lllle) as a from (SELECT list (llle) lllle from (SELECT list(lle) llle from (SELECT LIST(le) lle FROM (SELECT LIST(i) le from range(5) tbl(i) group by i%2) as t) as t1) as t2) as t3", + [[[[[[0, 2, 4], [1, 3]]]]]], + ) + + compare_results( + '''SELECT grp,lst,a FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as a + from (SELECT a_1%4 as grp, list(a_1) as lst FROM range(7) tbl(a_1) group by grp) as lst_tbl) as T;''', + [[None], [None], [2, 6], [3]], + ) + + # Tests for converting multiple lists to/from Pandas with NULL values and/or strings + compare_results( + "SELECT list(st) as a from (select i, case when i%5 then NULL else i::VARCHAR end as st from range(10) tbl(i)) as t group by i%2", + [['0', None, None, None, None], [None, None, '5', None, None]], + ) + + def test_struct_df(self, duckdb_cursor): + compare_results("SELECT a from (SELECT STRUCT_PACK(a := 42, b := 43) as a) as t", [{'a': 42, 'b': 43}]) + + compare_results("SELECT a from (SELECT STRUCT_PACK(a := NULL, b := 43) as a) as t", [{'a': None, 'b': 43}]) + + compare_results("SELECT a from (SELECT STRUCT_PACK(a := NULL) as a) as t", [{'a': None}]) + + compare_results( + "SELECT a from (SELECT STRUCT_PACK(a := i, b := i) as a FROM range(10) tbl(i)) as t", + [ + {'a': 0, 'b': 0}, + {'a': 1, 'b': 1}, + {'a': 2, 'b': 2}, + {'a': 3, 'b': 3}, + {'a': 4, 'b': 4}, + {'a': 5, 'b': 5}, + {'a': 6, 'b': 6}, + {'a': 7, 'b': 7}, + {'a': 8, 'b': 8}, + {'a': 9, 'b': 9}, + ], + ) + + compare_results( + "SELECT a from (SELECT STRUCT_PACK(a := LIST_VALUE(1,2,3), b := i) as a FROM range(10) tbl(i)) as t", + [ + {'a': [1, 2, 3], 'b': 0}, + {'a': [1, 2, 3], 'b': 1}, + {'a': [1, 2, 3], 'b': 2}, + {'a': [1, 2, 3], 'b': 3}, + {'a': [1, 2, 3], 'b': 4}, + {'a': [1, 2, 3], 'b': 5}, + {'a': [1, 2, 3], 'b': 6}, + {'a': [1, 2, 3], 'b': 7}, + {'a': [1, 2, 3], 'b': 8}, + {'a': [1, 2, 3], 'b': 9}, + ], + ) + + def test_map_df(self, duckdb_cursor): + compare_results( + "SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t", + [{'key': [1, 2, 3, 4], 'value': [10, 9, 8, 7]}], + ) - def test_map_df(self,duckdb_cursor): - compare_results("SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as a) as t",[{'key': [1, 2, 3, 4], 'value': [10, 9, 8, 7]}]) - with pytest.raises(duckdb.InvalidInputException, match="Map keys have to be unique"): - compare_results("SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4,2, NULL),LIST_VALUE(10, 9, 8, 7,11,42)) as a) as t",[{'key': [1, 2, 3, 4, 2, None], 'value': [10, 9, 8, 7, 11, 42]}]) - - compare_results("SELECT a from (select MAP(LIST_VALUE(),LIST_VALUE()) as a) as t",[{'key': [], 'value': []}]) + compare_results( + "SELECT a from (select MAP(LIST_VALUE(1, 2, 3, 4,2, NULL),LIST_VALUE(10, 9, 8, 7,11,42)) as a) as t", + [{'key': [1, 2, 3, 4, 2, None], 'value': [10, 9, 8, 7, 11, 42]}], + ) + + compare_results("SELECT a from (select MAP(LIST_VALUE(),LIST_VALUE()) as a) as t", [{'key': [], 'value': []}]) with pytest.raises(duckdb.InvalidInputException, match="Map keys have to be unique"): - compare_results("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D','Jon Lajoie' ),LIST_VALUE(10,9,10,11)) as a) as t", [{'key': ['Jon Lajoie', 'Backstreet Boys', 'Tenacious D', 'Jon Lajoie'], 'value': [10, 9, 10, 11]}]) + compare_results( + "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D','Jon Lajoie' ),LIST_VALUE(10,9,10,11)) as a) as t", + [{'key': ['Jon Lajoie', 'Backstreet Boys', 'Tenacious D', 'Jon Lajoie'], 'value': [10, 9, 10, 11]}], + ) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): - compare_results("SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', NULL, 'Tenacious D',NULL,NULL ),LIST_VALUE(10,9,10,11,13)) as a) as t", [{'key': ['Jon Lajoie', None, 'Tenacious D', None, None], 'value': [10, 9, 10, 11, 13]}]) + compare_results( + "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', NULL, 'Tenacious D',NULL,NULL ),LIST_VALUE(10,9,10,11,13)) as a) as t", + [{'key': ['Jon Lajoie', None, 'Tenacious D', None, None], 'value': [10, 9, 10, 11, 13]}], + ) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): - compare_results("SELECT a from (select MAP(LIST_VALUE(NULL, NULL, NULL,NULL,NULL ),LIST_VALUE(10,9,10,11,13)) as a) as t",[{'key': [None, None, None, None, None], 'value': [10, 9, 10, 11, 13]}]) + compare_results( + "SELECT a from (select MAP(LIST_VALUE(NULL, NULL, NULL,NULL,NULL ),LIST_VALUE(10,9,10,11,13)) as a) as t", + [{'key': [None, None, None, None, None], 'value': [10, 9, 10, 11, 13]}], + ) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): - compare_results("SELECT a from (select MAP(LIST_VALUE(NULL, NULL, NULL,NULL,NULL ),LIST_VALUE(NULL, NULL, NULL,NULL,NULL )) as a) as t", [{'key': [None, None, None, None, None], 'value': [None, None, None, None, None]}]) - - compare_results("SELECT m as a from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(m)", [{'key': [1], 'value': [2]}, {'key': [1], 'value': [2]}, {'key': [1], 'value': [2]}, {'key': [1], 'value': [2]}, {'key': [1], 'value': [2]}]) - - compare_results("SELECT m as a from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10) tbl(i) group by i%5) as lst_tbl) as T", [{'key': [0, 5], 'value': [0, 5]}, {'key': [1, 6], 'value': [1, 6]}, {'key': [2, 7], 'value': [2, 7]}, {'key': [3, 8], 'value': [3, 8]}, {'key': [4, 9], 'value': [4, 9]}]) - - def test_nested_mix(self,duckdb_cursor): + compare_results( + "SELECT a from (select MAP(LIST_VALUE(NULL, NULL, NULL,NULL,NULL ),LIST_VALUE(NULL, NULL, NULL,NULL,NULL )) as a) as t", + [{'key': [None, None, None, None, None], 'value': [None, None, None, None, None]}], + ) + + compare_results( + "SELECT m as a from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(m)", + [ + {'key': [1], 'value': [2]}, + {'key': [1], 'value': [2]}, + {'key': [1], 'value': [2]}, + {'key': [1], 'value': [2]}, + {'key': [1], 'value': [2]}, + ], + ) + + compare_results( + "SELECT m as a from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10) tbl(i) group by i%5) as lst_tbl) as T", + [ + {'key': [0, 5], 'value': [0, 5]}, + {'key': [1, 6], 'value': [1, 6]}, + {'key': [2, 7], 'value': [2, 7]}, + {'key': [3, 8], 'value': [3, 8]}, + {'key': [4, 9], 'value': [4, 9]}, + ], + ) + + def test_nested_mix(self, duckdb_cursor): # List of structs W/ Struct that is NULL entirely - compare_results("SELECT [{'i':1,'j':2},NULL,{'i':2,'j':NULL}] as a", [[{'i': 1, 'j': 2}, None, {'i': 2, 'j': None}]]) + compare_results( + "SELECT [{'i':1,'j':2},NULL,{'i':2,'j':NULL}] as a", [[{'i': 1, 'j': 2}, None, {'i': 2, 'j': None}]] + ) # Lists of structs with lists compare_results("SELECT [{'i':1,'j':[2,3]},NULL] as a", [[{'i': 1, 'j': [2, 3]}, None]]) - + # Maps embedded in a struct - compare_results("SELECT {'i':mp,'j':mp2} as a FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", [{'i': {'key': [1, 2, 3, 4], 'value': [10, 9, 8, 7]}, 'j': {'key': [1, 2, 3, 5], 'value': [10, 9, 8, 7]}}]) + compare_results( + "SELECT {'i':mp,'j':mp2} as a FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", + [{'i': {'key': [1, 2, 3, 4], 'value': [10, 9, 8, 7]}, 'j': {'key': [1, 2, 3, 5], 'value': [10, 9, 8, 7]}}], + ) - # List of maps - compare_results("SELECT [mp,mp2] as a FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", [[{'key': [1, 2, 3, 4], 'value': [10, 9, 8, 7]}, {'key': [1, 2, 3, 5], 'value': [10, 9, 8, 7]}]]) + # List of maps + compare_results( + "SELECT [mp,mp2] as a FROM (SELECT MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as mp, MAP(LIST_VALUE(1, 2, 3, 5),LIST_VALUE(10, 9, 8, 7)) as mp2) as t", + [[{'key': [1, 2, 3, 4], 'value': [10, 9, 8, 7]}, {'key': [1, 2, 3, 5], 'value': [10, 9, 8, 7]}]], + ) # Map with list as key and/or value - compare_results("SELECT MAP(LIST_VALUE([1,2],[3,4],[5,4]),LIST_VALUE([1,2],[3,4],[5,4])) as a", [{'key': [[1, 2], [3, 4], [5, 4]], 'value': [[1, 2], [3, 4], [5, 4]]}]) + compare_results( + "SELECT MAP(LIST_VALUE([1,2],[3,4],[5,4]),LIST_VALUE([1,2],[3,4],[5,4])) as a", + [{'key': [[1, 2], [3, 4], [5, 4]], 'value': [[1, 2], [3, 4], [5, 4]]}], + ) # Map with struct as key and/or value - compare_results("SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a", [{'key': [{'i': 1, 'j': 2}, {'i': 3, 'j': 4}], 'value': [{'i': 1, 'j': 2}, {'i': 3, 'j': 4}]}]) - - + compare_results( + "SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a", + [{'key': [{'i': 1, 'j': 2}, {'i': 3, 'j': 4}], 'value': [{'i': 1, 'j': 2}, {'i': 3, 'j': 4}]}], + ) # Null checks on lists with structs - compare_results("SELECT [{'i':1,'j':[2,3]},NULL,{'i':1,'j':[2,3]}] as a",[[{'i': 1, 'j': [2, 3]}, None, {'i': 1, 'j': [2, 3]}]]) + compare_results( + "SELECT [{'i':1,'j':[2,3]},NULL,{'i':1,'j':[2,3]}] as a", + [[{'i': 1, 'j': [2, 3]}, None, {'i': 1, 'j': [2, 3]}]], + ) # Struct that is NULL entirely df_duck = duckdb.query("SELECT col0 as a FROM (VALUES ({'i':1,'j':2}), (NULL), ({'i':1,'j':2}), (NULL))").df() @@ -113,9 +214,11 @@ def test_nested_mix(self,duckdb_cursor): assert np.isnan(duck_values[3]) # MAP that is NULL entirely - df_duck = duckdb.query("SELECT col0 as a FROM (VALUES (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))),(NULL), (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))), (NULL))").df() + df_duck = duckdb.query( + "SELECT col0 as a FROM (VALUES (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))),(NULL), (MAP(LIST_VALUE(1,2),LIST_VALUE(3,4))), (NULL))" + ).df() duck_values = df_duck['a'] assert duck_values[0] == {'key': [1, 2], 'value': [3, 4]} assert np.isnan(duck_values[1]) assert duck_values[2] == {'key': [1, 2], 'value': [3, 4]} - assert np.isnan(duck_values[3]) \ No newline at end of file + assert np.isnan(duck_values[3]) diff --git a/tools/pythonpkg/tests/fast/pandas/test_implicit_pandas_scan.py b/tools/pythonpkg/tests/fast/pandas/test_implicit_pandas_scan.py index 0b41f9a17bd1..e6f0b9f40726 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_implicit_pandas_scan.py +++ b/tools/pythonpkg/tests/fast/pandas/test_implicit_pandas_scan.py @@ -6,10 +6,11 @@ from conftest import NumpyPandas, ArrowPandas from packaging.version import Version -numpy_nullable_df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05},{"COL1": "val4", "CoL2": 17}]) +numpy_nullable_df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val4", "CoL2": 17}]) try: from pandas.compat import pa_version_under7p0 + pyarrow_dtypes_enabled = not pa_version_under7p0 except: pyarrow_dtypes_enabled = False @@ -20,11 +21,12 @@ # dtype_backend is not supported in pandas < 2.0.0 pyarrow_df = numpy_nullable_df + class TestImplicitPandasScan(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_local_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() - df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05},{"COL1": "val3", "CoL2": 17}]) + df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) r1 = con.execute('select * from df').fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" diff --git a/tools/pythonpkg/tests/fast/pandas/test_issue_1767.py b/tools/pythonpkg/tests/fast/pandas/test_issue_1767.py index 828510693299..e37f19e1d9d0 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_issue_1767.py +++ b/tools/pythonpkg/tests/fast/pandas/test_issue_1767.py @@ -6,9 +6,9 @@ import pytest from conftest import NumpyPandas, ArrowPandas + # Join from pandas not matching identical strings #1767 class TestIssue1767(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_unicode_join_pandas(self, duckdb_cursor, pandas): A = pandas.DataFrame({"key": ["a", "п"]}) diff --git a/tools/pythonpkg/tests/fast/pandas/test_limit.py b/tools/pythonpkg/tests/fast/pandas/test_limit.py index aa7cd4da2314..2d33126c3a44 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_limit.py +++ b/tools/pythonpkg/tests/fast/pandas/test_limit.py @@ -2,16 +2,24 @@ import pytest from conftest import NumpyPandas, ArrowPandas -class TestLimitPandas(object): +class TestLimitPandas(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_limit_df(self, duckdb_cursor, pandas): - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) - limit_df = duckdb.limit(df_in,2) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) + limit_df = duckdb.limit(df_in, 2) assert len(limit_df.execute().fetchall()) == 2 @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_aggregate_df(self, duckdb_cursor, pandas): - df_in = pandas.DataFrame({'numbers': [1,2,2,2],}) - aggregate_df = duckdb.aggregate(df_in,'count(numbers)','numbers') - assert aggregate_df.execute().fetchall() == [(1,), (3,)] \ No newline at end of file + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 2, 2], + } + ) + aggregate_df = duckdb.aggregate(df_in, 'count(numbers)', 'numbers') + assert aggregate_df.execute().fetchall() == [(1,), (3,)] diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_arrow.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_arrow.py index e88f8d0d8500..d4b29473ad0a 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_arrow.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_arrow.py @@ -3,53 +3,57 @@ import datetime from conftest import pandas_supports_arrow_backend + pd = pytest.importorskip("pandas", '2.0.0') import numpy as np + @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") class TestPandasArrow(object): def test_pandas_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': pd.Series([5,4,3])}).convert_dtypes() + df = pd.DataFrame({'a': pd.Series([5, 4, 3])}).convert_dtypes() con = duckdb.connect() res = con.sql("select * from df").fetchall() - assert res == [(5,),(4,),(3,)] + assert res == [(5,), (4,), (3,)] def test_mixed_columns(self): - df = pd.DataFrame({ - 'strings': pd.Series([ - 'abc', - 'DuckDB', - 'quack', - 'quack' - ]), - 'timestamps': pd.Series([ - datetime.datetime(1990, 10, 21), - datetime.datetime(2023, 1, 11), - datetime.datetime(2001, 2, 5), - datetime.datetime(1990, 10, 21), - ]), - 'objects': pd.Series([ - [5,4,3], - 'test', - None, - {'a': 42} - ]), - 'integers': np.ndarray((4,), buffer=np.array([1,2,3,4,5]), offset=np.int_().itemsize, dtype=int) - }) + df = pd.DataFrame( + { + 'strings': pd.Series(['abc', 'DuckDB', 'quack', 'quack']), + 'timestamps': pd.Series( + [ + datetime.datetime(1990, 10, 21), + datetime.datetime(2023, 1, 11), + datetime.datetime(2001, 2, 5), + datetime.datetime(1990, 10, 21), + ] + ), + 'objects': pd.Series([[5, 4, 3], 'test', None, {'a': 42}]), + 'integers': np.ndarray((4,), buffer=np.array([1, 2, 3, 4, 5]), offset=np.int_().itemsize, dtype=int), + } + ) pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') con = duckdb.connect() with pytest.raises(duckdb.InvalidInputException, match='Conversion failed for column objects with type object'): res = con.sql('select * from pyarrow_df').fetchall() - numpy_df = pd.DataFrame({'a': np.ndarray((2,), buffer=np.array([1,2,3]), offset=np.int_().itemsize, dtype=int)}).convert_dtypes(dtype_backend='numpy_nullable') - arrow_df = pd.DataFrame({'a': pd.Series([ - datetime.datetime(1990, 10, 21), - datetime.datetime(2023, 1, 11), - datetime.datetime(2001, 2, 5), - datetime.datetime(1990, 10, 21), - ])}).convert_dtypes(dtype_backend='pyarrow') - python_df = pd.DataFrame({'a': pd.Series(['test', [5,4,3], {'a': 42}])}).convert_dtypes() + numpy_df = pd.DataFrame( + {'a': np.ndarray((2,), buffer=np.array([1, 2, 3]), offset=np.int_().itemsize, dtype=int)} + ).convert_dtypes(dtype_backend='numpy_nullable') + arrow_df = pd.DataFrame( + { + 'a': pd.Series( + [ + datetime.datetime(1990, 10, 21), + datetime.datetime(2023, 1, 11), + datetime.datetime(2001, 2, 5), + datetime.datetime(1990, 10, 21), + ] + ) + } + ).convert_dtypes(dtype_backend='pyarrow') + python_df = pd.DataFrame({'a': pd.Series(['test', [5, 4, 3], {'a': 42}])}).convert_dtypes() df = pd.concat([numpy_df['a'], arrow_df['a'], python_df['a']], axis=1, keys=['numpy', 'arrow', 'python']) assert isinstance(df.dtypes[0], pd.core.arrays.integer.IntegerDtype) @@ -60,33 +64,39 @@ def test_mixed_columns(self): res = con.sql('select * from df').fetchall() def test_empty_df(self): - df = pd.DataFrame({ - 'string' : pd.Series(data=[], dtype='string'), - 'object' : pd.Series(data=[], dtype='object'), - 'Int64' : pd.Series(data=[], dtype='Int64'), - 'Float64' : pd.Series(data=[], dtype='Float64'), - 'bool' : pd.Series(data=[], dtype='bool'), - 'datetime64[ns]' : pd.Series(data=[], dtype='datetime64[ns]'), - 'datetime64[ms]' : pd.Series(data=[], dtype='datetime64[ms]'), - 'datetime64[us]' : pd.Series(data=[], dtype='datetime64[us]'), - 'datetime64[s]' : pd.Series(data=[], dtype='datetime64[s]'), - 'category' : pd.Series(data=[], dtype='category'), - 'timedelta64[ns]' : pd.Series(data=[], dtype='timedelta64[ns]'), - }) + df = pd.DataFrame( + { + 'string': pd.Series(data=[], dtype='string'), + 'object': pd.Series(data=[], dtype='object'), + 'Int64': pd.Series(data=[], dtype='Int64'), + 'Float64': pd.Series(data=[], dtype='Float64'), + 'bool': pd.Series(data=[], dtype='bool'), + 'datetime64[ns]': pd.Series(data=[], dtype='datetime64[ns]'), + 'datetime64[ms]': pd.Series(data=[], dtype='datetime64[ms]'), + 'datetime64[us]': pd.Series(data=[], dtype='datetime64[us]'), + 'datetime64[s]': pd.Series(data=[], dtype='datetime64[s]'), + 'category': pd.Series(data=[], dtype='category'), + 'timedelta64[ns]': pd.Series(data=[], dtype='timedelta64[ns]'), + } + ) pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') con = duckdb.connect() res = con.sql('select * from pyarrow_df').fetchall() assert res == [] - + def test_completely_null_df(self): - df = pd.DataFrame({ - 'a' : pd.Series(data=[ - None, - np.nan, - pd.NA, - ]) - }) + df = pd.DataFrame( + { + 'a': pd.Series( + data=[ + None, + np.nan, + pd.NA, + ] + ) + } + ) pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') con = duckdb.connect() @@ -94,53 +104,28 @@ def test_completely_null_df(self): assert res == [(None,), (None,), (None,)] def test_mixed_nulls(self): - df = pd.DataFrame({ - 'float': pd.Series(data=[ - 4.123123, - None, - 7.23456 - ], dtype='Float64'), - 'int64': pd.Series(data=[ - -234234124, - 709329413, - pd.NA - ], dtype='Int64'), - 'bool': pd.Series(data=[ - np.nan, - True, - False - ], dtype='boolean'), - 'string': pd.Series(data=[ - 'NULL', - None, - 'quack' - ]), - 'list[str]': pd.Series(data=[ - [ - 'Huey', - 'Dewey', - 'Louie' - ], - [ - None, - pd.NA, - np.nan, - 'DuckDB' - ], - None - ]), - 'datetime64' : pd.Series(data=[ - datetime.datetime(2011, 8, 16, 22, 7, 8), - None, - datetime.datetime(2010, 4, 26, 18, 14, 14) - ]), - 'date' : pd.Series(data=[ - datetime.date(2008, 5, 28), - datetime.date(2013, 7, 14), - None - ]), - }) + df = pd.DataFrame( + { + 'float': pd.Series(data=[4.123123, None, 7.23456], dtype='Float64'), + 'int64': pd.Series(data=[-234234124, 709329413, pd.NA], dtype='Int64'), + 'bool': pd.Series(data=[np.nan, True, False], dtype='boolean'), + 'string': pd.Series(data=['NULL', None, 'quack']), + 'list[str]': pd.Series(data=[['Huey', 'Dewey', 'Louie'], [None, pd.NA, np.nan, 'DuckDB'], None]), + 'datetime64': pd.Series( + data=[datetime.datetime(2011, 8, 16, 22, 7, 8), None, datetime.datetime(2010, 4, 26, 18, 14, 14)] + ), + 'date': pd.Series(data=[datetime.date(2008, 5, 28), datetime.date(2013, 7, 14), None]), + } + ) pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') con = duckdb.connect() res = con.sql('select * from pyarrow_df').fetchone() - assert res == (4.123123, -234234124, None, 'NULL', ['Huey', 'Dewey', 'Louie'], datetime.datetime(2011, 8, 16, 22, 7, 8), datetime.date(2008, 5, 28)) + assert res == ( + 4.123123, + -234234124, + None, + 'NULL', + ['Huey', 'Dewey', 'Louie'], + datetime.datetime(2011, 8, 16, 22, 7, 8), + datetime.date(2008, 5, 28), + ) diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_category.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_category.py index 7d141a864e37..848412a99108 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_category.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_category.py @@ -3,25 +3,27 @@ import numpy import pytest + def check_category_equal(category): - df_in = pd.DataFrame({ - 'x': pd.Categorical(category, ordered=True), - }) + df_in = pd.DataFrame( + { + 'x': pd.Categorical(category, ordered=True), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() assert df_in.equals(df_out) -def check_result_list(category,res): - for i in range (len(category)): + +def check_result_list(category, res): + for i in range(len(category)): assert category[i][0] == res[i] + def check_create_table(category): conn = duckdb.connect() - conn.execute ("PRAGMA enable_verification") - df_in = pd.DataFrame({ - 'x': pd.Categorical(category, ordered=True), - 'y': pd.Categorical(category, ordered=True) - }) + conn.execute("PRAGMA enable_verification") + df_in = pd.DataFrame({'x': pd.Categorical(category, ordered=True), 'y': pd.Categorical(category, ordered=True)}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() assert df_in.equals(df_out) @@ -30,101 +32,112 @@ def check_create_table(category): conn.execute("CREATE TABLE t2 AS SELECT * FROM df_in") # Check fetchall - res = conn.execute("SELECT t1.x FROM t1").fetchall() - check_result_list(res,category) + res = conn.execute("SELECT t1.x FROM t1").fetchall() + check_result_list(res, category) - # Do a insert to trigger string -> cat + # Do a insert to trigger string -> cat conn.execute("INSERT INTO t1 VALUES ('2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() assert res == [('1',)] - res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x)").fetchall() + res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x)").fetchall() assert res == conn.execute("SELECT x FROM t1").fetchall() - + res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.y)").fetchall() assert res == conn.execute("SELECT x FROM t1").fetchall() - assert res == conn.execute("SELECT x FROM t1").fetchall() # Triggering the cast with ENUM as a src conn.execute("ALTER TABLE t1 ALTER x SET DATA TYPE VARCHAR") # We should be able to drop the table without any dependencies conn.execute("DROP TABLE t1") -class TestCategory(object): +class TestCategory(object): def test_category_simple(self, duckdb_cursor): - df_in = pd.DataFrame({ - 'float': [1.0, 2.0, 1.0], - 'int': pd.Series([1, 2, 1], dtype="category") - }) + df_in = pd.DataFrame({'float': [1.0, 2.0, 1.0], 'int': pd.Series([1, 2, 1], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - print (duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) - print (df_out['int']) + print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) + print(df_out['int']) assert numpy.all(df_out['float'] == numpy.array([1.0, 2.0, 1.0])) assert numpy.all(df_out['int'] == numpy.array([1, 2, 1])) def test_category_nulls(self, duckdb_cursor): - df_in = pd.DataFrame({ - 'int': pd.Series([1, 2, None], dtype="category") - }) + df_in = pd.DataFrame({'int': pd.Series([1, 2, None], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - print (duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) + print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) assert df_out['int'][0] == 1 assert df_out['int'][1] == 2 assert numpy.isnan(df_out['int'][2]) def test_category_string(self, duckdb_cursor): - check_category_equal(['foo','bla','zoo', 'foo', 'foo', 'bla']) + check_category_equal(['foo', 'bla', 'zoo', 'foo', 'foo', 'bla']) def test_category_string_null(self, duckdb_cursor): - check_category_equal(['foo','bla',None,'zoo', 'foo', 'foo',None, 'bla']) + check_category_equal(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla']) def test_category_string_null_bug_4747(self, duckdb_cursor): check_category_equal([str(i) for i in range(160)] + [None]) def test_categorical_fetchall(self, duckdb_cursor): - df_in = pd.DataFrame({ - 'x': pd.Categorical(['foo','bla',None,'zoo', 'foo', 'foo',None, 'bla'], ordered=True), - }) - assert duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall() == [('foo',), ('bla',), (None,), ('zoo',), ('foo',), ('foo',), (None,), ('bla',)] - + df_in = pd.DataFrame( + { + 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + } + ) + assert duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall() == [ + ('foo',), + ('bla',), + (None,), + ('zoo',), + ('foo',), + ('foo',), + (None,), + ('bla',), + ] + def test_category_string_uint8(self, duckdb_cursor): category = [] - for i in range (10): + for i in range(10): category.append(str(i)) check_create_table(category) def test_category_fetch_df_chunk(self, duckdb_cursor): con = duckdb.connect() - categories = ['foo','bla',None,'zoo', 'foo', 'foo',None, 'bla'] - result = categories*256 + categories = ['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'] + result = categories * 256 categories = result * 2 - df_result = pd.DataFrame({ - 'x': pd.Categorical(result, ordered=True), - }) - df_in = pd.DataFrame({ - 'x': pd.Categorical(categories, ordered=True), - }) + df_result = pd.DataFrame( + { + 'x': pd.Categorical(result, ordered=True), + } + ) + df_in = pd.DataFrame( + { + 'x': pd.Categorical(categories, ordered=True), + } + ) con.register("data", df_in) query = con.execute("SELECT * FROM data") cur_chunk = query.fetch_df_chunk() - assert(cur_chunk.equals(df_result)) + assert cur_chunk.equals(df_result) cur_chunk = query.fetch_df_chunk() - assert(cur_chunk.equals(df_result)) + assert cur_chunk.equals(df_result) cur_chunk = query.fetch_df_chunk() - assert(cur_chunk.empty) + assert cur_chunk.empty def test_category_mix(self, duckdb_cursor): - df_in = pd.DataFrame({ - 'float': [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], - 'x': pd.Categorical(['foo','bla',None,'zoo', 'foo', 'foo',None, 'bla'], ordered=True), - }) + df_in = pd.DataFrame( + { + 'float': [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], + 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() assert df_out.equals(df_in) diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_enum.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_enum.py index 46e45dd9bebc..9dc13a642ceb 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_enum.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_enum.py @@ -2,10 +2,11 @@ import pytest import duckdb + class TestPandasEnum(object): def test_3480(self, duckdb_cursor): duckdb_cursor.execute( - """ + """ create type cat as enum ('marie', 'duchess', 'toulouse'); create table tab ( cat cat, @@ -20,7 +21,7 @@ def test_3480(self, duckdb_cursor): def test_3479(self, duckdb_cursor): duckdb_cursor.execute( - """ + """ create type cat as enum ('marie', 'duchess', 'toulouse'); create table tab ( cat cat, @@ -29,12 +30,19 @@ def test_3479(self, duckdb_cursor): """ ) - df = pd.DataFrame({"cat2": pd.Series(['duchess', 'toulouse', 'marie', None, "berlioz", "o_malley"], dtype="category"), "amt": [1, 2, 3, 4, 5, 6]}) + df = pd.DataFrame( + { + "cat2": pd.Series(['duchess', 'toulouse', 'marie', None, "berlioz", "o_malley"], dtype="category"), + "amt": [1, 2, 3, 4, 5, 6], + } + ) duckdb_cursor.register('df', df) - with pytest.raises(duckdb.ConversionException, match='Type UINT8 with value 0 can\'t be cast because the value is out of range for the destination type UINT8'): + with pytest.raises( + duckdb.ConversionException, + match='Type UINT8 with value 0 can\'t be cast because the value is out of range for the destination type UINT8', + ): duckdb_cursor.execute(f"INSERT INTO tab SELECT * FROM df;") assert duckdb_cursor.execute("select * from tab").fetchall() == [] duckdb_cursor.execute("DROP TABLE tab") duckdb_cursor.execute("DROP TYPE cat") - diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_limit.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_limit.py index fcb2d91f0db5..506d5dd5b2b1 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_limit.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_limit.py @@ -2,6 +2,7 @@ import pandas as pd import pytest + class TestPandasLimit(object): def test_pandas_limit(self, duckdb_cursor): con = duckdb.connect() diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_na.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_na.py index 70c3a056a88e..f76fd98077dd 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_na.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_na.py @@ -3,12 +3,14 @@ import duckdb import pytest + def assert_nullness(items, null_indices): for i in range(len(items)): if i in null_indices: - assert(items[i] == None) + assert items[i] == None else: - assert(items[i] != None) + assert items[i] != None + class TestPandasNA(object): def test_pandas_na(self, duckdb_cursor): @@ -19,20 +21,46 @@ def test_pandas_na(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("select * from df").fetchall() - assert(res[0][0] == None) + assert res[0][0] == None # DataFrame containing multiple values, with a pd.NA mixed in null_index = 3 - df = pd.DataFrame(pd.Series([3,1,2,pd.NA,8,6])) + df = pd.DataFrame(pd.Series([3, 1, 2, pd.NA, 8, 6])) res = conn.execute("select * from df").fetchall() items = [x[0] for x in [y for y in res]] assert_nullness(items, [null_index]) # Test if pd.NA behaves the same as np.NaN once converted - nan_df = pd.DataFrame({'a': [1.123, 5.23234, np.NaN, 7234.0000124, 0.000000124, 0000000000000.0000001, np.NaN, -2342349234.00934580345]}) - na_df = pd.DataFrame({'a': [1.123, 5.23234, pd.NA, 7234.0000124, 0.000000124, 0000000000000.0000001, pd.NA, -2342349234.00934580345]}) - assert(str(nan_df['a'].dtype) == 'float64') - assert(str(na_df['a'].dtype) == 'object') # pd.NA values turn the column into 'object' + nan_df = pd.DataFrame( + { + 'a': [ + 1.123, + 5.23234, + np.NaN, + 7234.0000124, + 0.000000124, + 0000000000000.0000001, + np.NaN, + -2342349234.00934580345, + ] + } + ) + na_df = pd.DataFrame( + { + 'a': [ + 1.123, + 5.23234, + pd.NA, + 7234.0000124, + 0.000000124, + 0000000000000.0000001, + pd.NA, + -2342349234.00934580345, + ] + } + ) + assert str(nan_df['a'].dtype) == 'float64' + assert str(na_df['a'].dtype) == 'object' # pd.NA values turn the column into 'object' nan_result = conn.execute("select * from nan_df").df() na_result = conn.execute("select * from na_df").df() @@ -40,7 +68,7 @@ def test_pandas_na(self, duckdb_cursor): # Mixed with stringified pd.NA values na_string_df = pd.DataFrame({'a': [str(pd.NA), str(pd.NA), pd.NA, str(pd.NA), pd.NA, pd.NA, pd.NA, str(pd.NA)]}) - null_indices = [2,4,5,6] + null_indices = [2, 4, 5, 6] res = conn.execute("select * from na_string_df").fetchall() items = [x[0] for x in [y for y in res]] assert_nullness(items, null_indices) diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_object.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_object.py index 8f1aa2e88ad6..35e20074d103 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_object.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_object.py @@ -4,73 +4,91 @@ import numpy as np import random -class TestPandasObject(object): - def test_object_to_string(self, duckdb_cursor): +class TestPandasObject(object): + def test_object_to_string(self, duckdb_cursor): con = duckdb.connect(database=':memory:', read_only=False) - x = pd.DataFrame( - [ - [1, 'a', 2], - [1, None, 2], - [1, 1.1, 2], - [1, 1.1, 2], - [1, 1.1, 2] - ] - ) - x = x.iloc[1:].copy() # middle col now entirely native float items + x = pd.DataFrame([[1, 'a', 2], [1, None, 2], [1, 1.1, 2], [1, 1.1, 2], [1, 1.1, 2]]) + x = x.iloc[1:].copy() # middle col now entirely native float items con.register('view2', x) df = con.execute('select * from view2').fetchall() - assert df == [(1, None, 2),(1, 1.1, 2), (1, 1.1, 2), (1, 1.1, 2)] + assert df == [(1, None, 2), (1, 1.1, 2), (1, 1.1, 2), (1, 1.1, 2)] def test_tuple_to_list(self, duckdb_cursor): - tuple_df = pd.DataFrame.from_dict(dict(nums=[(1,2,3,),(4,5,6,)])) - duckdb_cursor.execute("CREATE TABLE test as SELECT * FROM tuple_df"); + tuple_df = pd.DataFrame.from_dict( + dict( + nums=[ + ( + 1, + 2, + 3, + ), + ( + 4, + 5, + 6, + ), + ] + ) + ) + duckdb_cursor.execute("CREATE TABLE test as SELECT * FROM tuple_df") res = duckdb_cursor.table('test').fetchall() assert res == [([1, 2, 3],), ([4, 5, 6],)] - def test_2273(self, duckdb_cursor): + def test_2273(self, duckdb_cursor): df_in = pd.DataFrame([[datetime.date(1992, 7, 30)]]) assert duckdb.query("Select * from df_in").fetchall() == [(datetime.date(1992, 7, 30),)] def test_object_to_string_with_stride(self, duckdb_cursor): - data = np.array([["a", "b", "c"], [1,2,3], [1, 2, 3], [11, 22, 33]]) + data = np.array([["a", "b", "c"], [1, 2, 3], [1, 2, 3], [11, 22, 33]]) df = pd.DataFrame(data=data[1:,], columns=data[0]) duckdb_cursor.register("object_with_strides", df) res = duckdb_cursor.sql('select * from object_with_strides').fetchall() assert res == [('1', '2', '3'), ('1', '2', '3'), ('11', '22', '33')] - def test_2499(self, duckdb_cursor): + def test_2499(self, duckdb_cursor): df = pd.DataFrame( [ [ - np.array([ + np.array( + [ {'a': 0.881040697801939}, {'a': 0.9922600577751953}, {'a': 0.1589674833259317}, {'a': 0.8928451262745073}, - {'a': 0.07022897889168278} - ], dtype=object) + {'a': 0.07022897889168278}, + ], + dtype=object, + ) ], [ - np.array([ + np.array( + [ {'a': 0.8759413504156746}, {'a': 0.055784331256246156}, {'a': 0.8605151517439655}, {'a': 0.40807139339337695}, - {'a': 0.8429048322459952} - ], dtype=object) + {'a': 0.8429048322459952}, + ], + dtype=object, + ) ], [ - np.array([ + np.array( + [ {'a': 0.9697093934032401}, {'a': 0.9529257667149468}, {'a': 0.21398182248591713}, {'a': 0.6328512122275955}, - {'a': 0.5146953214092728} - ], dtype=object) - ] - ], columns=['col']) + {'a': 0.5146953214092728}, + ], + dtype=object, + ) + ], + ], + columns=['col'], + ) con = duckdb.connect(database=':memory:', read_only=False) con.register('df', df) - assert (con.execute('select count(*) from df').fetchone() == (3,) ) + assert con.execute('select count(*) from df').fetchone() == (3,) diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_string.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_string.py index 8097c38a0ccb..494823ad6bbf 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_string.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_string.py @@ -2,14 +2,17 @@ import pandas as pd import numpy + class TestPandasString(object): def test_pandas_string(self, duckdb_cursor): strings = numpy.array(['foo', 'bar', 'baz']) # https://pandas.pydata.org/pandas-docs/stable/user_guide/text.html - df_in = pd.DataFrame({ - 'object': pd.Series(strings, dtype='object'), - }) + df_in = pd.DataFrame( + { + 'object': pd.Series(strings, dtype='object'), + } + ) # Only available in pandas 1.0.0 if hasattr(pd, 'StringDtype'): df_in['string'] = pd.Series(strings, dtype=pd.StringDtype()) @@ -27,10 +30,16 @@ def test_bug_2467(self, duckdb_cursor): # Copy Dataframe to DuckDB con = duckdb.connect() con.register("df", df) - con.execute(f""" + con.execute( + f""" CREATE TABLE t1 AS SELECT * FROM df """ ) - assert con.execute(f""" + assert ( + con.execute( + f""" SELECT count(*) from t1 - """).fetchall() == [(3000000,)] \ No newline at end of file + """ + ).fetchall() + == [(3000000,)] + ) diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_timestamp.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_timestamp.py index 29c02b5ca1af..5ab311d1f994 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_timestamp.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_timestamp.py @@ -2,14 +2,15 @@ import pandas as pd from pytest import mark + @mark.parametrize('timezone', ['UTC', 'CET', 'Asia/Kathmandu']) def run_pandas_with_tz(timezone): con = duckdb.connect() - con.execute("SET TimeZone = '"+timezone+"'") + con.execute("SET TimeZone = '" + timezone + "'") df = pd.DataFrame({"timestamp": [pd.Timestamp("2022-01-01 10:15", tz=timezone)]}) duck_df = con.from_df(df).df() - print (df['timestamp'].dtype) + print(df['timestamp'].dtype) print(duck_df['timestamp'].dtype) - print (df) + print(df) print(duck_df) - assert df.equals(duck_df) \ No newline at end of file + assert df.equals(duck_df) diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_types.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_types.py index 5b9819d3661c..124b296eb33e 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_types.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_types.py @@ -4,21 +4,23 @@ import string from packaging import version -def round_trip(data,pandas_type): - df_in = pd.DataFrame({ - 'object': pd.Series(data, dtype=pandas_type), - }) + +def round_trip(data, pandas_type): + df_in = pd.DataFrame( + { + 'object': pd.Series(data, dtype=pandas_type), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - print (df_out) - print (df_in) + print(df_out) + print(df_in) assert df_out.equals(df_in) + class TestNumpyNullableTypes(object): def test_pandas_numeric(self): - base_df = pd.DataFrame( - {'a':range(10)} - ) + base_df = pd.DataFrame({'a': range(10)}) data_types = [ "uint8", @@ -42,11 +44,8 @@ def test_pandas_numeric(self): ] if version.parse(pd.__version__) >= version.parse('1.2.0'): - # These DTypes where added in 1.2.0 - data_types.extend([ - "Float32", - "Float64" - ]) + # These DTypes where added in 1.2.0 + data_types.extend(["Float32", "Float64"]) # Generate a dataframe with all the types, in the form of: # b=type1, # c=type2 @@ -63,23 +62,25 @@ def test_pandas_numeric(self): # FIXME: we don't support outputting pandas specific types (i.e UInt64) for letter, item in zip(string.ascii_lowercase, data_types): column_name = letter - assert(str(out_df[column_name].dtype) == item.lower()) + assert str(out_df[column_name].dtype) == item.lower() def test_pandas_unsigned(self, duckdb_cursor): - unsigned_types = ['uint8','uint16','uint32','uint64'] - data = numpy.array([0,1,2,3]) + unsigned_types = ['uint8', 'uint16', 'uint32', 'uint64'] + data = numpy.array([0, 1, 2, 3]) for u_type in unsigned_types: - round_trip(data,u_type) + round_trip(data, u_type) def test_pandas_bool(self, duckdb_cursor): - data = numpy.array([True,False,False,True]) - round_trip(data,'bool') - + data = numpy.array([True, False, False, True]) + round_trip(data, 'bool') + def test_pandas_boolean(self, duckdb_cursor): - data = numpy.array([True,None,pd.NA,numpy.nan,True]) - df_in = pd.DataFrame({ - 'object': pd.Series(data, dtype='boolean'), - }) + data = numpy.array([True, None, pd.NA, numpy.nan, True]) + df_in = pd.DataFrame( + { + 'object': pd.Series(data, dtype='boolean'), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() assert df_out['object'][0] == df_in['object'][0] @@ -89,13 +90,15 @@ def test_pandas_boolean(self, duckdb_cursor): assert df_out['object'][4] == df_in['object'][4] def test_pandas_float32(self, duckdb_cursor): - data = numpy.array([0.1,0.32,0.78, numpy.nan]) - df_in = pd.DataFrame({ - 'object': pd.Series(data, dtype='float32'), - }) + data = numpy.array([0.1, 0.32, 0.78, numpy.nan]) + df_in = pd.DataFrame( + { + 'object': pd.Series(data, dtype='float32'), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - + assert df_out['object'][0] == df_in['object'][0] assert df_out['object'][1] == df_in['object'][1] assert df_out['object'][2] == df_in['object'][2] @@ -103,37 +106,39 @@ def test_pandas_float32(self, duckdb_cursor): def test_pandas_float64(self): data = numpy.array([0.233, numpy.nan, 3456.2341231, float('-inf'), -23424.45345, float('+inf'), 0.0000000001]) - df_in = pd.DataFrame({ - 'object': pd.Series(data, dtype='float64'), - }) + df_in = pd.DataFrame( + { + 'object': pd.Series(data, dtype='float64'), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - + for i in range(len(data)): - if (numpy.isnan(df_out['object'][i])): - assert(i == 1) + if numpy.isnan(df_out['object'][i]): + assert i == 1 continue assert df_out['object'][i] == df_in['object'][i] def test_pandas_interval(self, duckdb_cursor): - if pd. __version__ != '1.2.4': + if pd.__version__ != '1.2.4': return - - data = numpy.array([2069211000000000,numpy.datetime64("NaT")]) - df_in = pd.DataFrame({ - 'object': pd.Series(data, dtype='timedelta64[ns]'), - }) + + data = numpy.array([2069211000000000, numpy.datetime64("NaT")]) + df_in = pd.DataFrame( + { + 'object': pd.Series(data, dtype='timedelta64[ns]'), + } + ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - + assert df_out['object'][0] == df_in['object'][0] - assert pd.isnull(df_out['object'][1]) + assert pd.isnull(df_out['object'][1]) def test_pandas_encoded_utf8(self, duckdb_cursor): - data = u'\u00c3' # Unicode data + data = u'\u00c3' # Unicode data data = [data.encode('utf8')] expected_result = data[0] df_in = pd.DataFrame({'object': pd.Series(data, dtype='object')}) result = duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchone()[0] assert result == expected_result - - diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_unregister.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_unregister.py index c0a9770351e4..794e59100ad0 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_unregister.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_unregister.py @@ -6,6 +6,7 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestPandasUnregister(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister1(self, duckdb_cursor, pandas): diff --git a/tools/pythonpkg/tests/fast/pandas/test_pandas_update.py b/tools/pythonpkg/tests/fast/pandas/test_pandas_update.py index e0be4be2bbfe..d12fa72781e3 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pandas_update.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pandas_update.py @@ -1,12 +1,11 @@ import duckdb import pandas as pd -class TestPandasUpdateList(object): - def test_pandas_update_list(self, duckdb_cursor): - duckdb_cursor = duckdb.connect(':memory:') - duckdb_cursor.execute('create table t (l int[])') - duckdb_cursor.execute('insert into t values ([1, 2]), ([3,4])') - duckdb_cursor.execute('update t set l = [5, 6]') - assert duckdb_cursor.execute('select * from t').fetchdf()['l'].tolist() == [[5, 6], [5, 6]] - +class TestPandasUpdateList(object): + def test_pandas_update_list(self, duckdb_cursor): + duckdb_cursor = duckdb.connect(':memory:') + duckdb_cursor.execute('create table t (l int[])') + duckdb_cursor.execute('insert into t values ([1, 2]), ([3,4])') + duckdb_cursor.execute('update t set l = [5, 6]') + assert duckdb_cursor.execute('select * from t').fetchdf()['l'].tolist() == [[5, 6], [5, 6]] diff --git a/tools/pythonpkg/tests/fast/pandas/test_parallel_pandas_scan.py b/tools/pythonpkg/tests/fast/pandas/test_parallel_pandas_scan.py index 3fa7fb954459..a9fd99b9789c 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_parallel_pandas_scan.py +++ b/tools/pythonpkg/tests/fast/pandas/test_parallel_pandas_scan.py @@ -6,7 +6,8 @@ import pytest from conftest import NumpyPandas, ArrowPandas -def run_parallel_queries(main_table, left_join_table, expected_df, pandas, iteration_count = 5): + +def run_parallel_queries(main_table, left_join_table, expected_df, pandas, iteration_count=5): for i in range(0, iteration_count): output_df = None sql = """ @@ -33,54 +34,69 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera finally: duckdb_conn.close() + class TestParallelPandasScan(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_scan(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": 3}]) - left_join_table = pandas.DataFrame([{"join_column": 3,"other_column": 4}]) + left_join_table = pandas.DataFrame([{"join_column": 3, "other_column": 4}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_parallel_ascii_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column":"text"}]) - left_join_table = pandas.DataFrame([{"join_column":"text","other_column":"more text"}]) + main_table = pandas.DataFrame([{"join_column": "text"}]) + left_join_table = pandas.DataFrame([{"join_column": "text", "other_column": "more text"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_parallel_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column":u"mühleisen"}]) - left_join_table = pandas.DataFrame([{"join_column": u"mühleisen","other_column":u"höhöhö"}]) + main_table = pandas.DataFrame([{"join_column": u"mühleisen"}]) + left_join_table = pandas.DataFrame([{"join_column": u"mühleisen", "other_column": u"höhöhö"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_parallel_complex_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column":u"鴨"}]) - left_join_table = pandas.DataFrame([{"join_column": u"鴨","other_column":u"數據庫"}]) + main_table = pandas.DataFrame([{"join_column": u"鴨"}]) + left_join_table = pandas.DataFrame([{"join_column": u"鴨", "other_column": u"數據庫"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_parallel_emojis(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column":u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) - left_join_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️","other_column":u"🦆🍞🦆"}]) + main_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) + left_join_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": u"🦆🍞🦆"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_object(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({ 'join_column': pandas.Series([3], dtype="Int8") }) - left_join_table = pandas.DataFrame({ 'join_column': pandas.Series([3], dtype="Int8"), 'other_column': pandas.Series([4], dtype="Int8") }) - expected_df = pandas.DataFrame({ "join_column": numpy.array([3], dtype=numpy.int8), "other_column": numpy.array([4], dtype=numpy.int8)}) + main_table = pandas.DataFrame({'join_column': pandas.Series([3], dtype="Int8")}) + left_join_table = pandas.DataFrame( + {'join_column': pandas.Series([3], dtype="Int8"), 'other_column': pandas.Series([4], dtype="Int8")} + ) + expected_df = pandas.DataFrame( + {"join_column": numpy.array([3], dtype=numpy.int8), "other_column": numpy.array([4], dtype=numpy.int8)} + ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_parallel_timestamp(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({ 'join_column': [pandas.Timestamp('20180310T11:17:54Z')] }) - left_join_table = pandas.DataFrame({ 'join_column': [pandas.Timestamp('20180310T11:17:54Z')], 'other_column': [pandas.Timestamp('20190310T11:17:54Z')] }) - expected_df = pandas.DataFrame({ "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype='datetime64[ns]')}) + main_table = pandas.DataFrame({'join_column': [pandas.Timestamp('20180310T11:17:54Z')]}) + left_join_table = pandas.DataFrame( + { + 'join_column': [pandas.Timestamp('20180310T11:17:54Z')], + 'other_column': [pandas.Timestamp('20190310T11:17:54Z')], + } + ) + expected_df = pandas.DataFrame( + { + "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), + "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), + } + ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_parallel_empty(self,duckdb_cursor, pandas): - df_empty = pandas.DataFrame({'A' : []}) + def test_parallel_empty(self, duckdb_cursor, pandas): + df_empty = pandas.DataFrame({'A': []}) duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") diff --git a/tools/pythonpkg/tests/fast/pandas/test_partitioned_pandas_scan.py b/tools/pythonpkg/tests/fast/pandas/test_partitioned_pandas_scan.py index 83d2ee512c03..32c5352f50e4 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tools/pythonpkg/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -4,6 +4,7 @@ import datetime import time + class TestPartitionedPandasScan(object): def test_parallel_pandas(self, duckdb_cursor): con = duckdb.connect() diff --git a/tools/pythonpkg/tests/fast/pandas/test_progress_bar.py b/tools/pythonpkg/tests/fast/pandas/test_progress_bar.py index 2769b2e89ed6..241cedd6bfb4 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_progress_bar.py +++ b/tools/pythonpkg/tests/fast/pandas/test_progress_bar.py @@ -4,8 +4,8 @@ import datetime import time -class TestProgressBarPandas(object): +class TestProgressBarPandas(object): def test_progress_pandas_single(self, duckdb_cursor): con = duckdb.connect() df = pd.DataFrame({'i': numpy.arange(10000000)}) @@ -14,11 +14,10 @@ def test_progress_pandas_single(self, duckdb_cursor): con.register('df_2', df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") - result = con.execute("SELECT SUM(df.i) FROM df inner join df_2 on (df.i = df_2.i)").fetchall() + result = con.execute("SELECT SUM(df.i) FROM df inner join df_2 on (df.i = df_2.i)").fetchall() assert result[0][0] == 49999995000000 - - def test_progress_pandas_parallel(self,duckdb_cursor): + def test_progress_pandas_parallel(self, duckdb_cursor): con = duckdb.connect() df = pd.DataFrame({'i': numpy.arange(10000000)}) @@ -30,11 +29,11 @@ def test_progress_pandas_parallel(self,duckdb_cursor): parallel_results = con.execute("SELECT SUM(df.i) FROM df inner join df_2 on (df.i = df_2.i)").fetchall() assert parallel_results[0][0] == 49999995000000 - def test_progress_pandas_empty(self,duckdb_cursor): + def test_progress_pandas_empty(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i' : []}) + df = pd.DataFrame({'i': []}) con.register('df', df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") - result = con.execute("SELECT SUM(df.i) from df").fetchall() - assert result[0][0] == None \ No newline at end of file + result = con.execute("SELECT SUM(df.i) from df").fetchall() + assert result[0][0] == None diff --git a/tools/pythonpkg/tests/fast/pandas/test_pyarrow_filter_pushdown.py b/tools/pythonpkg/tests/fast/pandas/test_pyarrow_filter_pushdown.py index 3e5aa0b47711..1462a0dd4ecb 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pyarrow_filter_pushdown.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pyarrow_filter_pushdown.py @@ -4,6 +4,7 @@ import tempfile from conftest import pandas_supports_arrow_backend + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") @@ -15,14 +16,15 @@ ## DuckDB connection used in this test duckdb_conn = duckdb.connect() + def numeric_operators(data_type, tbl_name): duckdb_conn.execute(f"CREATE TABLE {tbl_name} (a {data_type}, b {data_type}, c {data_type})") - duckdb_conn.execute("INSERT INTO " +tbl_name+ " VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") + duckdb_conn.execute("INSERT INTO " + tbl_name + " VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table(tbl_name) arrow_df = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - print (arrow_df) + print(arrow_df) - duckdb_conn.register("testarrow",arrow_df) + duckdb_conn.register("testarrow", arrow_df) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a =1").fetchone()[0] == 1 # Try > @@ -49,40 +51,47 @@ def numeric_operators(data_type, tbl_name): duckdb_conn.execute("EXPLAIN SELECT count(*) from testarrow where a = 100 or b =1") print(duckdb_conn.fetchall()) + def numeric_check_or_pushdown(tbl_name): duck_tbl = duckdb_conn.table(tbl_name) arrow_df = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') arrow_tbl_name = "testarrow_" + tbl_name - duckdb_conn.register(arrow_tbl_name ,arrow_df) + duckdb_conn.register(arrow_tbl_name, arrow_df) # Multiple column in the root OR node, don't push down - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR b=2 AND (a>3 OR b<5)").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR b=2 AND (a>3 OR b<5)" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match # Single column in the root OR node - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR a=10").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR a=10").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a=10.*|$", query_res[0][1]) assert match # Single column + root OR node with AND - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR (a>3 AND a<5)").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR (a>3 AND a<5)" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a>3 AND a<5.*|$", query_res[0][1]) assert match # Single column multiple ORs - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR a>3 OR a<5").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR a>3 OR a<5").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a>3 OR a<5.*|$", query_res[0][1]) assert match # Testing not equal - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a!=1 OR a>3 OR a<2").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a!=1 OR a>3 OR a<2").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a!=1 OR a>3 OR a<2.*|$", query_res[0][1]) assert match # Multiple OR filters connected with ANDs - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE (a<2 OR a>3) AND (a=1 OR a=4) AND (b=1 OR b<5)").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE (a<2 OR a>3) AND (a=1 OR a=4) AND (b=1 OR b<5)" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters: a<2 OR a>3 AND a=1|\n.*OR a=4.*\n.*b=2 OR b<5.*|$", query_res[0][1]) assert match @@ -92,55 +101,73 @@ def string_check_or_pushdown(tbl_name): arrow_df = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') arrow_tbl_name = "testarrow_varchar" - duckdb_conn.register(arrow_tbl_name ,arrow_df) + duckdb_conn.register(arrow_tbl_name, arrow_df) # Check string zonemap - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a>='1' OR a<='10'").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a>='1' OR a<='10'").fetchall() match = re.search(".*ARROW_SCAN.*Filters: a>=1 OR a<=10.*|$", query_res[0][1]) assert match # No support for OR with is null - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a IS NULL or a='1'").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a IS NULL or a='1'").fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match # No support for OR with is not null - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a IS NOT NULL OR a='1'").fetchall() + query_res = duckdb_conn.execute( + "EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a IS NOT NULL OR a='1'" + ).fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match # OR with the like operator - query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " +arrow_tbl_name+ " WHERE a=1 OR a LIKE '10%'").fetchall() + query_res = duckdb_conn.execute("EXPLAIN SELECT * FROM " + arrow_tbl_name + " WHERE a=1 OR a LIKE '10%'").fetchall() match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) assert not match @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") class TestArrowDFFilterPushdown(object): - def test_filter_pushdown_numeric(self,duckdb_cursor): - - numeric_types = ['TINYINT', 'SMALLINT', 'INTEGER', 'BIGINT', 'UTINYINT', 'USMALLINT', 'UINTEGER', 'UBIGINT', - 'FLOAT', 'DOUBLE', 'HUGEINT'] + def test_filter_pushdown_numeric(self, duckdb_cursor): + numeric_types = [ + 'TINYINT', + 'SMALLINT', + 'INTEGER', + 'BIGINT', + 'UTINYINT', + 'USMALLINT', + 'UINTEGER', + 'UBIGINT', + 'FLOAT', + 'DOUBLE', + 'HUGEINT', + ] for data_type in numeric_types: tbl_name = "test_" + data_type numeric_operators(data_type, tbl_name) numeric_check_or_pushdown(tbl_name) - def test_filter_pushdown_decimal(self,duckdb_cursor): - numeric_types = {'DECIMAL(4,1)': 'test_decimal_4_1', 'DECIMAL(9,1)': 'test_decimal_9_1', - 'DECIMAL(18,4)': 'test_decimal_18_4','DECIMAL(30,12)': 'test_decimal_30_12'} + def test_filter_pushdown_decimal(self, duckdb_cursor): + numeric_types = { + 'DECIMAL(4,1)': 'test_decimal_4_1', + 'DECIMAL(9,1)': 'test_decimal_9_1', + 'DECIMAL(18,4)': 'test_decimal_18_4', + 'DECIMAL(30,12)': 'test_decimal_30_12', + } for data_type in numeric_types: tbl_name = numeric_types[data_type] numeric_operators(data_type, tbl_name) numeric_check_or_pushdown(tbl_name) - def test_filter_pushdown_varchar(self,duckdb_cursor): + def test_filter_pushdown_varchar(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_varchar (a VARCHAR, b VARCHAR, c VARCHAR)") - duckdb_conn.execute("INSERT INTO test_varchar VALUES ('1','1','1'),('10','10','10'),('100','10','100'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_varchar VALUES ('1','1','1'),('10','10','10'),('100','10','100'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_varchar") arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='1'").fetchone()[0] == 1 # Try > @@ -159,21 +186,25 @@ def test_filter_pushdown_varchar(self,duckdb_cursor): # Try And assert duckdb_conn.execute("SELECT count(*) from testarrow where a='10' and b ='1'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='100' and b = '10' and c = '100'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a ='100' and b = '10' and c = '100'").fetchone()[ + 0 + ] + == 1 + ) # Try Or assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '100' or b ='1'").fetchone()[0] == 2 # More complex tests for OR pushed down on string string_check_or_pushdown("test_varchar") - - def test_filter_pushdown_bool(self,duckdb_cursor): + def test_filter_pushdown_bool(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_bool (a BOOL, b BOOL)") duckdb_conn.execute("INSERT INTO test_bool VALUES (TRUE,TRUE),(TRUE,FALSE),(FALSE,TRUE),(NULL,NULL)") duck_tbl = duckdb_conn.table("test_bool") arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a =True").fetchone()[0] == 2 @@ -187,13 +218,15 @@ def test_filter_pushdown_bool(self,duckdb_cursor): # Try Or assert duckdb_conn.execute("SELECT count(*) from testarrow where a = True or b =True").fetchone()[0] == 3 - def test_filter_pushdown_time(self,duckdb_cursor): + def test_filter_pushdown_time(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_time (a TIME, b TIME, c TIME)") - duckdb_conn.execute("INSERT INTO test_time VALUES ('00:01:00','00:01:00','00:01:00'),('00:10:00','00:10:00','00:10:00'),('01:00:00','00:10:00','01:00:00'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_time VALUES ('00:01:00','00:01:00','00:01:00'),('00:10:00','00:10:00','00:10:00'),('01:00:00','00:10:00','01:00:00'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_time") arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='00:01:00'").fetchone()[0] == 1 # Try > @@ -211,19 +244,32 @@ def test_filter_pushdown_time(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='00:10:00' and b ='00:01:00'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='01:00:00' and b = '00:10:00' and c = '01:00:00'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a='00:10:00' and b ='00:01:00'").fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='01:00:00' and b = '00:10:00' and c = '01:00:00'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '01:00:00' or b ='00:01:00'").fetchone()[0] == 2 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a = '01:00:00' or b ='00:01:00'").fetchone()[0] + == 2 + ) - def test_filter_pushdown_timestamp(self,duckdb_cursor): + def test_filter_pushdown_timestamp(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_timestamp (a TIMESTAMP, b TIMESTAMP, c TIMESTAMP)") - duckdb_conn.execute("INSERT INTO test_timestamp VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'),('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'),('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_timestamp VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'),('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'),('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_timestamp") arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - print (arrow_table) + print(arrow_table) - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2008-01-01 00:00:01'").fetchone()[0] == 1 # Try > @@ -241,34 +287,55 @@ def test_filter_pushdown_timestamp(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'").fetchone()[0] == 2 - - def test_filter_pushdown_timestamp_TZ(self,duckdb_cursor): - duckdb_conn.execute(""" + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 2 + ) + + def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor): + duckdb_conn.execute( + """ CREATE TABLE test_timestamptz ( a TIMESTAMPTZ, b TIMESTAMPTZ, c TIMESTAMPTZ ) - """) - duckdb_conn.execute(""" + """ + ) + duckdb_conn.execute( + """ INSERT INTO test_timestamptz VALUES ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'), ('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'), ('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'), (NULL,NULL,NULL) - """) + """ + ) # Have to fetch as naive here, or the times will be converted into UTC and our predicates dont match - duck_tbl = duckdb_conn.sql(""" + duck_tbl = duckdb_conn.sql( + """ select a::TIMESTAMP a, b::TIMESTAMP b, c::TIMESTAMP c from test_timestamptz - """) + """ + ) arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - print (arrow_table) + print(arrow_table) - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2008-01-01 00:00:01'").fetchone()[0] == 1 # Try > @@ -286,19 +353,35 @@ def test_filter_pushdown_timestamp_TZ(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'").fetchone()[0] == 2 - - - def test_filter_pushdown_date(self,duckdb_cursor): + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 2 + ) + + def test_filter_pushdown_date(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_date (a DATE, b DATE, c DATE)") - duckdb_conn.execute("INSERT INTO test_date VALUES ('2000-01-01','2000-01-01','2000-01-01'),('2000-10-01','2000-10-01','2000-10-01'),('2010-01-01','2000-10-01','2010-01-01'),(NULL,NULL,NULL)") + duckdb_conn.execute( + "INSERT INTO test_date VALUES ('2000-01-01','2000-01-01','2000-01-01'),('2000-10-01','2000-10-01','2000-10-01'),('2010-01-01','2000-10-01','2010-01-01'),(NULL,NULL,NULL)" + ) duck_tbl = duckdb_conn.table("test_date") arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - duckdb_conn.register("testarrow",arrow_table) + duckdb_conn.register("testarrow", arrow_table) # Try == assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2000-01-01'").fetchone()[0] == 1 # Try > @@ -316,23 +399,34 @@ def test_filter_pushdown_date(self,duckdb_cursor): assert duckdb_conn.execute("SELECT count(*) from testarrow where a IS NOT NULL").fetchone()[0] == 3 # Try And - assert duckdb_conn.execute("SELECT count(*) from testarrow where a='2000-10-01' and b ='2000-01-01'").fetchone()[0] == 0 - assert duckdb_conn.execute("SELECT count(*) from testarrow where a ='2010-01-01' and b = '2000-10-01' and c = '2010-01-01'").fetchone()[0] == 1 + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a='2000-10-01' and b ='2000-01-01'").fetchone()[0] + == 0 + ) + assert ( + duckdb_conn.execute( + "SELECT count(*) from testarrow where a ='2010-01-01' and b = '2000-10-01' and c = '2010-01-01'" + ).fetchone()[0] + == 1 + ) # Try Or - assert duckdb_conn.execute("SELECT count(*) from testarrow where a = '2010-01-01' or b ='2000-01-01'").fetchone()[0] == 2 - - - def test_filter_pushdown_no_projection(self,duckdb_cursor): + assert ( + duckdb_conn.execute("SELECT count(*) from testarrow where a = '2010-01-01' or b ='2000-01-01'").fetchone()[ + 0 + ] + == 2 + ) + + def test_filter_pushdown_no_projection(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test_int (a INTEGER, b INTEGER, c INTEGER)") duckdb_conn.execute("INSERT INTO test_int VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table("test_int") arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - duckdb_conn.register("testarrowtable",arrow_table) + duckdb_conn.register("testarrowtable", arrow_table) assert duckdb_conn.execute("SELECT * FROM testarrowtable VALUES where a =1").fetchall() == [(1, 1, 1)] @pytest.mark.parametrize('pandas', [ArrowPandas()]) - def test_filter_pushdown_2145(self,duckdb_cursor, pandas): - + def test_filter_pushdown_2145(self, duckdb_cursor, pandas): date1 = pandas.date_range("2018-01-01", "2018-12-31", freq="B") df1 = pandas.DataFrame(np.random.randn(date1.shape[0], 5), columns=list("ABCDE")) df1["date"] = date1 @@ -347,7 +441,7 @@ def test_filter_pushdown_2145(self,duckdb_cursor, pandas): table = pq.ParquetDataset(["data1.parquet", "data2.parquet"]).read() con = duckdb.connect() - con.register("testarrow",table) + con.register("testarrow", table) output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() expected_df = duckdb.from_parquet("data*.parquet").filter("date > '2019-01-01'").df() @@ -356,11 +450,11 @@ def test_filter_pushdown_2145(self,duckdb_cursor, pandas): os.remove("data1.parquet") os.remove("data2.parquet") - def test_filter_column_removal(self,duckdb_cursor): + def test_filter_column_removal(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test AS SELECT range i, range j FROM range(5)") duck_test_table = duckdb_conn.table("test") arrow_test_table = duck_test_table.df().convert_dtypes(dtype_backend='pyarrow') - duckdb_conn.register("arrow_test_table",arrow_test_table) + duckdb_conn.register("arrow_test_table", arrow_test_table) # PR 4817 - remove filter columns that are unused in the remainder of the query plan from the table function query_res = duckdb_conn.execute("EXPLAIN SELECT count(*) from testarrow where a = 100 or b =1").fetchall() diff --git a/tools/pythonpkg/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tools/pythonpkg/tests/fast/pandas/test_pyarrow_projection_pushdown.py index 861939a29f34..e693e75c6856 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tools/pythonpkg/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -3,17 +3,19 @@ import pytest from conftest import pandas_supports_arrow_backend + pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") _ = pytest.importorskip("pandas", '2.0.0') + @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") class TestArrowDFProjectionPushdown(object): - def test_projection_pushdown_no_filter(self,duckdb_cursor): + def test_projection_pushdown_no_filter(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (a INTEGER, b INTEGER, c INTEGER)") duckdb_conn.execute("INSERT INTO test VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table("test") arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') - duckdb_conn.register("testarrowtable",arrow_table) + duckdb_conn.register("testarrowtable", arrow_table) assert duckdb_conn.execute("SELECT sum(a) FROM testarrowtable").fetchall() == [(111,)] diff --git a/tools/pythonpkg/tests/fast/pandas/test_same_name.py b/tools/pythonpkg/tests/fast/pandas/test_same_name.py index af0b57fdbac6..0ffbc2936a23 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_same_name.py +++ b/tools/pythonpkg/tests/fast/pandas/test_same_name.py @@ -2,14 +2,19 @@ import duckdb import pandas as pd + class TestMultipleColumnsSameName(object): def test_multiple_columns_with_same_name(self, duckdb_cursor): df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) - df = df.rename(columns={ df.columns[1]: "a" }) + df = df.rename(columns={df.columns[1]: "a"}) con = duckdb.connect() con.register('df_view', df) - assert con.execute("DESCRIBE df_view;").fetchall() == [('a', 'BIGINT', 'YES', None, None, None), ('a_1', 'BIGINT', 'YES', None, None, None), ('d', 'BIGINT', 'YES', None, None, None)] + assert con.execute("DESCRIBE df_view;").fetchall() == [ + ('a', 'BIGINT', 'YES', None, None, None), + ('a_1', 'BIGINT', 'YES', None, None, None), + ('d', 'BIGINT', 'YES', None, None, None), + ] assert con.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert con.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe @@ -17,75 +22,85 @@ def test_multiple_columns_with_same_name(self, duckdb_cursor): def test_multiple_columns_with_same_name_relation(self, duckdb_cursor): df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) - df = df.rename(columns={ df.columns[1]: "a" }) + df = df.rename(columns={df.columns[1]: "a"}) con = duckdb.connect() rel = con.from_df(df) - assert(rel.query("df_view","DESCRIBE df_view;").fetchall() == [('a', 'BIGINT', 'YES', None, None, None), ('a_1', 'BIGINT', 'YES', None, None, None), ('d', 'BIGINT', 'YES', None, None, None)]) + assert rel.query("df_view", "DESCRIBE df_view;").fetchall() == [ + ('a', 'BIGINT', 'YES', None, None, None), + ('a_1', 'BIGINT', 'YES', None, None, None), + ('d', 'BIGINT', 'YES', None, None, None), + ] + + assert rel.query("df_view", "select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] + assert rel.query("df_view", "select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] - assert rel.query("df_view","select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] - assert rel.query("df_view","select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] - # Verify we are not changing original dataframe assert all(df.columns == ['a', 'a', 'd']), df.columns def test_multiple_columns_with_same_name_replacement_scans(self, duckdb_cursor): df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) - df = df.rename(columns={ df.columns[1]: "a" }) + df = df.rename(columns={df.columns[1]: "a"}) con = duckdb.connect() assert con.execute("select a_1 from df;").fetchall() == [(5,), (6,), (7,), (8,)] assert con.execute("select a from df;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe assert all(df.columns == ['a', 'a', 'd']), df.columns - def test_3669(self, duckdb_cursor): - df = pd.DataFrame([(1, 5, 9), - (2, 6, 10), - (3, 7, 11), - (4, 8, 12)], - columns=['a_1', 'a', 'a']) + df = pd.DataFrame([(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)], columns=['a_1', 'a', 'a']) con = duckdb.connect() con.register('df_view', df) - assert con.execute("DESCRIBE df_view;").fetchall() == [('a_1', 'BIGINT', 'YES', None, None, None), ('a', 'BIGINT', 'YES', None, None, None), ('a_2', 'BIGINT', 'YES', None, None, None)] - assert con.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] + assert con.execute("DESCRIBE df_view;").fetchall() == [ + ('a_1', 'BIGINT', 'YES', None, None, None), + ('a', 'BIGINT', 'YES', None, None, None), + ('a_2', 'BIGINT', 'YES', None, None, None), + ] + assert con.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert con.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] # Verify we are not changing original dataframe assert all(df.columns == ['a_1', 'a', 'a']), df.columns def test_minimally_rename(self, duckdb_cursor): - df = pd.DataFrame([(1, 5, 9, 13), - (2, 6, 10, 14), - (3, 7, 11, 15), - (4, 8, 12, 16)], - columns=['a_1', 'a', 'a', 'a_2']) + df = pd.DataFrame( + [(1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15), (4, 8, 12, 16)], columns=['a_1', 'a', 'a', 'a_2'] + ) con = duckdb.connect() con.register('df_view', df) - assert con.execute("DESCRIBE df_view;").fetchall() == [('a_1', 'BIGINT', 'YES', None, None, None), - ('a', 'BIGINT', 'YES', None, None, None), - ('a_3', 'BIGINT', 'YES', None, None, None), - ('a_2', 'BIGINT', 'YES', None, None, None)] - assert con.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] - assert con.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] - assert con.execute("select a_3 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] + assert con.execute("DESCRIBE df_view;").fetchall() == [ + ('a_1', 'BIGINT', 'YES', None, None, None), + ('a', 'BIGINT', 'YES', None, None, None), + ('a_3', 'BIGINT', 'YES', None, None, None), + ('a_2', 'BIGINT', 'YES', None, None, None), + ] + assert con.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] + assert con.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] + assert con.execute("select a_3 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] assert con.execute("select a_2 from df_view;").fetchall() == [(13,), (14,), (15,), (16,)] # Verify we are not changing original dataframe assert all(df.columns == ['a_1', 'a', 'a', 'a_2']), df.columns def test_multiple_columns_with_same_name_2(self, duckdb_cursor): df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'a_1': [9, 10, 11, 12]}) - df = df.rename(columns={ df.columns[1]: "a_1" }) + df = df.rename(columns={df.columns[1]: "a_1"}) con = duckdb.connect() con.register('df_view', df) - assert con.execute("DESCRIBE df_view;").fetchall() == [('a', 'BIGINT', 'YES', None, None, None), ('a_1', 'BIGINT', 'YES', None, None, None), ('a_1_1', 'BIGINT', 'YES', None, None, None)] + assert con.execute("DESCRIBE df_view;").fetchall() == [ + ('a', 'BIGINT', 'YES', None, None, None), + ('a_1', 'BIGINT', 'YES', None, None, None), + ('a_1_1', 'BIGINT', 'YES', None, None, None), + ] assert con.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert con.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert con.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] def test_case_insensitive(self, duckdb_cursor): - df = pd.DataFrame({'A_1': [1, 2, 3, 4], 'a_1': [9, 10, 11, 12]}) + df = pd.DataFrame({'A_1': [1, 2, 3, 4], 'a_1': [9, 10, 11, 12]}) con = duckdb.connect() con.register('df_view', df) - assert con.execute("DESCRIBE df_view;").fetchall() == [('A_1', 'BIGINT', 'YES', None, None, None), ('a_1_1', 'BIGINT', 'YES', None, None, None)] + assert con.execute("DESCRIBE df_view;").fetchall() == [ + ('A_1', 'BIGINT', 'YES', None, None, None), + ('a_1_1', 'BIGINT', 'YES', None, None, None), + ] assert con.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert con.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] diff --git a/tools/pythonpkg/tests/fast/pandas/test_stride.py b/tools/pythonpkg/tests/fast/pandas/test_stride.py index 08107cfe2040..e8968ee793e1 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_stride.py +++ b/tools/pythonpkg/tests/fast/pandas/test_stride.py @@ -2,8 +2,8 @@ import duckdb import numpy as np -class TestPandasStride(object): +class TestPandasStride(object): def test_stride(self, duckdb_cursor): expected_df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() @@ -17,7 +17,7 @@ def test_stride_fp32(self, duckdb_cursor): con.register('df_view', expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert(str(output_df[col].dtype) == 'float32') + assert str(output_df[col].dtype) == 'float32' pd.testing.assert_frame_equal(expected_df, output_df) def test_stride_fp64(self, duckdb_cursor): @@ -26,5 +26,5 @@ def test_stride_fp64(self, duckdb_cursor): con.register('df_view', expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert(str(output_df[col].dtype) == 'float64') + assert str(output_df[col].dtype) == 'float64' pd.testing.assert_frame_equal(expected_df, output_df) diff --git a/tools/pythonpkg/tests/fast/pandas/test_timedelta.py b/tools/pythonpkg/tests/fast/pandas/test_timedelta.py index df1f68ffda3e..4737cfbe9b5f 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_timedelta.py +++ b/tools/pythonpkg/tests/fast/pandas/test_timedelta.py @@ -6,30 +6,29 @@ class TestTimedelta(object): - def test_timedelta_positive(self, duckdb_cursor): - duckdb_interval = duckdb.query("SELECT '2290-01-01 23:59:00'::TIMESTAMP - '2000-01-01 23:59:00'::TIMESTAMP AS '0'").df() + duckdb_interval = duckdb.query( + "SELECT '2290-01-01 23:59:00'::TIMESTAMP - '2000-01-01 23:59:00'::TIMESTAMP AS '0'" + ).df() data = [datetime.timedelta(microseconds=9151574400000000)] - df_in = pd.DataFrame( - {0: pd.Series(data=data, dtype='object')} - ) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pd.testing.assert_frame_equal(df_out, duckdb_interval) def test_timedelta_coverage(self, duckdb_cursor): - duckdb_interval = duckdb.query("SELECT '2290-08-30 23:53:40'::TIMESTAMP - '2000-02-01 01:56:00'::TIMESTAMP AS '0'").df() + duckdb_interval = duckdb.query( + "SELECT '2290-08-30 23:53:40'::TIMESTAMP - '2000-02-01 01:56:00'::TIMESTAMP AS '0'" + ).df() data = [datetime.timedelta(microseconds=9169797460000000)] - df_in = pd.DataFrame( - {0: pd.Series(data=data, dtype='object')} - ) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pd.testing.assert_frame_equal(df_out, duckdb_interval) def test_timedelta_negative(self, duckdb_cursor): - duckdb_interval = duckdb.query("SELECT '2000-01-01 23:59:00'::TIMESTAMP - '2290-01-01 23:59:00'::TIMESTAMP AS '0'").df() + duckdb_interval = duckdb.query( + "SELECT '2000-01-01 23:59:00'::TIMESTAMP - '2290-01-01 23:59:00'::TIMESTAMP AS '0'" + ).df() data = [datetime.timedelta(microseconds=-9151574400000000)] - df_in = pd.DataFrame( - {0: pd.Series(data=data, dtype='object')} - ) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pd.testing.assert_frame_equal(df_out, duckdb_interval) diff --git a/tools/pythonpkg/tests/fast/pandas/test_timestamp.py b/tools/pythonpkg/tests/fast/pandas/test_timestamp.py index f37d8d835f04..72b1bb4fdff0 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_timestamp.py +++ b/tools/pythonpkg/tests/fast/pandas/test_timestamp.py @@ -4,20 +4,38 @@ import pytest import pandas as pd + class TestPandasTimestamps(object): def test_timestamp_types_roundtrip(self, duckdb_cursor): - d = {'a': [pd.Timestamp(datetime.datetime.now(), unit='s')], 'b': [pd.Timestamp(datetime.datetime.now(), unit='ms')], 'c': [pd.Timestamp(datetime.datetime.now(), unit='us')], 'd': [pd.Timestamp(datetime.datetime.now(), unit='ns')]} + d = { + 'a': [pd.Timestamp(datetime.datetime.now(), unit='s')], + 'b': [pd.Timestamp(datetime.datetime.now(), unit='ms')], + 'c': [pd.Timestamp(datetime.datetime.now(), unit='us')], + 'd': [pd.Timestamp(datetime.datetime.now(), unit='ns')], + } df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() - assert(df_from_duck.equals(df)) + assert df_from_duck.equals(df) def test_timestamp_nulls(self, duckdb_cursor): - d = {'a': [pd.Timestamp(None, unit='s')], 'b': [pd.Timestamp(None, unit='ms')], 'c': [pd.Timestamp(None, unit='us')], 'd': [pd.Timestamp(None, unit='ns')]} + d = { + 'a': [pd.Timestamp(None, unit='s')], + 'b': [pd.Timestamp(None, unit='ms')], + 'c': [pd.Timestamp(None, unit='us')], + 'd': [pd.Timestamp(None, unit='ns')], + } df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() - assert (df_from_duck.equals(df)) + assert df_from_duck.equals(df) def test_timestamp_timedelta(self, duckdb_cursor): - df = pd.DataFrame({'a': [pd.Timedelta(1, unit='s')], 'b': [pd.Timedelta(None, unit='s')], 'c': [pd.Timedelta(1, unit='us')] , 'd': [pd.Timedelta(1, unit='ms')]}) + df = pd.DataFrame( + { + 'a': [pd.Timedelta(1, unit='s')], + 'b': [pd.Timedelta(None, unit='s')], + 'c': [pd.Timedelta(1, unit='us')], + 'd': [pd.Timedelta(1, unit='ms')], + } + ) df_from_duck = duckdb.from_df(df).df() - assert (df_from_duck.equals(df)) \ No newline at end of file + assert df_from_duck.equals(df) diff --git a/tools/pythonpkg/tests/fast/relational_api/test_rapi_aggregations.py b/tools/pythonpkg/tests/fast/relational_api/test_rapi_aggregations.py index cbf95f16e2ca..f65d36e437d2 100644 --- a/tools/pythonpkg/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tools/pythonpkg/tests/fast/relational_api/test_rapi_aggregations.py @@ -10,10 +10,12 @@ def setup_and_teardown_of_table(duckdb_cursor): yield duckdb_cursor.execute('drop table bla') + @pytest.fixture() def table(duckdb_cursor): return duckdb_cursor.table('bla') + def munge(cell): try: cell = round(float(cell), 2) @@ -21,25 +23,27 @@ def munge(cell): cell = str(cell) return cell + def munge_compare(left_list, right_list): assert len(left_list) == len(right_list) - for i in range (len(left_list)): + for i in range(len(left_list)): tpl_left = left_list[i] tpl_right = right_list[i] assert len(tpl_left) == len(tpl_right) - for j in range (len(tpl_left)): + for j in range(len(tpl_left)): left_cell = munge(tpl_left[j]) right_cell = munge(tpl_right[j]) assert left_cell == right_cell -def aggregation_generic(aggregation_function,assertion_answers): - assert len(assertion_answers) >=2 - # Check single column + +def aggregation_generic(aggregation_function, assertion_answers): + assert len(assertion_answers) >= 2 + # Check single column print(aggregation_function('i').execute().fetchall()) munge_compare(aggregation_function('i').execute().fetchall(), assertion_answers[0]) # Check multi column - print(aggregation_function('i,j').execute().fetchall() ) + print(aggregation_function('i,j').execute().fetchall()) munge_compare(aggregation_function('i,j').execute().fetchall(), assertion_answers[1]) if len(assertion_answers) < 3: @@ -47,44 +51,45 @@ def aggregation_generic(aggregation_function,assertion_answers): with pytest.raises(duckdb.BinderException, match='No function matches the given name'): aggregation_function('k').execute().fetchall() else: - print (aggregation_function('k').execute().fetchall()) - munge_compare( aggregation_function('k').execute().fetchall(), assertion_answers[2]) + print(aggregation_function('k').execute().fetchall()) + munge_compare(aggregation_function('k').execute().fetchall(), assertion_answers[2]) # Check empty with pytest.raises(TypeError, match='incompatible function arguments'): aggregation_function().execute().fetchall() # Check Null with pytest.raises(TypeError, match='incompatible function arguments'): aggregation_function(None).execute().fetchall() - + # Check broken with pytest.raises(duckdb.BinderException, match='Referenced column "nonexistant" not found'): aggregation_function('nonexistant').execute().fetchall() + class TestRAPIAggregations(object): def test_sum(self, table): - aggregation_generic(table.sum,[[(3,)], [(3, Decimal('5.30'))]]) + aggregation_generic(table.sum, [[(3,)], [(3, Decimal('5.30'))]]) def test_count(self, table): - aggregation_generic(table.count,[[(2,)], [(2,2)], [(2,)]]) + aggregation_generic(table.count, [[(2,)], [(2, 2)], [(2,)]]) def test_median(self, table): # is this supposed to accept strings? - aggregation_generic(table.median,[[(1.5,)], [(1.5, Decimal('2.10'))], [('a',)]]) + aggregation_generic(table.median, [[(1.5,)], [(1.5, Decimal('2.10'))], [('a',)]]) def test_min(self, table): - aggregation_generic(table.min,[[(1,)], [(1, Decimal('2.10'))], [('a',)]]) + aggregation_generic(table.min, [[(1,)], [(1, Decimal('2.10'))], [('a',)]]) def test_max(self, table): - aggregation_generic(table.max,[[(2,)], [(2, Decimal('3.2'))], [('b',)]]) + aggregation_generic(table.max, [[(2,)], [(2, Decimal('3.2'))], [('b',)]]) def test_mean(self, table): - aggregation_generic(table.mean,[[(1.5,)], [(1.5, 2.65)]]) + aggregation_generic(table.mean, [[(1.5,)], [(1.5, 2.65)]]) def test_var(self, table): - aggregation_generic(table.var,[[(0.25,)], [(0.25, 0.30249999999999994)]]) + aggregation_generic(table.var, [[(0.25,)], [(0.25, 0.30249999999999994)]]) def test_std(self, table): - aggregation_generic(table.std,[[(0.5,)], [(0.5, 0.5499999999999999)]]) + aggregation_generic(table.std, [[(0.5,)], [(0.5, 0.5499999999999999)]]) def test_apply(self, table): table.apply('sum', 'i').execute().fetchone() == (3,) @@ -93,12 +98,12 @@ def test_quantile(self, table): extra_param = '0.5' aggregation_function = table.quantile # Check single column - assert aggregation_function(extra_param,'i').execute().fetchone() == (1,) + assert aggregation_function(extra_param, 'i').execute().fetchone() == (1,) # Check multi column - assert aggregation_function(extra_param,'i,j').execute().fetchone() == (1, Decimal('2.10')) + assert aggregation_function(extra_param, 'i,j').execute().fetchone() == (1, Decimal('2.10')) - assert aggregation_function(extra_param,'k').execute().fetchone() == ('a',) + assert aggregation_function(extra_param, 'k').execute().fetchone() == ('a',) # Check empty with pytest.raises(TypeError, match='incompatible function arguments'): @@ -106,14 +111,14 @@ def test_quantile(self, table): # Check Null with pytest.raises(TypeError, match='incompatible function arguments'): aggregation_function(None).execute().fetchone() - + # Check broken with pytest.raises(TypeError, match='incompatible function arguments.'): aggregation_function('bla').execute().fetchone() def test_value_counts(self, duckdb_cursor, table): duckdb_cursor.execute("insert into bla values (1,2.1,'a'), (NULL, NULL, NULL)") - munge_compare(table.value_counts('i').execute().fetchall(),[(None, 0), (1, 2), (2, 1)]) + munge_compare(table.value_counts('i').execute().fetchall(), [(None, 0), (1, 2), (2, 1)]) with pytest.raises(duckdb.InvalidInputException, match='Only one column is accepted'): table.value_counts('i,j').execute().fetchall() @@ -125,24 +130,34 @@ def test_shape(self, table): assert table.shape == (3, 3) def test_unique(self, table): - aggregation_generic(table.unique,[[(1,), (2,), (None,)], [(1, Decimal('2.10')), (2, Decimal('3.20')), (None, None)],[('a',), ('b',), (None,)]]) + aggregation_generic( + table.unique, + [ + [(1,), (2,), (None,)], + [(1, Decimal('2.10')), (2, Decimal('3.20')), (None, None)], + [('a',), ('b',), (None,)], + ], + ) def test_mad(self, table): - aggregation_generic(table.mad,[[(0.5,)], [(0.5, Decimal('0.55'))]]) + aggregation_generic(table.mad, [[(0.5,)], [(0.5, Decimal('0.55'))]]) def test_mode(self, table): - aggregation_generic(table.mode,[[(1,)], [(1, Decimal('2.10'))],[('a',)]]) + aggregation_generic(table.mode, [[(1,)], [(1, Decimal('2.10'))], [('a',)]]) def test_abs(self, table): - aggregation_generic(table.abs,[[(1,), (2,), (None,)], [(1, Decimal('2.10')), (2, Decimal('3.20')), (None, None)]]) + aggregation_generic( + table.abs, [[(1,), (2,), (None,)], [(1, Decimal('2.10')), (2, Decimal('3.20')), (None, None)]] + ) def test_prod(self, table): - aggregation_generic(table.prod,[[(2.0,)], [(2.0, 6.720000000000001)]]) + aggregation_generic(table.prod, [[(2.0,)], [(2.0, 6.720000000000001)]]) def test_skew(self, duckdb_cursor, table): - aggregation_generic(table.skew,[[(None,)], [(None, None)]]) + aggregation_generic(table.skew, [[(None,)], [(None, None)]]) duckdb_cursor.execute("create table aggr(k int, v decimal(10,2), v2 decimal(10, 2));") - duckdb_cursor.execute("""insert into aggr values + duckdb_cursor.execute( + """insert into aggr values (1, 10, null), (2, 10, 11), (2, 10, 15), @@ -153,16 +168,20 @@ def test_skew(self, duckdb_cursor, table): (2, 30, 35), (2, 30, 40), (2, 30, 50), - (2, 30, 51);""") + (2, 30, 51);""" + ) rel = duckdb_cursor.table('aggr') - munge_compare(rel.skew('k,v,v2').execute().fetchall(),[(-3.316624790355393, -0.16344366935199223, 0.3654008511025841)]) + munge_compare( + rel.skew('k,v,v2').execute().fetchall(), [(-3.316624790355393, -0.16344366935199223, 0.3654008511025841)] + ) duckdb_cursor.execute("drop table aggr") def test_kurt(self, duckdb_cursor, table): - aggregation_generic(table.kurt,[[(None,)], [(None, None)]]) + aggregation_generic(table.kurt, [[(None,)], [(None, None)]]) duckdb_cursor.execute("create table aggr(k int, v decimal(10,2), v2 decimal(10, 2));") - duckdb_cursor.execute("""insert into aggr values + duckdb_cursor.execute( + """insert into aggr values (1, 10, null), (2, 10, 11), (2, 10, 15), @@ -173,25 +192,46 @@ def test_kurt(self, duckdb_cursor, table): (2, 30, 35), (2, 30, 40), (2, 30, 50), - (2, 30, 51);""") + (2, 30, 51);""" + ) rel = duckdb_cursor.table('aggr') - munge_compare(rel.kurt('k,v,v2').execute().fetchall(),[(10.99999999999836, -1.9614277138467147, -1.445119691585509)]) + munge_compare( + rel.kurt('k,v,v2').execute().fetchall(), [(10.99999999999836, -1.9614277138467147, -1.445119691585509)] + ) duckdb_cursor.execute("drop table aggr") def test_cum_sum(self, table): - aggregation_generic(table.cumsum,[[(1,), (3,), (3,)], [(1, Decimal('2.10')), (3, Decimal('5.30')), (3, Decimal('5.30'))]]) + aggregation_generic( + table.cumsum, [[(1,), (3,), (3,)], [(1, Decimal('2.10')), (3, Decimal('5.30')), (3, Decimal('5.30'))]] + ) def test_cum_prod(self, table): - aggregation_generic(table.cumprod,[[(1.0,), (2.0,), (2.0,)], [(1.0, 2.1), (2.0, 6.720000000000001), (2.0, 6.720000000000001)]]) + aggregation_generic( + table.cumprod, [[(1.0,), (2.0,), (2.0,)], [(1.0, 2.1), (2.0, 6.720000000000001), (2.0, 6.720000000000001)]] + ) def test_cum_max(self, table): - aggregation_generic(table.cummax,[[(1,), (2,), (2,)], [(1, Decimal('2.10')), (2, Decimal('3.20')), (2, Decimal('3.20'))], [('a',), ('b',), ('b',)]]) + aggregation_generic( + table.cummax, + [ + [(1,), (2,), (2,)], + [(1, Decimal('2.10')), (2, Decimal('3.20')), (2, Decimal('3.20'))], + [('a',), ('b',), ('b',)], + ], + ) def test_cum_min(self, table): - aggregation_generic(table.cummin,[[(1,), (1,), (1,)], [(1, Decimal('2.10')), (1, Decimal('2.10')), (1, Decimal('2.10'))], [('a',), ('a',), ('a',)]]) + aggregation_generic( + table.cummin, + [ + [(1,), (1,), (1,)], + [(1, Decimal('2.10')), (1, Decimal('2.10')), (1, Decimal('2.10'))], + [('a',), ('a',), ('a',)], + ], + ) def test_cum_sem(self, table): - aggregation_generic(table.sem,[[(0.35355339059327373,)], [(0.35355339059327373, 0.38890872965260104)]]) + aggregation_generic(table.sem, [[(0.35355339059327373,)], [(0.35355339059327373, 0.38890872965260104)]]) def test_describe(self, table): assert table.describe().fetchall() is not None diff --git a/tools/pythonpkg/tests/fast/relational_api/test_rapi_close.py b/tools/pythonpkg/tests/fast/relational_api/test_rapi_close.py index 5183bffa31b1..97294e737418 100644 --- a/tools/pythonpkg/tests/fast/relational_api/test_rapi_close.py +++ b/tools/pythonpkg/tests/fast/relational_api/test_rapi_close.py @@ -1,136 +1,137 @@ import duckdb import pytest + # A closed connection should invalidate all relation's methods class TestRAPICloseConnRel(object): - def test_close_conn_rel(self, duckdb_cursor): - con = duckdb.connect() - con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") - con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") - rel = con.table("items") - con.close() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - print(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - len(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.filter("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.project("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.order("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.aggregate("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.sum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.count("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.median("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.quantile("","") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.apply("","") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.min("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.max("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.mean("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.var("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.std("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.value_counts("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.unique("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.union(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.except_(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.intersect(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.join(rel.set_alias('other'), "a") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.distinct() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - print(rel.limit(1)) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.query("","") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.execute() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.write_csv("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.insert_into("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.insert("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.create("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.create_view("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.to_arrow_table() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.arrow() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.to_df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.fetchone() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.fetchall() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.map(lambda df : df['col0'].add(42).to_frame()) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.mad("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.mode("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.abs("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.prod("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.skew("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.kurt("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.sem("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.cumsum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.cumprod("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.cummax("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.cummin("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.describe() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.fetchnumpy() - con = duckdb.connect() - con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") - con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") - valid_rel = con.table("items") + def test_close_conn_rel(self, duckdb_cursor): + con = duckdb.connect() + con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") + con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") + rel = con.table("items") + con.close() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + print(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + len(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.filter("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.project("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.order("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.aggregate("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.sum("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.count("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.median("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.quantile("", "") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.apply("", "") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.min("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.max("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.mean("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.var("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.std("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.value_counts("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.unique("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.union(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.except_(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.intersect(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.join(rel.set_alias('other'), "a") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.distinct() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + print(rel.limit(1)) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.query("", "") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.execute() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.write_csv("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.insert_into("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.insert("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.create("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.create_view("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.to_arrow_table() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.arrow() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.to_df() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.df() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.fetchone() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.fetchall() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.map(lambda df: df['col0'].add(42).to_frame()) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.mad("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.mode("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.abs("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.prod("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.skew("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.kurt("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.sem("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.cumsum("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.cumprod("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.cummax("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.cummin("") + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.describe() + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + rel.fetchnumpy() + con = duckdb.connect() + con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") + con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") + valid_rel = con.table("items") - # Test these bad boys when left relation is valid - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - valid_rel.union(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - valid_rel.except_(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - valid_rel.intersect(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - valid_rel.join(rel.set_alias('rel'), "rel.items = valid_rel.items") + # Test these bad boys when left relation is valid + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + valid_rel.union(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + valid_rel.except_(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + valid_rel.intersect(rel) + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + valid_rel.join(rel.set_alias('rel'), "rel.items = valid_rel.items") - def test_del_conn(self, duckdb_cursor): - con = duckdb.connect() - con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") - con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") - rel = con.table("items") - del con - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - print(rel) + def test_del_conn(self, duckdb_cursor): + con = duckdb.connect() + con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") + con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") + rel = con.table("items") + del con + with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + print(rel) diff --git a/tools/pythonpkg/tests/fast/relational_api/test_rapi_description.py b/tools/pythonpkg/tests/fast/relational_api/test_rapi_description.py index 91899038be63..395ff2e0c09a 100644 --- a/tools/pythonpkg/tests/fast/relational_api/test_rapi_description.py +++ b/tools/pythonpkg/tests/fast/relational_api/test_rapi_description.py @@ -1,6 +1,7 @@ import duckdb import pytest + class TestRAPIDescription(object): def test_rapi_description(self): res = duckdb.query('select 42::INT AS a, 84::BIGINT AS b') @@ -20,7 +21,9 @@ def test_rapi_describe(self): np.testing.assert_array_equal(duck_describe['b'], [1, 84, float('nan'), 84, 84, 84]) # now with more values - res = duckdb.query('select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)') + res = duckdb.query( + 'select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)' + ) duck_describe = res.describe().df() np.testing.assert_allclose(duck_describe['i'], [5000.0, 4999.0, 2887.0400066504103, 0.0, 9998.0, 4999.0]) np.testing.assert_allclose(duck_describe['j'], [10000.0, 49995.0, 28868.956799071675, 0.0, 99990.0, 49995.0]) diff --git a/tools/pythonpkg/tests/fast/relational_api/test_rapi_functions.py b/tools/pythonpkg/tests/fast/relational_api/test_rapi_functions.py index 915b0b09dbf8..bb0dc4781b54 100644 --- a/tools/pythonpkg/tests/fast/relational_api/test_rapi_functions.py +++ b/tools/pythonpkg/tests/fast/relational_api/test_rapi_functions.py @@ -1,5 +1,6 @@ import duckdb + class TestRAPIFunctions(object): def test_rapi_str_print(self): res = duckdb.query('select 42::INT AS a, 84::BIGINT AS b') diff --git a/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py b/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py index 02ea1d234663..5ad632d74bf3 100644 --- a/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py +++ b/tools/pythonpkg/tests/fast/relational_api/test_rapi_query.py @@ -1,6 +1,7 @@ import duckdb import pytest + @pytest.fixture() def tbl_table(): con = duckdb.default_connection @@ -9,8 +10,8 @@ def tbl_table(): yield con.execute('drop table tbl') -class TestRAPIQuery(object): +class TestRAPIQuery(object): @pytest.mark.parametrize('steps', [1, 2, 3, 4]) def test_query_chain(self, steps): con = duckdb.default_connection @@ -21,9 +22,9 @@ def test_query_chain(self, steps): amount = amount / 10 rel = rel.query("rel", f"select * from rel limit {amount}") result = rel.execute() - assert(len(result.fetchall()) == amount) + assert len(result.fetchall()) == amount - @pytest.mark.parametrize('input', [[5,4,3],[], [1000]]) + @pytest.mark.parametrize('input', [[5, 4, 3], [], [1000]]) def test_query_table(self, tbl_table, input): con = duckdb.default_connection rel = con.table("tbl") @@ -32,7 +33,7 @@ def test_query_table(self, tbl_table, input): # Querying a table relation rel = rel.query("x", "select * from x") result = rel.execute() - assert(result.fetchall() == [tuple([x]) for x in input]) + assert result.fetchall() == [tuple([x]) for x in input] def test_query_table_unrelated(self, tbl_table): con = duckdb.default_connection @@ -40,7 +41,7 @@ def test_query_table_unrelated(self, tbl_table): # Querying a table relation rel = rel.query("x", "select 5") result = rel.execute() - assert(result.fetchall() == [(5,)]) + assert result.fetchall() == [(5,)] def test_query_table_qualified(self): con = duckdb.default_connection @@ -48,7 +49,7 @@ def test_query_table_qualified(self): # Create table in fff schema con.execute("create table fff.t2 as select 1 as t") - assert(con.table("fff.t2").fetchall() == [(1,)]) + assert con.table("fff.t2").fetchall() == [(1,)] def test_query_insert_into_relation(self, tbl_table): con = duckdb.default_connection @@ -59,11 +60,11 @@ def test_query_insert_into_relation(self, tbl_table): def test_query_non_select(self): con = duckdb.connect() - rel = con.query("select [1,2,3,4]"); + rel = con.query("select [1,2,3,4]") rel.query("relation", "create table tbl as select * from relation") result = con.execute("select * from tbl").fetchall() - assert result == [([1,2,3,4],)] + assert result == [([1, 2, 3, 4],)] def test_query_non_select_fail(self): con = duckdb.connect() @@ -77,14 +78,13 @@ def test_query_non_select_fail(self): with pytest.raises(duckdb.CatalogException): rel.query("relation", "create table tbl as select * from not_a_valid_view") - def test_query_table_unrelated(self, tbl_table): con = duckdb.default_connection rel = con.table("tbl") # Querying a table relation rel = rel.query("x", "select 5") result = rel.execute() - assert(result.fetchall() == [(5,)]) + assert result.fetchall() == [(5,)] def test_query_non_select_result(self): with pytest.raises(duckdb.ParserException, match="syntax error"): @@ -121,6 +121,7 @@ def test_replacement_scan_recursion(self): con = duckdb.connect() depth_limit = 1000 import sys + if sys.platform.startswith('win'): # With the default we reach a stack overflow in the CI depth_limit = 250 diff --git a/tools/pythonpkg/tests/fast/sqlite/test_types.py b/tools/pythonpkg/tests/fast/sqlite/test_types.py index ef8e806933de..58b404857c15 100644 --- a/tools/pythonpkg/tests/fast/sqlite/test_types.py +++ b/tools/pythonpkg/tests/fast/sqlite/test_types.py @@ -1,4 +1,4 @@ -#-*- coding: iso-8859-1 -*- +# -*- coding: iso-8859-1 -*- # pysqlite2/test/types.py: tests for type conversion and detection # # Copyright (C) 2005 Gerhard Häring @@ -30,6 +30,7 @@ import duckdb import pytest + class DuckDBTypeTests(unittest.TestCase): def setUp(self): self.con = duckdb.connect(":memory:") @@ -91,6 +92,7 @@ def test_CheckDecimalWithExponent(self): def test_CheckNaN(self): import math + val = decimal.Decimal('nan') self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") @@ -140,9 +142,7 @@ def test_CheckUnicodeExecute(self): self.assertEqual(row[0], u"Österreich") - class CommonTableExpressionTests(unittest.TestCase): - def setUp(self): self.con = duckdb.connect(":memory:") self.cur = self.con.cursor() @@ -173,6 +173,7 @@ def test_CheckCursorDescriptionCTE(self): self.assertIsNotNone(self.cur.description) self.assertEqual(self.cur.description[0][0], "x") + class DateTimeTests(unittest.TestCase): def setUp(self): self.con = duckdb.connect(":memory:") @@ -239,9 +240,7 @@ class ListTests(unittest.TestCase): def setUp(self): self.con = duckdb.connect(":memory:") self.cur = self.con.cursor() - self.cur.execute( - "create table test(single INTEGER[], nested INTEGER[][])" - ) + self.cur.execute("create table test(single INTEGER[], nested INTEGER[][])") def tearDown(self): self.cur.close() @@ -268,7 +267,12 @@ def test_CheckNestedList(self): self.cur.execute("insert into test(nested) values (?)", (val,)) self.assertEqual( self.cur.execute("select * from test").fetchall(), - [(None, val,)], + [ + ( + None, + val, + ) + ], ) def test_CheckNone(self): diff --git a/tools/pythonpkg/tests/fast/test_alex_multithread.py b/tools/pythonpkg/tests/fast/test_alex_multithread.py index 83f2ba1328a5..765b8f0ab2bf 100644 --- a/tools/pythonpkg/tests/fast/test_alex_multithread.py +++ b/tools/pythonpkg/tests/fast/test_alex_multithread.py @@ -4,36 +4,38 @@ import os import pytest + @pytest.fixture(scope="session") def tmp_database(tmp_path_factory): database = tmp_path_factory.mktemp("databases", numbered=True) / "tmp.duckdb" return str(database) + def insert_from_cursor(duckdb_con): # Insert a row with the name of the thread - duckdb_cursor = duckdb_con.cursor() # Make a cursor within the thread + duckdb_cursor = duckdb_con.cursor() # Make a cursor within the thread thread_name = str(current_thread().name) duckdb_cursor.execute("""INSERT INTO my_inserts VALUES (?)""", (thread_name,)) + def insert_from_same_connection(duckdb_cursor): # Insert a row with the name of the thread thread_name = str(current_thread().name) duckdb_cursor.execute("""INSERT INTO my_inserts VALUES (?)""", (thread_name,)) + class TestPythonMultithreading(object): def test_multiple_cursors(self, duckdb_cursor): - duckdb_con = duckdb.connect() # In Memory DuckDB + duckdb_con = duckdb.connect() # In Memory DuckDB duckdb_con.execute("""CREATE OR REPLACE TABLE my_inserts (thread_name varchar)""") thread_count = 3 threads = [] - # Kick off multiple threads (in the same process) + # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, - args=(duckdb_con,), - name='my_thread_'+str(i))) + threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) for thread in threads: thread.start() @@ -41,23 +43,25 @@ def test_multiple_cursors(self, duckdb_cursor): for thread in threads: thread.join() - assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [('my_thread_0',), ('my_thread_1',), ('my_thread_2',)] + assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ + ('my_thread_0',), + ('my_thread_1',), + ('my_thread_2',), + ] def test_same_connection(self, duckdb_cursor): - duckdb_con = duckdb.connect() # In Memory DuckDB + duckdb_con = duckdb.connect() # In Memory DuckDB duckdb_con.execute("""CREATE OR REPLACE TABLE my_inserts (thread_name varchar)""") thread_count = 3 threads = [] cursors = [] - # Kick off multiple threads (in the same process) + # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): cursors.append(duckdb_con.cursor()) - threads.append(Thread(target=insert_from_same_connection, - args=(cursors[i],), - name='my_thread_'+str(i))) + threads.append(Thread(target=insert_from_same_connection, args=(cursors[i],), name='my_thread_' + str(i))) for thread in threads: thread.start() @@ -65,7 +69,11 @@ def test_same_connection(self, duckdb_cursor): for thread in threads: thread.join() - assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [('my_thread_0',), ('my_thread_1',), ('my_thread_2',)] + assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ + ('my_thread_0',), + ('my_thread_1',), + ('my_thread_2',), + ] def test_multiple_cursors_persisted(self, tmp_database): duckdb_con = duckdb.connect(tmp_database) @@ -74,19 +82,21 @@ def test_multiple_cursors_persisted(self, tmp_database): thread_count = 3 threads = [] - # Kick off multiple threads (in the same process) + # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, - args=(duckdb_con,), - name='my_thread_'+str(i))) + threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) for thread in threads: thread.start() for thread in threads: thread.join() - assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [('my_thread_0',), ('my_thread_1',), ('my_thread_2',)] + assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ + ('my_thread_0',), + ('my_thread_1',), + ('my_thread_2',), + ] duckdb_con.close() def test_same_connection_persisted(self, tmp_database): @@ -96,17 +106,19 @@ def test_same_connection_persisted(self, tmp_database): thread_count = 3 threads = [] - # Kick off multiple threads (in the same process) + # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_same_connection, - args=(duckdb_con,), - name='my_thread_'+str(i))) + threads.append(Thread(target=insert_from_same_connection, args=(duckdb_con,), name='my_thread_' + str(i))) for thread in threads: thread.start() for thread in threads: thread.join() - assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [('my_thread_0',), ('my_thread_1',), ('my_thread_2',)] + assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ + ('my_thread_0',), + ('my_thread_1',), + ('my_thread_2',), + ] duckdb_con.close() diff --git a/tools/pythonpkg/tests/fast/test_all_types.py b/tools/pythonpkg/tests/fast/test_all_types.py index 3aa9358e6d7f..1e79b3c9fc77 100644 --- a/tools/pythonpkg/tests/fast/test_all_types.py +++ b/tools/pythonpkg/tests/fast/test_all_types.py @@ -6,6 +6,7 @@ from decimal import Decimal from uuid import UUID + def get_all_types(): conn = duckdb.connect() all_types = conn.execute("describe select * from test_all_types()").fetchall() @@ -14,8 +15,10 @@ def get_all_types(): types.append(cur_type[0]) return types + all_types = get_all_types() + # we need to write our own equality function that considers nan==nan for testing purposes def recursive_equality(o1, o2): if o1 == o2: @@ -36,6 +39,7 @@ def recursive_equality(o1, o2): except: return False + class TestAllTypes(object): def test_fetchall(self, duckdb_cursor): conn = duckdb.connect() @@ -52,35 +56,105 @@ def test_fetchall(self, duckdb_cursor): 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", } - correct_answer_map = {'bool':[(False,), (True,), (None,)] - , 'tinyint':[(-128,), (127,), (None,)], 'smallint': [(-32768,), (32767,), (None,)] - , 'int':[(-2147483648,), (2147483647,), (None,)],'bigint':[(-9223372036854775808,), (9223372036854775807,), (None,)] - , 'hugeint':[(-170141183460469231731687303715884105727,), (170141183460469231731687303715884105727,), (None,)] - , 'utinyint': [(0,), (255,), (None,)], 'usmallint': [(0,), (65535,), (None,)] - , 'uint':[(0,), (4294967295,), (None,)], 'ubigint': [(0,), (18446744073709551615,), (None,)] - , 'time':[(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)] - , 'float': [(-3.4028234663852886e+38,), (3.4028234663852886e+38,), (None,)], 'double': [(-1.7976931348623157e+308,), (1.7976931348623157e+308,), (None,)] - , 'dec_4_1': [(Decimal('-999.9'),), (Decimal('999.9'),), (None,)], 'dec_9_4': [(Decimal('-99999.9999'),), (Decimal('99999.9999'),), (None,)] - , 'dec_18_6': [(Decimal('-999999999999.999999'),), (Decimal('999999999999.999999'),), (None,)], 'dec38_10':[(Decimal('-9999999999999999999999999999.9999999999'),), (Decimal('9999999999999999999999999999.9999999999'),), (None,)] - , 'uuid': [(UUID('00000000-0000-0000-0000-000000000001'),), (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), (None,)] - , 'varchar': [('🦆🦆🦆🦆🦆🦆',), ('goo\0se',), (None,)], 'json': [('🦆🦆🦆🦆🦆🦆',), ('goose',), (None,)], 'blob': [(b'thisisalongblob\x00withnullbytes',), (b'\x00\x00\x00a',), (None,)], 'bit': [('0010001001011100010101011010111',), ('10101',), (None,)] - , 'small_enum':[('DUCK_DUCK_ENUM',), ('GOOSE',), (None,)], 'medium_enum': [('enum_0',), ('enum_299',), (None,)], 'large_enum': [('enum_0',), ('enum_69999',), (None,)] - , 'date_array': [([], [datetime.date(1970, 1, 1), None, datetime.date.min, datetime.date.max], [None,],)] - , 'timestamp_array': [([], [datetime.datetime(1970, 1, 1), None, datetime.datetime.min, datetime.datetime.max], [None,],),] - , 'timestamptz_array': [([], [datetime.datetime(1970, 1, 1), None, datetime.datetime.min, datetime.datetime.max], [None,],),] - , 'int_array': [([],), ([42, 999, None, None, -42],), (None,)], 'varchar_array': [([],), (['🦆🦆🦆🦆🦆🦆', 'goose', None, ''],), (None,)] - , 'double_array': [([],), ([42.0, float('nan'), float('inf'), float('-inf'), None, -42.0],), (None,)] - , 'nested_int_array': [([],), ([[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]],), (None,)], 'struct': [({'a': None, 'b': None},), ({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'},), (None,)] - , 'struct_of_arrays': [({'a': None, 'b': None},), ({'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']},), (None,)] - , 'array_of_structs': [([],), ([{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None],), (None,)], 'map':[({'key': [], 'value': []},), ({'key': ['key1', 'key2'], 'value': ['🦆🦆🦆🦆🦆🦆', 'goose']},), (None,)] - , 'time_tz':[(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], 'interval': [(datetime.timedelta(0),), (datetime.timedelta(days=30969, seconds=999, microseconds=999999),), (None,)] - , 'timestamp':[(datetime.datetime(1990, 1, 1, 0, 0),)], 'date':[(datetime.date(1990, 1, 1),)], 'timestamp_s':[(datetime.datetime(1990, 1, 1, 0, 0),)] - , 'timestamp_ns':[(datetime.datetime(1990, 1, 1, 0, 0),)], 'timestamp_ms':[(datetime.datetime(1990, 1, 1, 0, 0),)], 'timestamp_tz':[(datetime.datetime(1990, 1, 1, 0, 0),)] - , 'union':[('Frank',),(5,),(None,)],} + correct_answer_map = { + 'bool': [(False,), (True,), (None,)], + 'tinyint': [(-128,), (127,), (None,)], + 'smallint': [(-32768,), (32767,), (None,)], + 'int': [(-2147483648,), (2147483647,), (None,)], + 'bigint': [(-9223372036854775808,), (9223372036854775807,), (None,)], + 'hugeint': [ + (-170141183460469231731687303715884105727,), + (170141183460469231731687303715884105727,), + (None,), + ], + 'utinyint': [(0,), (255,), (None,)], + 'usmallint': [(0,), (65535,), (None,)], + 'uint': [(0,), (4294967295,), (None,)], + 'ubigint': [(0,), (18446744073709551615,), (None,)], + 'time': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + 'float': [(-3.4028234663852886e38,), (3.4028234663852886e38,), (None,)], + 'double': [(-1.7976931348623157e308,), (1.7976931348623157e308,), (None,)], + 'dec_4_1': [(Decimal('-999.9'),), (Decimal('999.9'),), (None,)], + 'dec_9_4': [(Decimal('-99999.9999'),), (Decimal('99999.9999'),), (None,)], + 'dec_18_6': [(Decimal('-999999999999.999999'),), (Decimal('999999999999.999999'),), (None,)], + 'dec38_10': [ + (Decimal('-9999999999999999999999999999.9999999999'),), + (Decimal('9999999999999999999999999999.9999999999'),), + (None,), + ], + 'uuid': [ + (UUID('00000000-0000-0000-0000-000000000001'),), + (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (None,), + ], + 'varchar': [('🦆🦆🦆🦆🦆🦆',), ('goo\0se',), (None,)], + 'json': [('🦆🦆🦆🦆🦆🦆',), ('goose',), (None,)], + 'blob': [(b'thisisalongblob\x00withnullbytes',), (b'\x00\x00\x00a',), (None,)], + 'bit': [('0010001001011100010101011010111',), ('10101',), (None,)], + 'small_enum': [('DUCK_DUCK_ENUM',), ('GOOSE',), (None,)], + 'medium_enum': [('enum_0',), ('enum_299',), (None,)], + 'large_enum': [('enum_0',), ('enum_69999',), (None,)], + 'date_array': [ + ( + [], + [datetime.date(1970, 1, 1), None, datetime.date.min, datetime.date.max], + [ + None, + ], + ) + ], + 'timestamp_array': [ + ( + [], + [datetime.datetime(1970, 1, 1), None, datetime.datetime.min, datetime.datetime.max], + [ + None, + ], + ), + ], + 'timestamptz_array': [ + ( + [], + [datetime.datetime(1970, 1, 1), None, datetime.datetime.min, datetime.datetime.max], + [ + None, + ], + ), + ], + 'int_array': [([],), ([42, 999, None, None, -42],), (None,)], + 'varchar_array': [([],), (['🦆🦆🦆🦆🦆🦆', 'goose', None, ''],), (None,)], + 'double_array': [([],), ([42.0, float('nan'), float('inf'), float('-inf'), None, -42.0],), (None,)], + 'nested_int_array': [ + ([],), + ([[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]],), + (None,), + ], + 'struct': [({'a': None, 'b': None},), ({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'},), (None,)], + 'struct_of_arrays': [ + ({'a': None, 'b': None},), + ({'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']},), + (None,), + ], + 'array_of_structs': [([],), ([{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None],), (None,)], + 'map': [({'key': [], 'value': []},), ({'key': ['key1', 'key2'], 'value': ['🦆🦆🦆🦆🦆🦆', 'goose']},), (None,)], + 'time_tz': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + 'interval': [ + (datetime.timedelta(0),), + (datetime.timedelta(days=30969, seconds=999, microseconds=999999),), + (None,), + ], + 'timestamp': [(datetime.datetime(1990, 1, 1, 0, 0),)], + 'date': [(datetime.date(1990, 1, 1),)], + 'timestamp_s': [(datetime.datetime(1990, 1, 1, 0, 0),)], + 'timestamp_ns': [(datetime.datetime(1990, 1, 1, 0, 0),)], + 'timestamp_ms': [(datetime.datetime(1990, 1, 1, 0, 0),)], + 'timestamp_tz': [(datetime.datetime(1990, 1, 1, 0, 0),)], + 'union': [('Frank',), (5,), (None,)], + } for cur_type in all_types: if cur_type in replacement_values: - result = conn.execute("select "+replacement_values[cur_type]).fetchall() + result = conn.execute("select " + replacement_values[cur_type]).fetchall() print(cur_type, result) else: result = conn.execute(f'select "{cur_type}" from test_all_types()').fetchall() @@ -98,7 +172,6 @@ def test_bytearray_with_nulls(self): # Don't truncate the array on the nullbyte assert want == bytearray(got) - def test_fetchnumpy(self, duckdb_cursor): conn = duckdb.connect() @@ -148,12 +221,12 @@ def test_fetchnumpy(self, duckdb_cursor): dtype=np.uint64, ), 'float': np.ma.array( - [-3.4028234663852886e+38, 3.4028234663852886e+38, 42.0], + [-3.4028234663852886e38, 3.4028234663852886e38, 42.0], mask=[0, 0, 1], dtype=np.float32, ), 'double': np.ma.array( - [-1.7976931348623157e+308, 1.7976931348623157e+308, 42.0], + [-1.7976931348623157e308, 1.7976931348623157e308, 42.0], mask=[0, 0, 1], dtype=np.float64, ), @@ -296,11 +369,7 @@ def test_fetchnumpy(self, duckdb_cursor): mask=[0, 0, 1], dtype=object, ), - 'union': np.ma.array( - ['Frank', 5, None], - mask=[0, 0, 1], - dtype=object - ), + 'union': np.ma.array(['Frank', 5, None], mask=[0, 0, 1], dtype=object), } # The following types don't have a numpy equivalent, and are coerced to @@ -337,8 +406,7 @@ def test_fetchnumpy(self, duckdb_cursor): else: # assert_equal compares NaN equal, but also compares masked # elements equal to any unmasked element - if (isinstance(result, np.ma.MaskedArray) - or isinstance(correct_answer, np.ma.MaskedArray)): + if isinstance(result, np.ma.MaskedArray) or isinstance(correct_answer, np.ma.MaskedArray): assert np.all(result.mask == correct_answer.mask) np.testing.assert_equal(result, correct_answer) @@ -354,7 +422,7 @@ def test_arrow(self, duckdb_cursor): conn = duckdb.connect() for cur_type in all_types: if cur_type in replacement_values: - arrow_table = conn.execute("select "+replacement_values[cur_type]).arrow() + arrow_table = conn.execute("select " + replacement_values[cur_type]).arrow() else: arrow_table = conn.execute(f'select "{cur_type}" from test_all_types()').arrow() if cur_type in enum_types: @@ -368,7 +436,8 @@ def test_arrow(self, duckdb_cursor): def test_pandas(self): # We skip those since the extreme ranges are not supported in python. - replacement_values = { 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", + replacement_values = { + 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", 'timestamp_s': "'1990-01-01 00:00:00'::TIMESTAMP_S", 'timestamp_ns': "'1990-01-01 00:00:00'::TIMESTAMP_NS", 'timestamp_ms': "'1990-01-01 00:00:00'::TIMESTAMP_MS", @@ -377,12 +446,12 @@ def test_pandas(self): 'date_array': "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", 'timestamp_array': "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", - } + } conn = duckdb.connect() for cur_type in all_types: if cur_type in replacement_values: - dataframe = conn.execute("select "+replacement_values[cur_type]).df() + dataframe = conn.execute("select " + replacement_values[cur_type]).df() else: dataframe = conn.execute(f'select "{cur_type}" from test_all_types()').df() print(cur_type) diff --git a/tools/pythonpkg/tests/fast/test_ambiguous_prepare.py b/tools/pythonpkg/tests/fast/test_ambiguous_prepare.py index b45c8be1cd0f..998367ec186d 100644 --- a/tools/pythonpkg/tests/fast/test_ambiguous_prepare.py +++ b/tools/pythonpkg/tests/fast/test_ambiguous_prepare.py @@ -2,6 +2,7 @@ import pandas as pd import pytest + class TestAmbiguousPrepare(object): def test_bool(self, duckdb_cursor): conn = duckdb.connect() @@ -9,4 +10,3 @@ def test_bool(self, duckdb_cursor): assert res[0][0] == True assert res[0][1] == 42 assert res[0][2] == [1, 2, 3] - diff --git a/tools/pythonpkg/tests/fast/test_case_alias.py b/tools/pythonpkg/tests/fast/test_case_alias.py index 461813c29cd1..4fcbd49ca85c 100644 --- a/tools/pythonpkg/tests/fast/test_case_alias.py +++ b/tools/pythonpkg/tests/fast/test_case_alias.py @@ -5,6 +5,7 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestCaseAlias(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): @@ -14,7 +15,7 @@ def test_case_alias(self, duckdb_cursor, pandas): con = duckdb.connect(':memory:') - df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05},{"COL1": "val3", "CoL2": 17}]) + df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) r1 = con.from_df(df).query('df', 'select * from df').df() assert r1["COL1"][0] == "val1" diff --git a/tools/pythonpkg/tests/fast/test_context_manager.py b/tools/pythonpkg/tests/fast/test_context_manager.py index a42d0df453c7..2ac451d1a4ae 100644 --- a/tools/pythonpkg/tests/fast/test_context_manager.py +++ b/tools/pythonpkg/tests/fast/test_context_manager.py @@ -1,6 +1,7 @@ import duckdb + class TestContextManager(object): def test_context_manager(self, duckdb_cursor): with duckdb.connect(database=':memory:', read_only=False) as con: - assert con.execute("select 1").fetchall() == [(1,)] \ No newline at end of file + assert con.execute("select 1").fetchall() == [(1,)] diff --git a/tools/pythonpkg/tests/fast/test_filesystem.py b/tools/pythonpkg/tests/fast/test_filesystem.py index 64d2b13d6b77..94aa9afc5616 100644 --- a/tools/pythonpkg/tests/fast/test_filesystem.py +++ b/tools/pythonpkg/tests/fast/test_filesystem.py @@ -163,10 +163,12 @@ def test_database_attach(self, tmp_path: Path, monkeypatch: MonkeyPatch): # setup a database to attach later with duckdb.connect(db_path) as conn: - conn.execute(''' + conn.execute( + ''' CREATE TABLE t (id int); INSERT INTO t VALUES (0) - ''') + ''' + ) assert exists(db_path) @@ -181,7 +183,7 @@ def test_database_attach(self, tmp_path: Path, monkeypatch: MonkeyPatch): conn.execute('FROM hello.t') assert conn.fetchall() == [(0,), (1,)] - # duckdb sometimes seems to swallow write errors, so we use this to ensure that + # duckdb sometimes seems to swallow write errors, so we use this to ensure that # isn't happening assert not write_errors @@ -190,13 +192,12 @@ def test_copy_partition(self, duckdb_cursor: DuckDBPyConnection, memory: Abstrac duckdb_cursor.execute("copy (select 1 as a) to 'memory://root' (partition_by (a))") - assert memory.open( - '/root\\a=1\\data_0.csv' - if sys.platform == 'win32' else - '/root/a=1/data_0.csv' - ).read() == b'1\n' + assert ( + memory.open('/root\\a=1\\data_0.csv' if sys.platform == 'win32' else '/root/a=1/data_0.csv').read() + == b'1\n' + ) - def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): + def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute("copy (select 2 as a) to 'memory://partition' (partition_by (a))") diff --git a/tools/pythonpkg/tests/fast/test_get_table_names.py b/tools/pythonpkg/tests/fast/test_get_table_names.py index 34355f689017..968a13c49a9f 100644 --- a/tools/pythonpkg/tests/fast/test_get_table_names.py +++ b/tools/pythonpkg/tests/fast/test_get_table_names.py @@ -1,13 +1,13 @@ import duckdb import pytest -class TestGetTableNames(object): +class TestGetTableNames(object): def test_table_success(self, duckdb_cursor): conn = duckdb.connect() table_names = conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") assert table_names == {'my_table2', 'my_table3', 'my_table1'} - + def test_table_fail(self, duckdb_cursor): conn = duckdb.connect() conn.close() diff --git a/tools/pythonpkg/tests/fast/test_import_export.py b/tools/pythonpkg/tests/fast/test_import_export.py index 6fc3889e73a8..2fce1636666e 100644 --- a/tools/pythonpkg/tests/fast/test_import_export.py +++ b/tools/pythonpkg/tests/fast/test_import_export.py @@ -5,25 +5,30 @@ import os from pathlib import Path + def export_database(export_location): # Create the db con = duckdb.connect() - con.execute("create table tbl (a integer, b integer);"); - con.execute("insert into tbl values (5,1);"); + con.execute("create table tbl (a integer, b integer);") + con.execute("insert into tbl values (5,1);") # Export the db - con.execute(f"export database '{export_location}';"); + con.execute(f"export database '{export_location}';") print(f"Exported database to {export_location}") + def import_database(import_location): con = duckdb.connect() con.execute(f"import database '{import_location}'") - print(f"Imported database from {import_location}"); + print(f"Imported database from {import_location}") res = con.query("select * from tbl").fetchall() - assert res == [(5,1),] + assert res == [ + (5, 1), + ] print("Successfully queried an imported database that was moved from its original export location!") + def move_database(export_location, import_location): assert path.exists(export_location) assert path.exists(import_location) @@ -31,28 +36,27 @@ def move_database(export_location, import_location): for file in ['schema.sql', 'load.sql', 'tbl.csv']: shutil.move(path.join(export_location, file), import_location) + def export_move_and_import(export_path, import_path): export_database(export_path) move_database(export_path, import_path) import_database(import_path) + def export_and_import_empty_db(db_path, _): con = duckdb.connect() # Export the db - con.execute(f"export database '{db_path}';"); + con.execute(f"export database '{db_path}';") print(f"Exported database to {db_path}") - con.close(); + con.close() con = duckdb.connect() con.execute(f"import database '{db_path}'") -class TestDuckDBImportExport(): - @pytest.mark.parametrize('routine', [ - export_move_and_import, - export_and_import_empty_db - ]) +class TestDuckDBImportExport: + @pytest.mark.parametrize('routine', [export_move_and_import, export_and_import_empty_db]) def test_import_and_export(self, routine, tmp_path_factory): export_path = str(tmp_path_factory.mktemp("export_dbs", numbered=True)) import_path = str(tmp_path_factory.mktemp("import_dbs", numbered=True)) diff --git a/tools/pythonpkg/tests/fast/test_import_without_pyarrow_dataset.py b/tools/pythonpkg/tests/fast/test_import_without_pyarrow_dataset.py index dbd55c742472..57188b61e9a8 100644 --- a/tools/pythonpkg/tests/fast/test_import_without_pyarrow_dataset.py +++ b/tools/pythonpkg/tests/fast/test_import_without_pyarrow_dataset.py @@ -3,14 +3,16 @@ pyarrow = pytest.importorskip("pyarrow") + class TestImportWithoutPyArrowDataset: - def test_import(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setitem(sys.modules, "pyarrow.dataset", None) - import duckdb - # We should be able to import duckdb even when pyarrow.dataset is missing - con = duckdb.connect() - rel = con.query('select 1') - arrow_record_batch = rel.record_batch() - with pytest.raises(duckdb.InvalidInputException): - # The replacement scan functionality relies on pyarrow.dataset - con.query('select * from arrow_record_batch') + def test_import(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setitem(sys.modules, "pyarrow.dataset", None) + import duckdb + + # We should be able to import duckdb even when pyarrow.dataset is missing + con = duckdb.connect() + rel = con.query('select 1') + arrow_record_batch = rel.record_batch() + with pytest.raises(duckdb.InvalidInputException): + # The replacement scan functionality relies on pyarrow.dataset + con.query('select * from arrow_record_batch') diff --git a/tools/pythonpkg/tests/fast/test_insert.py b/tools/pythonpkg/tests/fast/test_insert.py index e8fa8be4d0bf..fb2833d3492d 100644 --- a/tools/pythonpkg/tests/fast/test_insert.py +++ b/tools/pythonpkg/tests/fast/test_insert.py @@ -4,19 +4,19 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestInsert(object): - - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - def test_insert(self, pandas): - test_df = pandas.DataFrame({"i":[1, 2, 3], "j":["one", "two", "three"]}) - # connect to an in-memory temporary database - conn = duckdb.connect() - # get a cursor - cursor = conn.cursor() - conn.execute("CREATE TABLE test (i INTEGER, j STRING)") - rel = conn.table("test") - rel.insert([1,'one']) - rel.insert([2,'two']) - rel.insert([3,'three']) - rel_a3 = cursor.table('test').project('CAST(i as BIGINT)i, j').to_df() - pandas.testing.assert_frame_equal(rel_a3, test_df) + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + def test_insert(self, pandas): + test_df = pandas.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) + # connect to an in-memory temporary database + conn = duckdb.connect() + # get a cursor + cursor = conn.cursor() + conn.execute("CREATE TABLE test (i INTEGER, j STRING)") + rel = conn.table("test") + rel.insert([1, 'one']) + rel.insert([2, 'two']) + rel.insert([3, 'three']) + rel_a3 = cursor.table('test').project('CAST(i as BIGINT)i, j').to_df() + pandas.testing.assert_frame_equal(rel_a3, test_df) diff --git a/tools/pythonpkg/tests/fast/test_many_con_same_file.py b/tools/pythonpkg/tests/fast/test_many_con_same_file.py index 26bee31d0c3a..58b57aa48917 100644 --- a/tools/pythonpkg/tests/fast/test_many_con_same_file.py +++ b/tools/pythonpkg/tests/fast/test_many_con_same_file.py @@ -2,6 +2,7 @@ import os import pytest + def get_tables(con): tbls = con.execute("SHOW TABLES").fetchall() tbls = [x[0] for x in tbls] @@ -32,6 +33,7 @@ def test_multiple_writes(): except: pass + def test_multiple_writes_memory(): con1 = duckdb.connect() con2 = duckdb.connect() @@ -48,6 +50,7 @@ def test_multiple_writes_memory(): del con2 del con3 + def test_multiple_writes_named_memory(): con1 = duckdb.connect(":memory:1") con2 = duckdb.connect(":memory:1") @@ -60,10 +63,13 @@ def test_multiple_writes_named_memory(): del con2 del con3 + def test_diff_config(): - con1 = duckdb.connect("test.db",False) - with pytest.raises(duckdb.ConnectionException, match="Can't open a connection to same database file with a different configuration than existing connections"): - con2 = duckdb.connect("test.db",True) + con1 = duckdb.connect("test.db", False) + with pytest.raises( + duckdb.ConnectionException, + match="Can't open a connection to same database file with a different configuration than existing connections", + ): + con2 = duckdb.connect("test.db", True) con1.close() del con1 - diff --git a/tools/pythonpkg/tests/fast/test_map.py b/tools/pythonpkg/tests/fast/test_map.py index edeece2a61af..ee8a5bf787b8 100644 --- a/tools/pythonpkg/tests/fast/test_map.py +++ b/tools/pythonpkg/tests/fast/test_map.py @@ -5,6 +5,7 @@ import re from conftest import NumpyPandas, ArrowPandas + class TestMap(object): @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_map(self, duckdb_cursor, pandas): @@ -13,9 +14,9 @@ def test_map(self, duckdb_cursor, pandas): conn.execute('CREATE TABLE t (a integer)') empty_rel = conn.table('t') - newdf1 = testrel.map(lambda df : df['col0'].add(42).to_frame()) - newdf2 = testrel.map(lambda df : df['col0'].astype('string').to_frame()) - newdf3 = testrel.map(lambda df : df) + newdf1 = testrel.map(lambda df: df['col0'].add(42).to_frame()) + newdf2 = testrel.map(lambda df: df['col0'].astype('string').to_frame()) + newdf3 = testrel.map(lambda df: df) # column count differs from bind def evil1(df): @@ -33,7 +34,7 @@ def evil2(df): # column name differs from bind def evil3(df): if len(df) == 0: - df = df.rename(columns={"col0" : "col42"}) + df = df.rename(columns={"col0": "col42"}) return df # does not return a df @@ -45,10 +46,10 @@ def evil5(df): this_makes_no_sense() def return_dataframe(df): - return pandas.DataFrame({'A' : [1]}) + return pandas.DataFrame({'A': [1]}) def return_big_dataframe(df): - return pandas.DataFrame({'A' : [1]*5000}) + return pandas.DataFrame({'A': [1] * 5000}) def return_none(df): return None @@ -65,7 +66,9 @@ def return_empty_df(df): with pytest.raises(duckdb.InvalidInputException, match='UDF column name mismatch'): print(testrel.map(evil3).df()) - with pytest.raises(duckdb.InvalidInputException, match="Expected the UDF to return an object of type 'pandas.DataFrame'"): + with pytest.raises( + duckdb.InvalidInputException, match="Expected the UDF to return an object of type 'pandas.DataFrame'" + ): print(testrel.map(evil4).df()) with pytest.raises(duckdb.InvalidInputException): @@ -79,12 +82,14 @@ def return_empty_df(df): with pytest.raises(TypeError): print(testrel.map().df()) - testrel.map(return_dataframe).df().equals(pandas.DataFrame({'A' : [1]})) - - with pytest.raises(duckdb.InvalidInputException, match='UDF returned more than 2048 rows, which is not allowed.'): + testrel.map(return_dataframe).df().equals(pandas.DataFrame({'A': [1]})) + + with pytest.raises( + duckdb.InvalidInputException, match='UDF returned more than 2048 rows, which is not allowed.' + ): testrel.map(return_big_dataframe).df() - empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({'A' : []})) + empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({'A': []})) with pytest.raises(duckdb.InvalidInputException, match='No return value from Python function'): testrel.map(return_none).df() @@ -108,17 +113,25 @@ def process(rel): def mapper(x): dates = x['date'].to_numpy("datetime64[us]") days = x['days_to_add'].to_numpy("int") - x["result1"] = pandas.Series([pandas.to_datetime(y[0]).date() + timedelta(days=y[1].item()) for y in zip(dates,days)], dtype='datetime64[us]') - x["result2"] = pandas.Series([pandas.to_datetime(y[0]).date() + timedelta(days=-y[1].item()) for y in zip(dates,days)], dtype='datetime64[us]') + x["result1"] = pandas.Series( + [pandas.to_datetime(y[0]).date() + timedelta(days=y[1].item()) for y in zip(dates, days)], + dtype='datetime64[us]', + ) + x["result2"] = pandas.Series( + [pandas.to_datetime(y[0]).date() + timedelta(days=-y[1].item()) for y in zip(dates, days)], + dtype='datetime64[us]', + ) return x rel = rel.map(mapper) rel = rel.project("*, datediff('day', date, result1) as one") rel = rel.project("*, datediff('day', date, result2) as two") - rel = rel.project("*, IF(ABS(one) > ABS(two), one, two) as three") + rel = rel.project("*, IF(ABS(one) > ABS(two), one, two) as three") return rel - df = pandas.DataFrame({'date': pandas.Series([date(2000,1,1), date(2000,1,2)], dtype="datetime64[us]"), 'days_to_add': [1,2]}) + df = pandas.DataFrame( + {'date': pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), 'days_to_add': [1, 2]} + ) rel = duckdb.from_df(df) rel = process(rel) x = rel.fetchdf() @@ -143,7 +156,9 @@ def does_nothing(df): rel = con.sql('select i from range(10) tbl(i)') # expects the mapper to return a string column rel = rel.map(does_nothing, schema={'i': str}) - with pytest.raises(duckdb.InvalidInputException, match=re.escape("UDF column type mismatch, expected [VARCHAR], got [BIGINT]")): + with pytest.raises( + duckdb.InvalidInputException, match=re.escape("UDF column type mismatch, expected [VARCHAR], got [BIGINT]") + ): rel.fetchall() @pytest.mark.parametrize('pandas', [NumpyPandas()]) @@ -164,40 +179,50 @@ def no_op(df): con = duckdb.connect() rel = con.sql('select 42') - with pytest.raises(duckdb.InvalidInputException, match=re.escape("Invalid Input Error: 'schema' should be given as a Dict[str, DuckDBType]")): + with pytest.raises( + duckdb.InvalidInputException, + match=re.escape("Invalid Input Error: 'schema' should be given as a Dict[str, DuckDBType]"), + ): rel.map(no_op, schema=[int]) @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_returns_non_dataframe(self, pandas): def returns_series(df): - return df.loc[:,'i'] + return df.loc[:, 'i'] con = duckdb.connect() rel = con.sql('select i, i as j from range(10) tbl(i)') - with pytest.raises(duckdb.InvalidInputException, match=re.escape("Expected the UDF to return an object of type 'pandas.DataFrame', found '' instead")): + with pytest.raises( + duckdb.InvalidInputException, + match=re.escape( + "Expected the UDF to return an object of type 'pandas.DataFrame', found '' instead" + ), + ): rel = rel.map(returns_series) @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_explicit_schema_columncount_mismatch(self, pandas): def returns_subset(df): - return pandas.DataFrame({'i': df.loc[:,'i']}) + return pandas.DataFrame({'i': df.loc[:, 'i']}) con = duckdb.connect() rel = con.sql('select i, i as j from range(10) tbl(i)') rel = rel.map(returns_subset, schema={'i': int, 'j': int}) - with pytest.raises(duckdb.InvalidInputException, match='Invalid Input Error: Expected 2 columns from UDF, got 1'): + with pytest.raises( + duckdb.InvalidInputException, match='Invalid Input Error: Expected 2 columns from UDF, got 1' + ): rel.fetchall() @pytest.mark.parametrize('pandas', [NumpyPandas()]) def test_pyarrow_df(self, pandas): # PyArrow backed dataframes only exist on pandas >= 2.0.0 _ = pytest.importorskip("pandas", "2.0.0") - + def basic_function(df): # Create a pyarrow backed dataframe - df = pandas.DataFrame({'a': [5,3,2,1,2]}).convert_dtypes(dtype_backend='pyarrow') + df = pandas.DataFrame({'a': [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend='pyarrow') return df - + con = duckdb.connect() with pytest.raises(duckdb.InvalidInputException): rel = con.sql('select 42').map(basic_function) diff --git a/tools/pythonpkg/tests/fast/test_memory_leaks.py b/tools/pythonpkg/tests/fast/test_memory_leaks.py index 9cad4bccfe3a..4e4a8d637cf9 100644 --- a/tools/pythonpkg/tests/fast/test_memory_leaks.py +++ b/tools/pythonpkg/tests/fast/test_memory_leaks.py @@ -4,6 +4,7 @@ import os, psutil import pandas as pd + @pytest.fixture def check_leaks(): process = psutil.Process(os.getpid()) @@ -17,10 +18,11 @@ def check_leaks(): # Assert that the amount of used memory does not pass 5mb assert difference <= 5_000_000 + class TestMemoryLeaks(object): def test_fetchmany(self, check_leaks): datetimes = ['1985-01-30T16:41:43' for _ in range(10000)] - df = pd.DataFrame({'time' : pd.Series(data=datetimes)}) + df = pd.DataFrame({'time': pd.Series(data=datetimes)}) for _ in range(100): duckdb.sql('select time::TIMESTAMP from df').fetchmany(10000) diff --git a/tools/pythonpkg/tests/fast/test_multi_statement.py b/tools/pythonpkg/tests/fast/test_multi_statement.py index 7e5d5bd7e67e..db82eaf3e5a3 100644 --- a/tools/pythonpkg/tests/fast/test_multi_statement.py +++ b/tools/pythonpkg/tests/fast/test_multi_statement.py @@ -2,20 +2,24 @@ import os import shutil + class TestMultiStatement(object): def test_multi_statement(self, duckdb_cursor): import duckdb + con = duckdb.connect(':memory:') # test empty statement con.execute('') # run multiple statements in one call to execute - con.execute(''' + con.execute( + ''' CREATE TABLE integers(i integer); insert into integers select * from range(10); select * from integers; - ''') + ''' + ) results = [x[0] for x in con.fetchall()] assert results == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] diff --git a/tools/pythonpkg/tests/fast/test_multithread.py b/tools/pythonpkg/tests/fast/test_multithread.py index 32a4da9a9fd0..9357a3273175 100644 --- a/tools/pythonpkg/tests/fast/test_multithread.py +++ b/tools/pythonpkg/tests/fast/test_multithread.py @@ -5,32 +5,40 @@ import numpy as np from conftest import NumpyPandas, ArrowPandas import os + try: import pyarrow as pa + can_run = True except: can_run = False + def connect_duck(duckdb_conn): out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() assert out == [(42,), (84,), (None,), (128,)] + class DuckDBThreaded: - def __init__(self,duckdb_insert_thread_count,thread_function, pandas): + def __init__(self, duckdb_insert_thread_count, thread_function, pandas): self.duckdb_insert_thread_count = duckdb_insert_thread_count self.threads = [] self.thread_function = thread_function self.pandas = pandas - - def multithread_test(self,if_all_true=True): + + def multithread_test(self, if_all_true=True): duckdb_conn = duckdb.connect() queue = Queue.Queue() return_value = False - for i in range(0,self.duckdb_insert_thread_count): - self.threads.append(threading.Thread(target=self.thread_function, args=(duckdb_conn,queue, self.pandas),name='duckdb_thread_'+str(i))) + for i in range(0, self.duckdb_insert_thread_count): + self.threads.append( + threading.Thread( + target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name='duckdb_thread_' + str(i) + ) + ) - for i in range(0,len(self.threads)): + for i in range(0, len(self.threads)): self.threads[i].start() if not if_all_true: if queue.get(): @@ -40,21 +48,21 @@ def multithread_test(self,if_all_true=True): return_value = True elif queue.get() and return_value: return_value = True - - for i in range(0,len(self.threads)): + + for i in range(0, len(self.threads)): self.threads[i].join() - assert (return_value) + assert return_value def execute_query_same_connection(duckdb_conn, queue, pandas): - try: out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)') queue.put(False) except: queue.put(True) + def execute_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() @@ -64,6 +72,7 @@ def execute_query(duckdb_conn, queue, pandas): except: queue.put(False) + def insert_runtime_error(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() @@ -71,24 +80,29 @@ def insert_runtime_error(duckdb_conn, queue, pandas): duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') queue.put(False) except: - queue.put(True) + queue.put(True) + def execute_many_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: # from python docs - duckdb_conn.execute('''CREATE TABLE stocks - (date text, trans text, symbol text, qty real, price real)''') + duckdb_conn.execute( + '''CREATE TABLE stocks + (date text, trans text, symbol text, qty real, price real)''' + ) # Larger example that inserts many records at a time - purchases = [('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), - ] + purchases = [ + ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), + ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), + ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ] duckdb_conn.executemany('INSERT INTO stocks VALUES (?,?,?,?,?)', purchases) queue.put(True) except: - queue.put(False) + queue.put(False) + def fetchone_query(duckdb_conn, queue, pandas): # Get a new connection @@ -97,7 +111,8 @@ def fetchone_query(duckdb_conn, queue, pandas): duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchone() queue.put(True) except: - queue.put(False) + queue.put(False) + def fetchall_query(duckdb_conn, queue, pandas): # Get a new connection @@ -106,7 +121,8 @@ def fetchall_query(duckdb_conn, queue, pandas): duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() queue.put(True) except: - queue.put(False) + queue.put(False) + def conn_close(duckdb_conn, queue, pandas): # Get a new connection @@ -115,7 +131,8 @@ def conn_close(duckdb_conn, queue, pandas): duckdb_conn.close() queue.put(True) except: - queue.put(False) + queue.put(False) + def fetchnp_query(duckdb_conn, queue, pandas): # Get a new connection @@ -124,7 +141,8 @@ def fetchnp_query(duckdb_conn, queue, pandas): duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchnumpy() queue.put(True) except: - queue.put(False) + queue.put(False) + def fetchdf_query(duckdb_conn, queue, pandas): # Get a new connection @@ -135,6 +153,7 @@ def fetchdf_query(duckdb_conn, queue, pandas): except: queue.put(False) + def fetchdf_chunk_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() @@ -142,7 +161,8 @@ def fetchdf_chunk_query(duckdb_conn, queue, pandas): duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_df_chunk() queue.put(True) except: - queue.put(False) + queue.put(False) + def fetch_arrow_query(duckdb_conn, queue, pandas): # Get a new connection @@ -151,7 +171,7 @@ def fetch_arrow_query(duckdb_conn, queue, pandas): duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_arrow_table() queue.put(True) except: - queue.put(False) + queue.put(False) def fetch_record_batch_query(duckdb_conn, queue, pandas): @@ -161,7 +181,8 @@ def fetch_record_batch_query(duckdb_conn, queue, pandas): duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_record_batch() queue.put(True) except: - queue.put(False) + queue.put(False) + def transaction_query(duckdb_conn, queue, pandas): # Get a new connection @@ -175,50 +196,55 @@ def transaction_query(duckdb_conn, queue, pandas): duckdb_conn.commit() queue.put(True) except: - queue.put(False) + queue.put(False) + def df_append(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") - df = pandas.DataFrame(np.random.randint(0,100,size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) try: - duckdb_conn.append('T',df) + duckdb_conn.append('T', df) queue.put(True) except: - queue.put(False) + queue.put(False) + def df_register(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0,100,size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) try: - duckdb_conn.register('T',df) + duckdb_conn.register('T', df) queue.put(True) except: - queue.put(False) + queue.put(False) + def df_unregister(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0,100,size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) try: - duckdb_conn.register('T',df) + duckdb_conn.register('T', df) duckdb_conn.unregister('T') queue.put(True) except: - queue.put(False) + queue.put(False) + def arrow_register_unregister(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column':pa.array([1,2,3,4,5],type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: - duckdb_conn.register('T',arrow_tbl) + duckdb_conn.register('T', arrow_tbl) duckdb_conn.unregister('T') queue.put(True) except: - queue.put(False) + queue.put(False) + def table(duckdb_conn, queue, pandas): # Get a new connection @@ -228,7 +254,8 @@ def table(duckdb_conn, queue, pandas): out = duckdb_conn.table('T') queue.put(True) except: - queue.put(False) + queue.put(False) + def view(duckdb_conn, queue, pandas): # Get a new connection @@ -239,7 +266,9 @@ def view(duckdb_conn, queue, pandas): out = duckdb_conn.values([5, 'five']) queue.put(True) except: - queue.put(False) + queue.put(False) + + def values(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() @@ -247,7 +276,7 @@ def values(duckdb_conn, queue, pandas): out = duckdb_conn.values([5, 'five']) queue.put(True) except: - queue.put(False) + queue.put(False) def from_query(duckdb_conn, queue, pandas): @@ -257,48 +286,53 @@ def from_query(duckdb_conn, queue, pandas): out = duckdb_conn.from_query("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(True) except: - queue.put(False) + queue.put(False) + def from_df(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(['bla', 'blabla']*10, columns=['A']) + df = pandas.DataFrame(['bla', 'blabla'] * 10, columns=['A']) try: out = duckdb_conn.execute("select * from df").fetchall() queue.put(True) except: queue.put(False) + def from_arrow(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column':pa.array([1,2,3,4,5],type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: out = duckdb_conn.from_arrow(arrow_tbl) queue.put(True) except: queue.put(False) + def from_csv_auto(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','integers.csv') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') try: out = duckdb_conn.from_csv_auto(filename) queue.put(True) except: - queue.put(False) + queue.put(False) + def from_parquet(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','binary_string.parquet') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') try: out = duckdb_conn.from_parquet(filename) queue.put(True) except: queue.put(False) + def description(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() @@ -310,150 +344,149 @@ def description(duckdb_conn, queue, pandas): rel.description queue.put(True) except: - queue.put(False) + queue.put(False) + def cursor(duckdb_conn, queue, pandas): # Get a new connection - cx = duckdb_conn.cursor() + cx = duckdb_conn.cursor() try: cx.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') queue.put(False) except: - queue.put(True) + queue.put(True) -class TestDuckMultithread(object): +class TestDuckMultithread(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_execute(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,execute_query, pandas) + duck_threads = DuckDBThreaded(10, execute_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_execute_many(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,execute_many_query, pandas) + duck_threads = DuckDBThreaded(10, execute_many_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetchone(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,fetchone_query, pandas) + duck_threads = DuckDBThreaded(10, fetchone_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetchall(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,fetchall_query, pandas) + duck_threads = DuckDBThreaded(10, fetchall_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_close(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,conn_close, pandas) + duck_threads = DuckDBThreaded(10, conn_close, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetchnp(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,fetchnp_query, pandas) + duck_threads = DuckDBThreaded(10, fetchnp_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetchdf(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,fetchdf_query, pandas) + duck_threads = DuckDBThreaded(10, fetchdf_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetchdfchunk(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,fetchdf_chunk_query, pandas) + duck_threads = DuckDBThreaded(10, fetchdf_chunk_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetcharrow(self, duckdb_cursor, pandas): if not can_run: return - duck_threads = DuckDBThreaded(10,fetch_arrow_query, pandas) + duck_threads = DuckDBThreaded(10, fetch_arrow_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_fetch_record_batch(self, duckdb_cursor, pandas): if not can_run: return - duck_threads = DuckDBThreaded(10,fetch_record_batch_query, pandas) + duck_threads = DuckDBThreaded(10, fetch_record_batch_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_transaction(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,transaction_query, pandas) + duck_threads = DuckDBThreaded(10, transaction_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_df_append(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,df_append, pandas) + duck_threads = DuckDBThreaded(10, df_append, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_df_register(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,df_register, pandas) + duck_threads = DuckDBThreaded(10, df_register, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_df_unregister(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,df_unregister, pandas) + duck_threads = DuckDBThreaded(10, df_unregister, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_arrow_register_unregister(self, duckdb_cursor, pandas): if not can_run: return - duck_threads = DuckDBThreaded(10,arrow_register_unregister, pandas) + duck_threads = DuckDBThreaded(10, arrow_register_unregister, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_table(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,table, pandas) + duck_threads = DuckDBThreaded(10, table, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_view(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,view, pandas) + duck_threads = DuckDBThreaded(10, view, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_values(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,values, pandas) + duck_threads = DuckDBThreaded(10, values, pandas) duck_threads.multithread_test() - + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_from_query(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,from_query, pandas) + duck_threads = DuckDBThreaded(10, from_query, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_from_DF(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,from_df, pandas) - duck_threads.multithread_test() + duck_threads = DuckDBThreaded(10, from_df, pandas) + duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_from_arrow(self, duckdb_cursor, pandas): if not can_run: return - duck_threads = DuckDBThreaded(10,from_arrow, pandas) + duck_threads = DuckDBThreaded(10, from_arrow, pandas) duck_threads.multithread_test() - + @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_from_csv_auto(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,from_csv_auto, pandas) + duck_threads = DuckDBThreaded(10, from_csv_auto, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_from_parquet(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,from_parquet, pandas) + duck_threads = DuckDBThreaded(10, from_parquet, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_description(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,description, pandas) + duck_threads = DuckDBThreaded(10, description, pandas) duck_threads.multithread_test() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_cursor(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10,cursor, pandas) + duck_threads = DuckDBThreaded(10, cursor, pandas) duck_threads.multithread_test(False) - - diff --git a/tools/pythonpkg/tests/fast/test_non_default_conn.py b/tools/pythonpkg/tests/fast/test_non_default_conn.py index b446d5891342..b994a6eceb52 100644 --- a/tools/pythonpkg/tests/fast/test_non_default_conn.py +++ b/tools/pythonpkg/tests/fast/test_non_default_conn.py @@ -4,31 +4,31 @@ import os import tempfile -class TestNonDefaultConn(object): +class TestNonDefaultConn(object): def test_values(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") - duckdb.values([1],conn).insert_into("t") - assert conn.execute("select count(*) from t").fetchall()[0] == (1,) + duckdb.values([1], conn).insert_into("t") + assert conn.execute("select count(*) from t").fetchall()[0] == (1,) def test_query(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1)") - assert duckdb.query("select count(*) from t",connection=conn).execute().fetchall()[0] == (1,) - assert duckdb.from_query("select count(*) from t",connection=conn).execute().fetchall()[0] == (1,) + assert duckdb.query("select count(*) from t", connection=conn).execute().fetchall()[0] == (1,) + assert duckdb.from_query("select count(*) from t", connection=conn).execute().fetchall()[0] == (1,) def test_from_csv(self, duckdb_cursor): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_csv(temp_file_name, index=False) rel = duckdb.from_csv_auto(temp_file_name, conn) - assert rel.query('t_2','select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) - + assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + def test_from_parquet(self, duckdb_cursor): try: import pyarrow as pa @@ -38,21 +38,21 @@ def test_from_parquet(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_parquet(temp_file_name, index=False) rel = duckdb.from_parquet(temp_file_name, connection=conn) - assert rel.query('t_2','select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) - + assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + def test_from_df(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.df(test_df, connection=conn) - assert rel.query('t_2','select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) rel = duckdb.from_df(test_df, connection=conn) - assert rel.query('t_2','select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) - + assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + def test_from_arrow(self, duckdb_cursor): try: import pyarrow as pa @@ -62,65 +62,65 @@ def test_from_arrow(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_arrow = pa.Table.from_pandas(test_df) rel = duckdb.from_arrow(test_arrow, connection=conn) - assert rel.query('t_2','select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) rel = duckdb.arrow(test_arrow, connection=conn) - assert rel.query('t_2','select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) - + assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + def test_filter_df(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1), (4)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) - rel = duckdb.filter(test_df,"i < 2", connection=conn) - assert rel.query('t_2','select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) + rel = duckdb.filter(test_df, "i < 2", connection=conn) + assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) def test_project_df(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1), (4)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4],"j":[1, 2, 3, 4]}) - rel = duckdb.project(test_df,"i", connection=conn) - assert rel.query('t_2','select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) - + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) + rel = duckdb.project(test_df, "i", connection=conn) + assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + def test_agg_df(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1), (4)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4],"j":[1, 2, 3, 4]}) - rel = duckdb.aggregate(test_df,"count(*) as i", connection=conn) - assert rel.query('t_2','select * from t inner join t_2 on (a = i)').fetchall()[0] == (4, 4) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) + rel = duckdb.aggregate(test_df, "count(*) as i", connection=conn) + assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (4, 4) def test_distinct_df(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1)") - test_df = pd.DataFrame.from_dict({"i":[1,1, 2, 3, 4]}) + test_df = pd.DataFrame.from_dict({"i": [1, 1, 2, 3, 4]}) rel = duckdb.distinct(test_df, connection=conn) - assert rel.query('t_2','select * from t inner join t_2 on (a = i)').fetchall()[0] == (1,1) + assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) def test_limit_df(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1),(4)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) - rel = duckdb.limit(test_df,1, connection=conn) - assert rel.query('t_2','select * from t inner join t_2 on (a = i)').fetchall()[0] == (1,1) - + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) + rel = duckdb.limit(test_df, 1, connection=conn) + assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + def test_query_df(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1),(4)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) - rel = duckdb.query_df(test_df,'t_2','select * from t inner join t_2 on (a = i)', connection=conn) - assert rel.fetchall()[0] == (1,1) - + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) + rel = duckdb.query_df(test_df, 't_2', 'select * from t inner join t_2 on (a = i)', connection=conn) + assert rel.fetchall()[0] == (1, 1) + def test_query_order(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table t (a integer)") conn.execute("insert into t values (1),(4)") - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) - rel = duckdb.order(test_df,'i', connection=conn) - assert rel.query('t_2','select * from t inner join t_2 on (a = i)').fetchall()[0] == (1,1) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) + rel = duckdb.order(test_df, 'i', connection=conn) + assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) diff --git a/tools/pythonpkg/tests/fast/test_parameter_list.py b/tools/pythonpkg/tests/fast/test_parameter_list.py index f883c98d6fa7..2421c76e08e1 100644 --- a/tools/pythonpkg/tests/fast/test_parameter_list.py +++ b/tools/pythonpkg/tests/fast/test_parameter_list.py @@ -2,24 +2,29 @@ import pytest from conftest import NumpyPandas, ArrowPandas -class TestParameterList(object): + +class TestParameterList(object): def test_bool(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table bool_table (a bool)") conn.execute("insert into bool_table values (TRUE)") - res = conn.execute("select count(*) from bool_table where a =?",[True]) + res = conn.execute("select count(*) from bool_table where a =?", [True]) assert res.fetchone()[0] == 1 @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_exception(self, duckdb_cursor, pandas): conn = duckdb.connect() - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) conn.execute("create table bool_table (a bool)") conn.execute("insert into bool_table values (TRUE)") with pytest.raises(duckdb.NotImplementedException, match='Unable to transform'): - res = conn.execute("select count(*) from bool_table where a =?",[df_in]) + res = conn.execute("select count(*) from bool_table where a =?", [df_in]) def test_explicit_nan_param(self): con = duckdb.default_connection res = con.execute('select isnan(cast(? as double))', (float("nan"),)) - assert(res.fetchone()[0] == True) + assert res.fetchone()[0] == True diff --git a/tools/pythonpkg/tests/fast/test_parquet.py b/tools/pythonpkg/tests/fast/test_parquet.py index c6e7a3d772a8..498fe53ef3c2 100644 --- a/tools/pythonpkg/tests/fast/test_parquet.py +++ b/tools/pythonpkg/tests/fast/test_parquet.py @@ -7,22 +7,23 @@ VARCHAR = duckdb.typing.VARCHAR BIGINT = duckdb.typing.BIGINT -filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','binary_string.parquet') +filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') + @pytest.fixture(scope="session") def tmp_parquets(tmp_path_factory): tmp_dir = tmp_path_factory.mktemp('parquets', numbered=True) - tmp_parquets = [str(tmp_dir / ('tmp'+str(i)+'.parquet')) for i in range(1, 4)] + tmp_parquets = [str(tmp_dir / ('tmp' + str(i) + '.parquet')) for i in range(1, 4)] return tmp_parquets -class TestParquet(object): +class TestParquet(object): def test_scan_binary(self, duckdb_cursor): conn = duckdb.connect() - res = conn.execute("SELECT typeof(#1) FROM parquet_scan('"+filename+"') limit 1").fetchall() + res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() assert res[0] == ('BLOB',) - res = conn.execute("SELECT * FROM parquet_scan('"+filename+"')").fetchall() + res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() assert res[0] == (b'foo',) def test_from_parquet_binary(self, duckdb_cursor): @@ -34,88 +35,104 @@ def test_from_parquet_binary(self, duckdb_cursor): def test_scan_binary_as_string(self, duckdb_cursor): conn = duckdb.connect() - res = conn.execute("SELECT typeof(#1) FROM parquet_scan('"+filename+"',binary_as_string=True) limit 1").fetchall() + res = conn.execute( + "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=True) limit 1" + ).fetchall() assert res[0] == ('VARCHAR',) - res = conn.execute("SELECT * FROM parquet_scan('"+filename+"',binary_as_string=True)").fetchall() + res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=True)").fetchall() assert res[0] == ('foo',) def test_from_parquet_binary_as_string(self, duckdb_cursor): - rel = duckdb.from_parquet(filename,True) + rel = duckdb.from_parquet(filename, True) assert rel.types == [VARCHAR] res = rel.execute().fetchall() assert res[0] == ('foo',) def test_from_parquet_file_row_number(self, duckdb_cursor): - rel = duckdb.from_parquet(filename,binary_as_string=True,file_row_number=True) + rel = duckdb.from_parquet(filename, binary_as_string=True, file_row_number=True) assert rel.types == [VARCHAR, BIGINT] res = rel.execute().fetchall() - assert res[0] == ('foo',0,) + assert res[0] == ( + 'foo', + 0, + ) def test_from_parquet_filename(self, duckdb_cursor): - rel = duckdb.from_parquet(filename,binary_as_string=True,filename=True) + rel = duckdb.from_parquet(filename, binary_as_string=True, filename=True) assert rel.types == [VARCHAR, VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',filename,) + assert res[0] == ( + 'foo', + filename, + ) def test_from_parquet_list_binary_as_string(self, duckdb_cursor): - rel = duckdb.from_parquet([filename],binary_as_string=True) + rel = duckdb.from_parquet([filename], binary_as_string=True) assert rel.types == [VARCHAR] res = rel.execute().fetchall() assert res[0] == ('foo',) def test_from_parquet_list_file_row_number(self, duckdb_cursor): - rel = duckdb.from_parquet([filename],binary_as_string=True,file_row_number=True) + rel = duckdb.from_parquet([filename], binary_as_string=True, file_row_number=True) assert rel.types == [VARCHAR, BIGINT] res = rel.execute().fetchall() - assert res[0] == ('foo',0,) + assert res[0] == ( + 'foo', + 0, + ) def test_from_parquet_list_filename(self, duckdb_cursor): - rel = duckdb.from_parquet([filename],binary_as_string=True,filename=True) + rel = duckdb.from_parquet([filename], binary_as_string=True, filename=True) assert rel.types == [VARCHAR, VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',filename,) + assert res[0] == ( + 'foo', + filename, + ) def test_parquet_binary_as_string_pragma(self, duckdb_cursor): conn = duckdb.connect() - res = conn.execute("SELECT typeof(#1) FROM parquet_scan('"+filename+"') limit 1").fetchall() + res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() assert res[0] == ('BLOB',) - res = conn.execute("SELECT * FROM parquet_scan('"+filename+"')").fetchall() + res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() assert res[0] == (b'foo',) conn.execute("PRAGMA binary_as_string=1") - res = conn.execute("SELECT typeof(#1) FROM parquet_scan('"+filename+"') limit 1").fetchall() + res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() assert res[0] == ('VARCHAR',) - res = conn.execute("SELECT * FROM parquet_scan('"+filename+"')").fetchall() + res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() assert res[0] == ('foo',) - res = conn.execute("SELECT typeof(#1) FROM parquet_scan('"+filename+"',binary_as_string=False) limit 1").fetchall() + res = conn.execute( + "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=False) limit 1" + ).fetchall() assert res[0] == ('BLOB',) - res = conn.execute("SELECT * FROM parquet_scan('"+filename+"',binary_as_string=False)").fetchall() + res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=False)").fetchall() assert res[0] == (b'foo',) conn.execute("PRAGMA binary_as_string=0") - res = conn.execute("SELECT typeof(#1) FROM parquet_scan('"+filename+"') limit 1").fetchall() + res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() assert res[0] == ('BLOB',) - res = conn.execute("SELECT * FROM parquet_scan('"+filename+"')").fetchall() + res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() assert res[0] == (b'foo',) def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): duckdb.default_connection.execute("PRAGMA binary_as_string=1") - rel = duckdb.from_parquet(filename,True) + rel = duckdb.from_parquet(filename, True) assert rel.types == [VARCHAR] res = rel.execute().fetchall() @@ -124,22 +141,59 @@ def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): def test_from_parquet_union_by_name(self, tmp_parquets): conn = duckdb.connect() - conn.execute("copy (from (values (1::bigint), (2::bigint), (9223372036854775807::bigint)) t(a)) to '"+tmp_parquets[0]+"' (format 'parquet');") - - conn.execute("copy (from (values (3::integer, 4::integer), (5::integer, 6::integer)) t(a, b)) to '"+tmp_parquets[1]+"' (format 'parquet');") - - conn.execute("copy (from (values (100::integer, 101::integer), (102::integer, 103::integer)) t(a, c)) to '"+tmp_parquets[2]+"' (format 'parquet');") - - rel = duckdb.from_parquet(tmp_parquets, union_by_name = True).order('a') + conn.execute( + "copy (from (values (1::bigint), (2::bigint), (9223372036854775807::bigint)) t(a)) to '" + + tmp_parquets[0] + + "' (format 'parquet');" + ) + + conn.execute( + "copy (from (values (3::integer, 4::integer), (5::integer, 6::integer)) t(a, b)) to '" + + tmp_parquets[1] + + "' (format 'parquet');" + ) + + conn.execute( + "copy (from (values (100::integer, 101::integer), (102::integer, 103::integer)) t(a, c)) to '" + + tmp_parquets[2] + + "' (format 'parquet');" + ) + + rel = duckdb.from_parquet(tmp_parquets, union_by_name=True).order('a') assert rel.execute().fetchall() == [ - (1, None, None,), - (2, None, None,), - (3, 4, None,), - (5, 6, None,), - (100, None, 101,), - (102, None, 103,), - (9223372036854775807, None, None,), + ( + 1, + None, + None, + ), + ( + 2, + None, + None, + ), + ( + 3, + 4, + None, + ), + ( + 5, + 6, + None, + ), + ( + 100, + None, + 101, + ), + ( + 102, + None, + 103, + ), + ( + 9223372036854775807, + None, + None, + ), ] - - - diff --git a/tools/pythonpkg/tests/fast/test_pytorch.py b/tools/pythonpkg/tests/fast/test_pytorch.py index 411f2c2f8281..65566d3d2b52 100644 --- a/tools/pythonpkg/tests/fast/test_pytorch.py +++ b/tools/pythonpkg/tests/fast/test_pytorch.py @@ -4,36 +4,37 @@ torch = pytest.importorskip('torch') + def test_pytorch(): - con = duckdb.connect() - - con.execute("create table t( a integer, b integer)") - con.execute("insert into t values (1,2), (3,4)") - - # Test from connection - duck_torch = con.execute("select * from t").torch() - duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) - - # Test from relation - duck_torch = con.sql("select * from t").torch() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) - - # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] - - for supported_type in numeric_types: - con = duckdb.connect() - con.execute(f"create table t( a {supported_type} , b {supported_type})") - con.execute("insert into t values (1,2), (3,4)") - duck_torch = con.sql("select * from t").torch() - duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) - - with pytest.raises(TypeError, match="can't convert"): - con = duckdb.connect() - con.execute(f"create table t( a UINTEGER)") - duck_torch = con.sql("select * from t").torch() \ No newline at end of file + con = duckdb.connect() + + con.execute("create table t( a integer, b integer)") + con.execute("insert into t values (1,2), (3,4)") + + # Test from connection + duck_torch = con.execute("select * from t").torch() + duck_numpy = con.sql("select * from t").fetchnumpy() + torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) + torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + + # Test from relation + duck_torch = con.sql("select * from t").torch() + torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) + torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + + # Test all Numeric Types + numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + + for supported_type in numeric_types: + con = duckdb.connect() + con.execute(f"create table t( a {supported_type} , b {supported_type})") + con.execute("insert into t values (1,2), (3,4)") + duck_torch = con.sql("select * from t").torch() + duck_numpy = con.sql("select * from t").fetchnumpy() + torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) + torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + + with pytest.raises(TypeError, match="can't convert"): + con = duckdb.connect() + con.execute(f"create table t( a UINTEGER)") + duck_torch = con.sql("select * from t").torch() diff --git a/tools/pythonpkg/tests/fast/test_relation.py b/tools/pythonpkg/tests/fast/test_relation.py index c26116c315d1..b945ec569f31 100644 --- a/tools/pythonpkg/tests/fast/test_relation.py +++ b/tools/pythonpkg/tests/fast/test_relation.py @@ -7,17 +7,19 @@ from duckdb.typing import BIGINT, VARCHAR, TINYINT, BOOLEAN + def get_relation(conn): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) conn.register("test_df", test_df) return conn.from_df(test_df) + class TestRelation(object): def test_csv_auto(self, duckdb_cursor): conn = duckdb.connect() df_rel = get_relation(conn) temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) test_df.to_csv(temp_file_name, index=False) # now create a relation from it @@ -45,16 +47,16 @@ def test_limit_operator(self, duckdb_cursor): assert rel.limit(2).execute().fetchall() == [(1, 'one'), (2, 'two')] assert rel.limit(2, offset=1).execute().fetchall() == [(2, 'two'), (3, 'three')] - def test_intersect_operator(self,duckdb_cursor): + def test_intersect_operator(self, duckdb_cursor): conn = duckdb.connect() - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) conn.register("test_df", test_df) - test_df_2 = pd.DataFrame.from_dict({"i":[3, 4, 5, 6]}) + test_df_2 = pd.DataFrame.from_dict({"i": [3, 4, 5, 6]}) conn.register("test_df", test_df_2) rel = conn.from_df(test_df) rel_2 = conn.from_df(test_df_2) - assert rel.intersect(rel_2).execute().fetchall() == [ (3,), (4,)] + assert rel.intersect(rel_2).execute().fetchall() == [(3,), (4,)] def test_aggregate_operator(self, duckdb_cursor): conn = duckdb.connect() @@ -65,46 +67,70 @@ def test_aggregate_operator(self, duckdb_cursor): def test_distinct_operator(self, duckdb_cursor): conn = duckdb.connect() rel = get_relation(conn) - assert rel.distinct().execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'),(4, 'four')] + assert rel.distinct().execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] def test_union_operator(self, duckdb_cursor): conn = duckdb.connect() rel = get_relation(conn) print(rel.union(rel).execute().fetchall()) - assert rel.union(rel).execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four'), (1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + assert rel.union(rel).execute().fetchall() == [ + (1, 'one'), + (2, 'two'), + (3, 'three'), + (4, 'four'), + (1, 'one'), + (2, 'two'), + (3, 'three'), + (4, 'four'), + ] def test_join_operator(self, duckdb_cursor): # join rel with itself on i conn = duckdb.connect() - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel2 = conn.from_df(test_df) - assert rel.join(rel2, 'i').execute().fetchall() == [(1, 'one', 'one'), (2, 'two', 'two'), (3, 'three', 'three'), (4, 'four', 'four')] - - def test_except_operator(self,duckdb_cursor): + assert rel.join(rel2, 'i').execute().fetchall() == [ + (1, 'one', 'one'), + (2, 'two', 'two'), + (3, 'three', 'three'), + (4, 'four', 'four'), + ] + + def test_except_operator(self, duckdb_cursor): conn = duckdb.connect() - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel2 = conn.from_df(test_df) assert rel.except_(rel2).execute().fetchall() == [] - def test_create_operator(self,duckdb_cursor): + def test_create_operator(self, duckdb_cursor): conn = duckdb.connect() - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel.create("test_df") - assert conn.query("select * from test_df").execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'),(4, 'four')] - - def test_create_view_operator(self,duckdb_cursor): + assert conn.query("select * from test_df").execute().fetchall() == [ + (1, 'one'), + (2, 'two'), + (3, 'three'), + (4, 'four'), + ] + + def test_create_view_operator(self, duckdb_cursor): conn = duckdb.connect() - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel.create_view("test_df") - assert conn.query("select * from test_df").execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'),(4, 'four')] - - def test_insert_into_operator(self,duckdb_cursor): + assert conn.query("select * from test_df").execute().fetchall() == [ + (1, 'one'), + (2, 'two'), + (3, 'three'), + (4, 'four'), + ] + + def test_insert_into_operator(self, duckdb_cursor): conn = duckdb.connect() - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel.create("test_table2") # insert the relation's data into an existing table @@ -114,11 +140,18 @@ def test_insert_into_operator(self,duckdb_cursor): # Inserting elements into table_3 print(conn.values([5, 'five']).insert_into("test_table3")) rel_3 = conn.table("test_table3") - rel_3.insert([6,'six']) - - assert rel_3.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four'), (5, 'five'), (6, 'six')] - - def test_write_csv_operator(self,duckdb_cursor): + rel_3.insert([6, 'six']) + + assert rel_3.execute().fetchall() == [ + (1, 'one'), + (2, 'two'), + (3, 'three'), + (4, 'four'), + (5, 'five'), + (6, 'six'), + ] + + def test_write_csv_operator(self, duckdb_cursor): conn = duckdb.connect() df_rel = get_relation(conn) temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) @@ -127,7 +160,7 @@ def test_write_csv_operator(self,duckdb_cursor): csv_rel = duckdb.from_csv_auto(temp_file_name) assert df_rel.execute().fetchall() == csv_rel.execute().fetchall() - def test_get_attr_operator(self,duckdb_cursor): + def test_get_attr_operator(self, duckdb_cursor): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") @@ -136,28 +169,29 @@ def test_get_attr_operator(self,duckdb_cursor): assert rel.columns == ['i'] assert rel.types == ['INTEGER'] - def test_query_fail(self,duckdb_cursor): + def test_query_fail(self, duckdb_cursor): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") with pytest.raises(TypeError, match='incompatible function arguments'): rel.query("select j from test") - def test_execute_fail(self,duckdb_cursor): + def test_execute_fail(self, duckdb_cursor): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") with pytest.raises(TypeError, match='incompatible function arguments'): rel.execute("select j from test") - def test_df_proj(self,duckdb_cursor): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + def test_df_proj(self, duckdb_cursor): + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = duckdb.project(test_df, 'i') assert rel.execute().fetchall() == [(1,), (2,), (3,), (4,)] def test_project_on_types(self, duckdb_cursor): con = duckdb_cursor - con.sql(""" + con.sql( + """ create table tbl( c0 BIGINT, c1 TINYINT, @@ -166,7 +200,8 @@ def test_project_on_types(self, duckdb_cursor): c4 VARCHAR, c5 STRUCT(a VARCHAR, b BIGINT) ) - """) + """ + ) rel = con.table("tbl") # select only the varchar columns projection = rel.select_types(["varchar"]) @@ -179,42 +214,41 @@ def test_project_on_types(self, duckdb_cursor): ## select with empty projection list, not possible with pytest.raises(duckdb.Error): projection = rel.select_types([]) - + # select with type-filter that matches nothing with pytest.raises(duckdb.Error): projection = rel.select_types([BOOLEAN]) - def test_df_alias(self,duckdb_cursor): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + def test_df_alias(self, duckdb_cursor): + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = duckdb.alias(test_df, 'dfzinho') assert rel.alias == "dfzinho" - def test_df_filter(self,duckdb_cursor): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + def test_df_filter(self, duckdb_cursor): + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = duckdb.filter(test_df, 'i > 1') assert rel.execute().fetchall() == [(2, 'two'), (3, 'three'), (4, 'four')] - def test_df_order_by(self,duckdb_cursor): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + def test_df_order_by(self, duckdb_cursor): + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = duckdb.order(test_df, 'j') assert rel.execute().fetchall() == [(4, 'four'), (1, 'one'), (3, 'three'), (2, 'two')] - def test_df_distinct(self,duckdb_cursor): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + def test_df_distinct(self, duckdb_cursor): + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = duckdb.distinct(test_df) - assert rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'),(4, 'four')] + assert rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] - def test_df_write_csv(self,duckdb_cursor): - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3, 4], "j":["one", "two", "three", "four"]}) + def test_df_write_csv(self, duckdb_cursor): + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) duckdb.write_csv(test_df, temp_file_name) csv_rel = duckdb.from_csv_auto(temp_file_name) - assert csv_rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] - + assert csv_rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] def test_join_types(self, duckdb_cursor): - test_df1 = pd.DataFrame.from_dict({"i":[1, 2, 3, 4]}) - test_df2 = pd.DataFrame.from_dict({"j":[ 3, 4, 5, 6]}) + test_df1 = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) + test_df2 = pd.DataFrame.from_dict({"j": [3, 4, 5, 6]}) rel1 = duckdb_cursor.from_df(test_df1) rel2 = duckdb_cursor.from_df(test_df2) diff --git a/tools/pythonpkg/tests/fast/test_relation_dependency_leak.py b/tools/pythonpkg/tests/fast/test_relation_dependency_leak.py index 739b6618ae41..bb3502eff677 100644 --- a/tools/pythonpkg/tests/fast/test_relation_dependency_leak.py +++ b/tools/pythonpkg/tests/fast/test_relation_dependency_leak.py @@ -2,36 +2,43 @@ import numpy as np import os, psutil import pytest + try: import pyarrow as pa + can_run = True except: can_run = False from conftest import NumpyPandas, ArrowPandas + def check_memory(function_to_check, pandas): process = psutil.Process(os.getpid()) - mem_usage = process.memory_info().rss/(10**9) + mem_usage = process.memory_info().rss / (10**9) for __ in range(100): function_to_check(pandas) - cur_mem_usage = process.memory_info().rss/(10**9) + cur_mem_usage = process.memory_info().rss / (10**9) # This seems a good empirical value - assert cur_mem_usage/3 < mem_usage + assert cur_mem_usage / 3 < mem_usage + def from_df(pandas): df = pandas.DataFrame({"x": np.random.rand(1_000_000)}) return duckdb.from_df(df) + def from_arrow(pandas): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data],['a']) + arrow_table = pa.Table.from_arrays([data], ['a']) duckdb.from_arrow(arrow_table) + def arrow_replacement(pandas): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data],['a']) + arrow_table = pa.Table.from_arrays([data], ['a']) duckdb.query("select sum(a) from arrow_table").fetchall() + def pandas_replacement(pandas): df = pandas.DataFrame({"x": np.random.rand(1_000_000)}) duckdb.query("select sum(x) from df").fetchall() @@ -64,5 +71,3 @@ def test_relation_view_leak(self, duckdb_cursor, pandas): rel.create_view("bla") duckdb.default_connection.unregister("bla") assert rel.query("bla", "select count(*) from bla").fetchone()[0] == 1_000_000 - - diff --git a/tools/pythonpkg/tests/fast/test_replacement_scan.py b/tools/pythonpkg/tests/fast/test_replacement_scan.py index d2ff85a7dcb1..ce01cb4cbbbf 100644 --- a/tools/pythonpkg/tests/fast/test_replacement_scan.py +++ b/tools/pythonpkg/tests/fast/test_replacement_scan.py @@ -2,46 +2,49 @@ import os import pytest + class TestReplacementScan(object): - def test_csv_replacement(self, duckdb_cursor): - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','integers.csv') - res = duckdb_cursor.execute("select count(*) from '%s'"%(filename)) - assert res.fetchone()[0] == 2 - - def test_parquet_replacement(self, duckdb_cursor): - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),'data','binary_string.parquet') - res = duckdb_cursor.execute("select count(*) from '%s'"%(filename)) - assert res.fetchone()[0] == 3 - - def test_replacement_scan_relapi(self): - pyrel1 = duckdb.query('from (values (42), (84), (120)) t(i)') - assert isinstance(pyrel1, duckdb.DuckDBPyRelation) - assert (pyrel1.fetchall() == [(42,), (84,), (120,)]) - - pyrel2 = duckdb.query('from pyrel1 limit 2') - assert isinstance(pyrel2, duckdb.DuckDBPyRelation) - assert (pyrel2.fetchall() == [(42,), (84,)]) - - pyrel3 = duckdb.query('select i + 100 from pyrel2') - assert (type(pyrel3) == duckdb.DuckDBPyRelation) - assert (pyrel3.fetchall() == [(142,), (184,)]) - - def test_replacement_scan_alias(self): - pyrel1 = duckdb.query('from (values (1, 2)) t(i, j)') - pyrel2 = duckdb.query('from (values (1, 10)) t(i, k)') - pyrel3 = duckdb.query('from pyrel1 join pyrel2 using(i)') - assert (type(pyrel3) == duckdb.DuckDBPyRelation) - assert (pyrel3.fetchall() == [(1, 2, 10)]) - - def test_replacement_scan_pandas_alias(self): - df1 = duckdb.query('from (values (1, 2)) t(i, j)').df() - df2 = duckdb.query('from (values (1, 10)) t(i, k)').df() - df3 = duckdb.query('from df1 join df2 using(i)') - assert (df3.fetchall() == [(1, 2, 10)]) - - def test_replacement_scan_fail(self, duckdb_cursor): - random_object = "I love salmiak rondos" - con = duckdb.connect() - with pytest.raises(duckdb.InvalidInputException, - match=r'Python Object "random_object" of type "str" found on line .* not suitable for replacement scans.'): - con.execute("select count(*) from random_object").fetchone() + def test_csv_replacement(self, duckdb_cursor): + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') + res = duckdb_cursor.execute("select count(*) from '%s'" % (filename)) + assert res.fetchone()[0] == 2 + + def test_parquet_replacement(self, duckdb_cursor): + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') + res = duckdb_cursor.execute("select count(*) from '%s'" % (filename)) + assert res.fetchone()[0] == 3 + + def test_replacement_scan_relapi(self): + pyrel1 = duckdb.query('from (values (42), (84), (120)) t(i)') + assert isinstance(pyrel1, duckdb.DuckDBPyRelation) + assert pyrel1.fetchall() == [(42,), (84,), (120,)] + + pyrel2 = duckdb.query('from pyrel1 limit 2') + assert isinstance(pyrel2, duckdb.DuckDBPyRelation) + assert pyrel2.fetchall() == [(42,), (84,)] + + pyrel3 = duckdb.query('select i + 100 from pyrel2') + assert type(pyrel3) == duckdb.DuckDBPyRelation + assert pyrel3.fetchall() == [(142,), (184,)] + + def test_replacement_scan_alias(self): + pyrel1 = duckdb.query('from (values (1, 2)) t(i, j)') + pyrel2 = duckdb.query('from (values (1, 10)) t(i, k)') + pyrel3 = duckdb.query('from pyrel1 join pyrel2 using(i)') + assert type(pyrel3) == duckdb.DuckDBPyRelation + assert pyrel3.fetchall() == [(1, 2, 10)] + + def test_replacement_scan_pandas_alias(self): + df1 = duckdb.query('from (values (1, 2)) t(i, j)').df() + df2 = duckdb.query('from (values (1, 10)) t(i, k)').df() + df3 = duckdb.query('from df1 join df2 using(i)') + assert df3.fetchall() == [(1, 2, 10)] + + def test_replacement_scan_fail(self, duckdb_cursor): + random_object = "I love salmiak rondos" + con = duckdb.connect() + with pytest.raises( + duckdb.InvalidInputException, + match=r'Python Object "random_object" of type "str" found on line .* not suitable for replacement scans.', + ): + con.execute("select count(*) from random_object").fetchone() diff --git a/tools/pythonpkg/tests/fast/test_result.py b/tools/pythonpkg/tests/fast/test_result.py index 2a91d7e01311..34c8e18762bd 100644 --- a/tools/pythonpkg/tests/fast/test_result.py +++ b/tools/pythonpkg/tests/fast/test_result.py @@ -1,8 +1,9 @@ import duckdb import pytest import datetime -class TestPythonResult(object): + +class TestPythonResult(object): def test_result_closed(self, duckdb_cursor): connection = duckdb.connect('') cursor = connection.cursor() @@ -29,16 +30,31 @@ def test_result_describe_types(self, duckdb_cursor): cursor.execute("INSERT INTO test VALUES (TRUE, '01:01:01', 'bla' )") rel = connection.table("test") res = rel.execute() - assert res.description == [('i', 'bool', None, None, None, None, None), ('j', 'Time', None, None, None, None, None), ('k', 'STRING', None, None, None, None, None)] + assert res.description == [ + ('i', 'bool', None, None, None, None, None), + ('j', 'Time', None, None, None, None, None), + ('k', 'STRING', None, None, None, None, None), + ] def test_result_timestamps(self, duckdb_cursor): connection = duckdb.connect('') cursor = connection.cursor() - cursor.execute('CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );') - cursor.execute("INSERT INTO timestamps VALUES ('2008-01-01 00:00:11','2008-01-01 00:00:01.794','2008-01-01 00:00:01.98926','2008-01-01 00:00:01.899268321' )") + cursor.execute( + 'CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );' + ) + cursor.execute( + "INSERT INTO timestamps VALUES ('2008-01-01 00:00:11','2008-01-01 00:00:01.794','2008-01-01 00:00:01.98926','2008-01-01 00:00:01.899268321' )" + ) rel = connection.table("timestamps") - assert rel.execute().fetchall() == [(datetime.datetime(2008, 1, 1, 0, 0, 11), datetime.datetime(2008, 1, 1, 0, 0, 1, 794000), datetime.datetime(2008, 1, 1, 0, 0, 1, 989260), datetime.datetime(2008, 1, 1, 0, 0, 1, 899268))] + assert rel.execute().fetchall() == [ + ( + datetime.datetime(2008, 1, 1, 0, 0, 11), + datetime.datetime(2008, 1, 1, 0, 0, 1, 794000), + datetime.datetime(2008, 1, 1, 0, 0, 1, 989260), + datetime.datetime(2008, 1, 1, 0, 0, 1, 899268), + ) + ] def test_result_interval(self): connection = duckdb.connect() @@ -49,7 +65,11 @@ def test_result_interval(self): rel = connection.table("intervals") res = rel.execute() assert res.description == [('ivals', 'TIMEDELTA', None, None, None, None, None)] - assert res.fetchall() == [(datetime.timedelta(days=1.0),), (datetime.timedelta(seconds=2.0),), (datetime.timedelta(microseconds=1.0),)] + assert res.fetchall() == [ + (datetime.timedelta(days=1.0),), + (datetime.timedelta(seconds=2.0),), + (datetime.timedelta(microseconds=1.0),), + ] def test_description_uuid(self): connection = duckdb.connect() diff --git a/tools/pythonpkg/tests/fast/test_runtime_error.py b/tools/pythonpkg/tests/fast/test_runtime_error.py index 390a35f550ea..6c6fb459b4ab 100644 --- a/tools/pythonpkg/tests/fast/test_runtime_error.py +++ b/tools/pythonpkg/tests/fast/test_runtime_error.py @@ -5,6 +5,7 @@ closed = lambda: pytest.raises(duckdb.ConnectionException, match='Connection has already been closed') no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match='No open result set') + class TestRuntimeError(object): def test_fetch_error(self): con = duckdb.connect() @@ -29,7 +30,9 @@ def test_arrow_error(self): def test_register_error(self): con = duckdb.connect() py_obj = "this is a string" - with pytest.raises(duckdb.InvalidInputException, match='Python Object str not suitable to be registered as a view'): + with pytest.raises( + duckdb.InvalidInputException, match='Python Object str not suitable to be registered as a view' + ): con.register(py_obj, "v") def test_arrow_fetch_table_error(self): @@ -54,11 +57,14 @@ def test_arrow_record_batch_reader_error(self): with pytest.raises(duckdb.ProgrammingError, match='There is no query result'): res.fetch_arrow_reader(1) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_relation_fetchall_error(self, pandas): conn = duckdb.connect() - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in @@ -68,7 +74,11 @@ def test_relation_fetchall_error(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_relation_fetchall_execute(self, pandas): conn = duckdb.connect() - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in @@ -78,7 +88,11 @@ def test_relation_fetchall_execute(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_relation_query_error(self, pandas): conn = duckdb.connect() - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in @@ -88,7 +102,11 @@ def test_relation_query_error(self, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_conn_broken_statement_error(self, pandas): conn = duckdb.connect() - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) conn.execute("create view x as select * from df_in") del df_in with pytest.raises(duckdb.InvalidInputException): @@ -98,16 +116,20 @@ def test_conn_prepared_statement_error(self): conn = duckdb.connect() conn.execute("create table integers (a integer, b integer)") with pytest.raises(duckdb.InvalidInputException, match='Prepared statement needs 2 parameters, 1 given'): - conn.execute("select * from integers where a =? and b=?",[1]) + conn.execute("select * from integers where a =? and b=?", [1]) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_closed_conn_exceptions(self, pandas): conn = duckdb.connect() conn.close() - df_in = pandas.DataFrame({'numbers': [1,2,3,4,5],}) + df_in = pandas.DataFrame( + { + 'numbers': [1, 2, 3, 4, 5], + } + ) with closed(): - conn.register("bla",df_in) + conn.register("bla", df_in) with closed(): conn.from_query("select 1") diff --git a/tools/pythonpkg/tests/fast/test_tf.py b/tools/pythonpkg/tests/fast/test_tf.py index ebe770456a6f..b65acec6d412 100644 --- a/tools/pythonpkg/tests/fast/test_tf.py +++ b/tools/pythonpkg/tests/fast/test_tf.py @@ -4,31 +4,32 @@ tf = pytest.importorskip('tensorflow') + def test_tf(): - con = duckdb.connect() - - con.execute("create table t( a integer, b integer)") - con.execute("insert into t values (1,2), (3,4)") - - # Test from connection - duck_tf = con.execute("select * from t").tf() - duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) - - # Test from relation - duck_tf = con.sql("select * from t").tf() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) - - # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] - - for supported_type in numeric_types: - con = duckdb.connect() - con.execute(f"create table t( a {supported_type} , b {supported_type})") - con.execute("insert into t values (1,2), (3,4)") - duck_tf = con.sql("select * from t").tf() - duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) \ No newline at end of file + con = duckdb.connect() + + con.execute("create table t( a integer, b integer)") + con.execute("insert into t values (1,2), (3,4)") + + # Test from connection + duck_tf = con.execute("select * from t").tf() + duck_numpy = con.sql("select * from t").fetchnumpy() + tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) + tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + + # Test from relation + duck_tf = con.sql("select * from t").tf() + tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) + tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + + # Test all Numeric Types + numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + + for supported_type in numeric_types: + con = duckdb.connect() + con.execute(f"create table t( a {supported_type} , b {supported_type})") + con.execute("insert into t values (1,2), (3,4)") + duck_tf = con.sql("select * from t").tf() + duck_numpy = con.sql("select * from t").fetchnumpy() + tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) + tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) diff --git a/tools/pythonpkg/tests/fast/test_transaction.py b/tools/pythonpkg/tests/fast/test_transaction.py index dbb0f6ca9a74..54deaf8207a9 100644 --- a/tools/pythonpkg/tests/fast/test_transaction.py +++ b/tools/pythonpkg/tests/fast/test_transaction.py @@ -1,19 +1,20 @@ import duckdb import pandas as pd + class TestConnectionTransaction(object): def test_transaction(self, duckdb_cursor): con = duckdb.connect() con.execute('create table t (i integer)') - con.execute ('insert into t values (1)') + con.execute('insert into t values (1)') con.begin() - con.execute ('insert into t values (1)') + con.execute('insert into t values (1)') assert con.execute('select count (*) from t').fetchone()[0] == 2 con.rollback() assert con.execute('select count (*) from t').fetchone()[0] == 1 con.begin() - con.execute ('insert into t values (1)') + con.execute('insert into t values (1)') assert con.execute('select count (*) from t').fetchone()[0] == 2 con.commit() assert con.execute('select count (*) from t').fetchone()[0] == 2 diff --git a/tools/pythonpkg/tests/fast/test_type.py b/tools/pythonpkg/tests/fast/test_type.py index 9c60ce52df82..ce8dda592e7b 100644 --- a/tools/pythonpkg/tests/fast/test_type.py +++ b/tools/pythonpkg/tests/fast/test_type.py @@ -4,7 +4,35 @@ import pytest from typing import Union -from duckdb.typing import SQLNULL, BOOLEAN, TINYINT, UTINYINT, SMALLINT, USMALLINT, INTEGER, UINTEGER, BIGINT, UBIGINT, HUGEINT, UUID, FLOAT, DOUBLE, DATE, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, TIME, TIME_TZ, TIMESTAMP_TZ, VARCHAR, BLOB, BIT, INTERVAL +from duckdb.typing import ( + SQLNULL, + BOOLEAN, + TINYINT, + UTINYINT, + SMALLINT, + USMALLINT, + INTEGER, + UINTEGER, + BIGINT, + UBIGINT, + HUGEINT, + UUID, + FLOAT, + DOUBLE, + DATE, + TIMESTAMP, + TIMESTAMP_MS, + TIMESTAMP_NS, + TIMESTAMP_S, + TIME, + TIME_TZ, + TIMESTAMP_TZ, + VARCHAR, + BLOB, + BIT, + INTERVAL, +) + class TestType(object): def test_sqltype(self): @@ -75,23 +103,17 @@ def test_union_type(self): assert str(type) == 'UNION(a BIGINT, b VARCHAR, c TINYINT)' import sys - @pytest.mark.skipif(sys.version_info < (3,9), reason="requires >= python3.9") + + @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires >= python3.9") def test_implicit_convert_from_builtin_type(self): type = duckdb.list_type(list[str]) assert str(type.child) == "VARCHAR[]" - mapping = { - 'VARCHAR': str, - 'BIGINT': int, - 'BLOB': bytes, - 'BLOB': bytearray, - 'BOOLEAN': bool, - 'DOUBLE': float - } + mapping = {'VARCHAR': str, 'BIGINT': int, 'BLOB': bytes, 'BLOB': bytearray, 'BOOLEAN': bool, 'DOUBLE': float} for expected, type in mapping.items(): res = duckdb.list_type(type) assert str(res.child) == expected - + res = duckdb.list_type({'a': str, 'b': int}) assert str(res.child) == 'STRUCT(a VARCHAR, b BIGINT)' @@ -122,7 +144,7 @@ def test_implicit_convert_from_numpy(self, duckdb_cursor): 'uint64': 'UBIGINT', 'float16': 'FLOAT', 'float32': 'FLOAT', - 'float64': 'DOUBLE' + 'float64': 'DOUBLE', } builtins = [] diff --git a/tools/pythonpkg/tests/fast/test_unicode.py b/tools/pythonpkg/tests/fast/test_unicode.py index 371e1bbcd39b..b697f84aa333 100644 --- a/tools/pythonpkg/tests/fast/test_unicode.py +++ b/tools/pythonpkg/tests/fast/test_unicode.py @@ -4,9 +4,10 @@ import duckdb import pandas as pd + class TestUnicode(object): def test_unicode_pandas_scan(self, duckdb_cursor): con = duckdb.connect(database=':memory:', read_only=False) - test_df = pd.DataFrame.from_dict({"i":[1, 2, 3], "j":["a", "c", u"ë"]}) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", u"ë"]}) con.register('test_df_view', test_df) con.execute('SELECT i, j, LENGTH(j) FROM test_df_view').fetchall() diff --git a/tools/pythonpkg/tests/fast/test_value.py b/tools/pythonpkg/tests/fast/test_value.py index ee2e97266986..7dabd8cafd44 100644 --- a/tools/pythonpkg/tests/fast/test_value.py +++ b/tools/pythonpkg/tests/fast/test_value.py @@ -29,7 +29,7 @@ TimestampNanosecondValue, TimestampTimeZoneValue, TimeValue, - TimeTimeZoneValue + TimeTimeZoneValue, ) import uuid import datetime @@ -62,12 +62,15 @@ VARCHAR, BLOB, BIT, - INTERVAL + INTERVAL, ) + class TestValue(object): # This excludes timezone aware values, as those are a pain to test - @pytest.mark.parametrize('item', [ + @pytest.mark.parametrize( + 'item', + [ (BOOLEAN, BooleanValue(True), True), (UTINYINT, UnsignedBinaryValue(129), 129), (USMALLINT, UnsignedShortValue(12356), 12356), @@ -80,17 +83,30 @@ class TestValue(object): (HUGEINT, HugeIntegerValue(-1), -1), (FLOAT, FloatValue(1.8349000215530396), 1.8349000215530396), (DOUBLE, DoubleValue(0.23234234234), 0.23234234234), - (duckdb.decimal_type(12, 8), DecimalValue(decimal.Decimal('1234.12345678'), 12, 8), decimal.Decimal('1234.12345678')), + ( + duckdb.decimal_type(12, 8), + DecimalValue(decimal.Decimal('1234.12345678'), 12, 8), + decimal.Decimal('1234.12345678'), + ), (VARCHAR, StringValue('this is a long string'), 'this is a long string'), - (UUID, UUIDValue(uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), + ( + UUID, + UUIDValue(uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), + uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), + ), (BIT, BitValue(b'010101010101'), '010101010101'), (BLOB, BlobValue(b'\x00\x00\x00a'), b'\x00\x00\x00a'), (DATE, DateValue(datetime.date(2000, 5, 4)), datetime.date(2000, 5, 4)), (INTERVAL, IntervalValue(datetime.timedelta(days=5)), datetime.timedelta(days=5)), - (TIMESTAMP, TimestampValue(datetime.datetime(1970, 3, 21, 12, 5, 43, 120)), datetime.datetime(1970, 3, 21, 12, 5, 43, 120)), + ( + TIMESTAMP, + TimestampValue(datetime.datetime(1970, 3, 21, 12, 5, 43, 120)), + datetime.datetime(1970, 3, 21, 12, 5, 43, 120), + ), (SQLNULL, NullValue(), None), - (TIME, TimeValue(datetime.time(12, 3, 12, 80)), datetime.time(12, 3, 12, 80)) - ]) + (TIME, TimeValue(datetime.time(12, 3, 12, 80)), datetime.time(12, 3, 12, 80)), + ], + ) def test_value_helpers(self, item): expected_type = item[0] value_object = item[1] @@ -107,56 +123,63 @@ def test_value_helpers(self, item): def test_float_to_decimal_prevention(self): value = DecimalValue(1.2345, 12, 8) - + con = duckdb.connect() with pytest.raises(duckdb.ConversionException, match="Can't losslessly convert"): con.execute('select $1', [value]).fetchall() - @pytest.mark.parametrize('value', [ - TimestampSecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), - TimestampMilisecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), - TimestampNanosecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)) - ]) + @pytest.mark.parametrize( + 'value', + [ + TimestampSecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), + TimestampMilisecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), + TimestampNanosecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), + ], + ) def test_timestamp_sec_not_supported(self, value): con = duckdb.connect() - with pytest.raises(duckdb.NotImplementedException, match="Conversion from 'datetime' to type .* is not implemented yet"): + with pytest.raises( + duckdb.NotImplementedException, match="Conversion from 'datetime' to type .* is not implemented yet" + ): con.execute('select $1', [value]).fetchall() - - @pytest.mark.parametrize('test', [ - (TINYINT, 0, True), - (TINYINT, 255, False), - (TINYINT, -128, True), - (UTINYINT, 80, True), - (UTINYINT, -1, False), - (UTINYINT, 255, True), - (SMALLINT, 0, True), - (SMALLINT, 128, True), - (SMALLINT, -255, True), - (SMALLINT, -32780, False), - (USMALLINT, 0, True), - (USMALLINT, -1, False), - (USMALLINT, 1337, True), - (USMALLINT, 32780, True), - (INTEGER, 0, True), - (INTEGER, 32780, True), - (INTEGER, -32780, True), - (INTEGER, -1337, True), - (UINTEGER, 0, True), - (UINTEGER, -1337, False), - (UINTEGER, 65534, True), - (BIGINT, 0, True), - (BIGINT, -1234567, True), - (BIGINT, 9223372036854775808, False), - (UBIGINT, 9223372036854775808, True), - (UBIGINT, -1, False), - (UBIGINT, 18446744073709551615, True), - (HUGEINT, -9223372036854775808, True), - (HUGEINT, 9223372036854775807, True), - (HUGEINT, 0, True), - (HUGEINT, -1337, True), - (HUGEINT, 12334214123, True) - ]) + @pytest.mark.parametrize( + 'test', + [ + (TINYINT, 0, True), + (TINYINT, 255, False), + (TINYINT, -128, True), + (UTINYINT, 80, True), + (UTINYINT, -1, False), + (UTINYINT, 255, True), + (SMALLINT, 0, True), + (SMALLINT, 128, True), + (SMALLINT, -255, True), + (SMALLINT, -32780, False), + (USMALLINT, 0, True), + (USMALLINT, -1, False), + (USMALLINT, 1337, True), + (USMALLINT, 32780, True), + (INTEGER, 0, True), + (INTEGER, 32780, True), + (INTEGER, -32780, True), + (INTEGER, -1337, True), + (UINTEGER, 0, True), + (UINTEGER, -1337, False), + (UINTEGER, 65534, True), + (BIGINT, 0, True), + (BIGINT, -1234567, True), + (BIGINT, 9223372036854775808, False), + (UBIGINT, 9223372036854775808, True), + (UBIGINT, -1, False), + (UBIGINT, 18446744073709551615, True), + (HUGEINT, -9223372036854775808, True), + (HUGEINT, 9223372036854775807, True), + (HUGEINT, 0, True), + (HUGEINT, -1337, True), + (HUGEINT, 12334214123, True), + ], + ) def test_numeric_values(self, test): target_type = test[0] test_value = test[1] diff --git a/tools/pythonpkg/tests/fast/test_windows_abs_path.py b/tools/pythonpkg/tests/fast/test_windows_abs_path.py index fb73e3ed5bdb..bc9f05ec0a39 100644 --- a/tools/pythonpkg/tests/fast/test_windows_abs_path.py +++ b/tools/pythonpkg/tests/fast/test_windows_abs_path.py @@ -3,6 +3,7 @@ import os import shutil + class TestWindowsAbsPath(object): def test_windows_path_accent(self): if os.name != 'nt': @@ -23,7 +24,7 @@ def test_windows_path_accent(self): del con os.chdir('tést') - dbpath = os.path.join('..', dbpath) + dbpath = os.path.join('..', dbpath) con = duckdb.connect(dbpath) res = con.execute("SELECT COUNT(*) FROM int").fetchall() assert res[0][0] == 10 diff --git a/tools/pythonpkg/tests/fast/types/test_blob.py b/tools/pythonpkg/tests/fast/types/test_blob.py index 40c748e77dc0..162859d29060 100644 --- a/tools/pythonpkg/tests/fast/types/test_blob.py +++ b/tools/pythonpkg/tests/fast/types/test_blob.py @@ -1,6 +1,7 @@ import duckdb import numpy + class TestBlob(object): def test_blob(self, duckdb_cursor): duckdb_cursor.execute("SELECT BLOB 'hello'") diff --git a/tools/pythonpkg/tests/fast/types/test_boolean.py b/tools/pythonpkg/tests/fast/types/test_boolean.py index a4e30dbf170d..8e8d21473d56 100644 --- a/tools/pythonpkg/tests/fast/types/test_boolean.py +++ b/tools/pythonpkg/tests/fast/types/test_boolean.py @@ -1,8 +1,9 @@ import duckdb import numpy + class TestBoolean(object): def test_bool(self, duckdb_cursor): duckdb_cursor.execute("SELECT TRUE") results = duckdb_cursor.fetchall() - assert results[0][0] == True \ No newline at end of file + assert results[0][0] == True diff --git a/tools/pythonpkg/tests/fast/types/test_decimal.py b/tools/pythonpkg/tests/fast/types/test_decimal.py index 881c94f5c920..30cb13e7f9e1 100644 --- a/tools/pythonpkg/tests/fast/types/test_decimal.py +++ b/tools/pythonpkg/tests/fast/types/test_decimal.py @@ -1,21 +1,26 @@ - import numpy import pandas from decimal import * + class TestDecimal(object): def test_decimal(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL') + duckdb_cursor.execute( + 'SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL' + ) result = duckdb_cursor.fetchall() - assert result == [(Decimal('1.2'), Decimal('100.3'), Decimal('320938.4298'), Decimal('49082094824.904820482094'), None)] + assert result == [ + (Decimal('1.2'), Decimal('100.3'), Decimal('320938.4298'), Decimal('49082094824.904820482094'), None) + ] def test_decimal_numpy(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d') + duckdb_cursor.execute( + 'SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d' + ) result = duckdb_cursor.fetchnumpy() - assert result == {'a': numpy.array([1.2]), - 'b': numpy.array([100.3]), - 'c': numpy.array([320938.4298]), - 'd': numpy.array([49082094824.904820482094])} - - - + assert result == { + 'a': numpy.array([1.2]), + 'b': numpy.array([100.3]), + 'c': numpy.array([320938.4298]), + 'd': numpy.array([49082094824.904820482094]), + } diff --git a/tools/pythonpkg/tests/fast/types/test_hugeint.py b/tools/pythonpkg/tests/fast/types/test_hugeint.py index 51d735f7aa14..f025438011b4 100644 --- a/tools/pythonpkg/tests/fast/types/test_hugeint.py +++ b/tools/pythonpkg/tests/fast/types/test_hugeint.py @@ -1,4 +1,3 @@ - import numpy import pandas @@ -8,10 +7,8 @@ def test_hugeint(self, duckdb_cursor): duckdb_cursor.execute('SELECT 437894723897234238947043214') result = duckdb_cursor.fetchall() assert result == [(437894723897234238947043214,)] - + def test_hugeint_numpy(self, duckdb_cursor): duckdb_cursor.execute('SELECT 1::HUGEINT AS i') result = duckdb_cursor.fetchnumpy() assert result == {'i': numpy.array([1.0])} - - diff --git a/tools/pythonpkg/tests/fast/types/test_nan.py b/tools/pythonpkg/tests/fast/types/test_nan.py index 3c0957adae43..ab244f7b2187 100644 --- a/tools/pythonpkg/tests/fast/types/test_nan.py +++ b/tools/pythonpkg/tests/fast/types/test_nan.py @@ -4,11 +4,12 @@ import pytest from conftest import NumpyPandas, ArrowPandas + class TestPandasNaN(object): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_pandas_nan(self, duckdb_cursor, pandas): # create a DataFrame with some basic values - df = pandas.DataFrame([{"col1": "val1", "col2": 1.05},{"col1": "val3", "col2": np.NaN}]) + df = pandas.DataFrame([{"col1": "val1", "col2": 1.05}, {"col1": "val3", "col2": np.NaN}]) # create a new column (newcol1) that includes either NaN or values from col1 df["newcol1"] = np.where(df["col1"] == "val1", np.NaN, df["col1"]) # now create a new column with the current time @@ -16,7 +17,7 @@ def test_pandas_nan(self, duckdb_cursor, pandas): current_time = datetime.datetime.now().replace(microsecond=0) df['datetest'] = current_time # introduce a NaT (Not a Time value) - df.loc[0,'datetest'] = pandas.NaT + df.loc[0, 'datetest'] = pandas.NaT # now pass the DF through duckdb: conn = duckdb.connect(':memory:') diff --git a/tools/pythonpkg/tests/fast/types/test_nested.py b/tools/pythonpkg/tests/fast/types/test_nested.py index 48de3542ba6b..dcfc7dc71f30 100644 --- a/tools/pythonpkg/tests/fast/types/test_nested.py +++ b/tools/pythonpkg/tests/fast/types/test_nested.py @@ -2,7 +2,6 @@ class TestNested(object): - def test_lists(self, duckdb_cursor): duckdb_conn = duckdb.connect() result = duckdb_conn.execute("SELECT LIST_VALUE(1, 2, 3, 4) ").fetchall() @@ -14,7 +13,6 @@ def test_lists(self, duckdb_cursor): result = duckdb_conn.execute("SELECT LIST_VALUE(1, 2, 3, NULL) ").fetchall() assert result == [([1, 2, 3, None],)] - def test_nested_lists(self, duckdb_cursor): duckdb_conn = duckdb.connect() result = duckdb_conn.execute("SELECT LIST_VALUE(LIST_VALUE(1, 2, 3, 4), LIST_VALUE(1, 2, 3, 4)) ").fetchall() @@ -30,24 +28,22 @@ def test_struct(self, duckdb_cursor): result = duckdb_conn.execute("SELECT STRUCT_PACK(a := 42, b := NULL)").fetchall() assert result == [({'a': 42, 'b': None},)] - def test_nested_struct(self, duckdb_cursor): duckdb_conn = duckdb.connect() result = duckdb_conn.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, 7))").fetchall() - assert result == [({'a': 42, 'b': [10,9,8,7]},)] + assert result == [({'a': 42, 'b': [10, 9, 8, 7]},)] result = duckdb_conn.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, NULL))").fetchall() - assert result == [({'a': 42, 'b': [10,9,8,None]},)] + assert result == [({'a': 42, 'b': [10, 9, 8, None]},)] def test_map(self, duckdb_cursor): duckdb_conn = duckdb.connect() result = duckdb_conn.execute("select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7))").fetchall() - assert result == [({'key': [1,2,3,4], 'value': [10,9,8,7]},)] + assert result == [({'key': [1, 2, 3, 4], 'value': [10, 9, 8, 7]},)] result = duckdb_conn.execute("select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, NULL))").fetchall() - assert result == [({'key': [1,2,3,4], 'value': [10,9,8,None]},)] + assert result == [({'key': [1, 2, 3, 4], 'value': [10, 9, 8, None]},)] result = duckdb_conn.execute("SELECT MAP() ").fetchall() assert result == [({'key': [], 'value': []},)] - \ No newline at end of file diff --git a/tools/pythonpkg/tests/fast/types/test_null.py b/tools/pythonpkg/tests/fast/types/test_null.py index e3bf8b8b35a0..fa4105b68199 100644 --- a/tools/pythonpkg/tests/fast/types/test_null.py +++ b/tools/pythonpkg/tests/fast/types/test_null.py @@ -1,13 +1,10 @@ import traceback -class TestNull(object): +class TestNull(object): def test_fetchone_null(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE atable (Value int)") duckdb_cursor.execute("INSERT INTO atable VALUES (1)") duckdb_cursor.execute("SELECT * FROM atable") - assert(duckdb_cursor.fetchone()[0] == 1) - assert(duckdb_cursor.fetchone() is None) - - - \ No newline at end of file + assert duckdb_cursor.fetchone()[0] == 1 + assert duckdb_cursor.fetchone() is None diff --git a/tools/pythonpkg/tests/fast/types/test_numeric.py b/tools/pythonpkg/tests/fast/types/test_numeric.py index 07fda588efc4..f25b72b11c0f 100644 --- a/tools/pythonpkg/tests/fast/types/test_numeric.py +++ b/tools/pythonpkg/tests/fast/types/test_numeric.py @@ -1,13 +1,15 @@ import duckdb import numpy -def check_result(duckdb_cursor,value, type): - duckdb_cursor.execute("SELECT " + str(value)+"::"+ type) + +def check_result(duckdb_cursor, value, type): + duckdb_cursor.execute("SELECT " + str(value) + "::" + type) results = duckdb_cursor.fetchall() assert results[0][0] == value + class TestNumeric(object): def test_numeric_results(self, duckdb_cursor): - check_result(duckdb_cursor,1,"TINYINT") - check_result(duckdb_cursor,1,"SMALLINT") - check_result(duckdb_cursor,1,"FLOAT") \ No newline at end of file + check_result(duckdb_cursor, 1, "TINYINT") + check_result(duckdb_cursor, 1, "SMALLINT") + check_result(duckdb_cursor, 1, "FLOAT") diff --git a/tools/pythonpkg/tests/fast/types/test_numpy.py b/tools/pythonpkg/tests/fast/types/test_numpy.py index c7bfbe22cef0..fe23d1fcfa43 100644 --- a/tools/pythonpkg/tests/fast/types/test_numpy.py +++ b/tools/pythonpkg/tests/fast/types/test_numpy.py @@ -3,13 +3,19 @@ import datetime import pytest + class TestNumpyDatetime64(object): def test_numpy_datetime64(self, duckdb_cursor): duckdb_con = duckdb.connect() duckdb_con.execute("create table tbl(col TIMESTAMP)") - duckdb_con.execute("insert into tbl VALUES (CAST(? AS TIMESTAMP WITHOUT TIME ZONE))", parameters=[np.datetime64('2022-02-08T06:01:38.761310')]) - assert [(datetime.datetime(2022, 2, 8, 6, 1, 38, 761310),)] == duckdb_con.execute("select * from tbl").fetchall() + duckdb_con.execute( + "insert into tbl VALUES (CAST(? AS TIMESTAMP WITHOUT TIME ZONE))", + parameters=[np.datetime64('2022-02-08T06:01:38.761310')], + ) + assert [(datetime.datetime(2022, 2, 8, 6, 1, 38, 761310),)] == duckdb_con.execute( + "select * from tbl" + ).fetchall() def test_numpy_datetime_overflow(self): duckdb_con = duckdb.connect() diff --git a/tools/pythonpkg/tests/fast/types/test_object_int.py b/tools/pythonpkg/tests/fast/types/test_object_int.py index 4a8fa633fbd6..d22a63d41e9a 100644 --- a/tools/pythonpkg/tests/fast/types/test_object_int.py +++ b/tools/pythonpkg/tests/fast/types/test_object_int.py @@ -3,22 +3,27 @@ import duckdb import pytest + class TestPandasObjectInteger(object): # Signed Masked Integer types def test_object_integer(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df_in = pd.DataFrame({ + df_in = pd.DataFrame( + { 'int8': pd.Series([None, 1, -1], dtype="Int8"), 'int16': pd.Series([None, 1, -1], dtype="Int16"), 'int32': pd.Series([None, 1, -1], dtype="Int32"), - 'int64': pd.Series([None, 1, -1], dtype="Int64")} + 'int64': pd.Series([None, 1, -1], dtype="Int64"), + } ) # These are float64 because pandas would force these to be float64 even if we set them to int8, int16, int32, int64 respectively - df_expected_res = pd.DataFrame({ - 'int8': np.ma.masked_array([0,1,-1], mask=[True,False,False], dtype='float64'), - 'int16': np.ma.masked_array([0,1,-1], mask=[True,False,False], dtype='float64'), - 'int32': np.ma.masked_array([0,1,-1], mask=[True,False,False], dtype='float64'), - 'int64': np.ma.masked_array([0,1,-1], mask=[True,False,False], dtype='float64'),} + df_expected_res = pd.DataFrame( + { + 'int8': np.ma.masked_array([0, 1, -1], mask=[True, False, False], dtype='float64'), + 'int16': np.ma.masked_array([0, 1, -1], mask=[True, False, False], dtype='float64'), + 'int32': np.ma.masked_array([0, 1, -1], mask=[True, False, False], dtype='float64'), + 'int64': np.ma.masked_array([0, 1, -1], mask=[True, False, False], dtype='float64'), + } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() pd.testing.assert_frame_equal(df_expected_res, df_out) @@ -26,18 +31,22 @@ def test_object_integer(self, duckdb_cursor): # Unsigned Masked Integer types def test_object_uinteger(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df_in = pd.DataFrame({ + df_in = pd.DataFrame( + { 'uint8': pd.Series([None, 1, 255], dtype="UInt8"), 'uint16': pd.Series([None, 1, 65535], dtype="UInt16"), 'uint32': pd.Series([None, 1, 4294967295], dtype="UInt32"), - 'uint64': pd.Series([None, 1, 18446744073709551615], dtype="UInt64")} + 'uint64': pd.Series([None, 1, 18446744073709551615], dtype="UInt64"), + } ) # These are float64 because pandas would force these to be float64 even if we set them to uint8, uint16, uint32, uint64 respectively - df_expected_res = pd.DataFrame({ - 'uint8': np.ma.masked_array([0,1,255], mask=[True,False,False], dtype='float64'), - 'uint16': np.ma.masked_array([0,1,65535], mask=[True,False,False], dtype='float64'), - 'uint32': np.ma.masked_array([0,1,4294967295], mask=[True,False,False], dtype='float64'), - 'uint64': np.ma.masked_array([0,1,18446744073709551615], mask=[True,False,False], dtype='float64'),} + df_expected_res = pd.DataFrame( + { + 'uint8': np.ma.masked_array([0, 1, 255], mask=[True, False, False], dtype='float64'), + 'uint16': np.ma.masked_array([0, 1, 65535], mask=[True, False, False], dtype='float64'), + 'uint32': np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False], dtype='float64'), + 'uint64': np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False], dtype='float64'), + } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() pd.testing.assert_frame_equal(df_expected_res, df_out) @@ -46,13 +55,17 @@ def test_object_uinteger(self, duckdb_cursor): def test_object_float(self, duckdb_cursor): # Require pandas 1.2.0 >= for this, because Float32|Float64 was not added before this version pd = pytest.importorskip("pandas", '1.2.0') - df_in = pd.DataFrame({ + df_in = pd.DataFrame( + { 'float32': pd.Series([None, 1, 4294967295], dtype="Float32"), - 'float64': pd.Series([None, 1, 18446744073709551615], dtype="Float64")} + 'float64': pd.Series([None, 1, 18446744073709551615], dtype="Float64"), + } ) - df_expected_res = pd.DataFrame({ - 'float32': np.ma.masked_array([0,1,4294967295], mask=[True,False,False], dtype='float32'), - 'float64': np.ma.masked_array([0,1,18446744073709551615], mask=[True,False,False], dtype='float64'),} + df_expected_res = pd.DataFrame( + { + 'float32': np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False], dtype='float32'), + 'float64': np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False], dtype='float64'), + } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() pd.testing.assert_frame_equal(df_expected_res, df_out) diff --git a/tools/pythonpkg/tests/fast/types/test_unsigned.py b/tools/pythonpkg/tests/fast/types/test_unsigned.py index d54601d26c93..6ac50727422b 100644 --- a/tools/pythonpkg/tests/fast/types/test_unsigned.py +++ b/tools/pythonpkg/tests/fast/types/test_unsigned.py @@ -1,10 +1,7 @@ - class TestUnsigned(object): def test_unsigned(self, duckdb_cursor): duckdb_cursor.execute('create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)') duckdb_cursor.execute('insert into unsigned values (1,1,1,1), (null,null,null,null)') duckdb_cursor.execute('select * from unsigned order by a nulls first') - result = duckdb_cursor.fetchall() + result = duckdb_cursor.fetchall() assert result == [(None, None, None, None), (1, 1, 1, 1)] - - diff --git a/tools/pythonpkg/tests/fast/udf/test_remove_function.py b/tools/pythonpkg/tests/fast/udf/test_remove_function.py index 049d29648045..f73f793aee65 100644 --- a/tools/pythonpkg/tests/fast/udf/test_remove_function.py +++ b/tools/pythonpkg/tests/fast/udf/test_remove_function.py @@ -1,6 +1,7 @@ import duckdb import os import pytest + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") from typing import Union @@ -12,10 +13,14 @@ from duckdb.typing import * + class TestRemoveFunction(object): def test_not_created(self): con = duckdb.connect() - with pytest.raises(duckdb.InvalidInputException, match="No function by the name of 'not_a_registered_function' was found in the list of registered functions"): + with pytest.raises( + duckdb.InvalidInputException, + match="No function by the name of 'not_a_registered_function' was found in the list of registered functions", + ): con.remove_function('not_a_registered_function') def test_double_remove(self): @@ -26,9 +31,12 @@ def func(x: int) -> int: con.create_function('func', func) con.sql('select func(42)') con.remove_function('func') - with pytest.raises(duckdb.InvalidInputException, match="No function by the name of 'func' was found in the list of registered functions"): + with pytest.raises( + duckdb.InvalidInputException, + match="No function by the name of 'func' was found in the list of registered functions", + ): con.remove_function('func') - + with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): con.sql('select func(42)') @@ -43,7 +51,9 @@ def func(x: int) -> int: """ Error: Catalog Error: Scalar Function with name func does not exist! """ - with pytest.raises(duckdb.InvalidInputException, match='Attempting to execute an unsuccessful or closed pending query result'): + with pytest.raises( + duckdb.InvalidInputException, match='Attempting to execute an unsuccessful or closed pending query result' + ): res = rel.fetchall() def test_use_after_remove_and_recreation(self): @@ -58,6 +68,7 @@ def func(x: str) -> str: def also_func(x: int) -> int: return x + con.create_function('func', also_func) res = rel1.fetchall() assert res[0][0] == 42 @@ -66,12 +77,15 @@ def also_func(x: int) -> int: Candidate functions: func(BIGINT) -> BIGINT """ - with pytest.raises(duckdb.InvalidInputException, match='Attempting to execute an unsuccessful or closed pending query result'): + with pytest.raises( + duckdb.InvalidInputException, match='Attempting to execute an unsuccessful or closed pending query result' + ): res = rel2.fetchall() def test_overwrite_name(self): def func(x): return x + con = duckdb.connect() # create first version of the function con.create_function('func', func, [BIGINT], BIGINT) @@ -82,12 +96,17 @@ def func(x): def other_func(x): return x - with pytest.raises(duckdb.NotImplementedException, match="A function by the name of 'func' is already created, creating multiple functions with the same name is not supported yet, please remove it first"): + with pytest.raises( + duckdb.NotImplementedException, + match="A function by the name of 'func' is already created, creating multiple functions with the same name is not supported yet, please remove it first", + ): con.create_function('func', other_func, [VARCHAR], VARCHAR) con.remove_function('func') - with pytest.raises(duckdb.InvalidInputException, match='Catalog Error: Scalar Function with name func does not exist!'): + with pytest.raises( + duckdb.InvalidInputException, match='Catalog Error: Scalar Function with name func does not exist!' + ): # Attempted to execute the relation using the 'func' function, but it was deleted rel1.fetchall() diff --git a/tools/pythonpkg/tests/fast/udf/test_scalar.py b/tools/pythonpkg/tests/fast/udf/test_scalar.py index 3ea369732607..124cd48bf82a 100644 --- a/tools/pythonpkg/tests/fast/udf/test_scalar.py +++ b/tools/pythonpkg/tests/fast/udf/test_scalar.py @@ -1,6 +1,7 @@ import duckdb import os import pytest + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") from typing import Union @@ -12,54 +13,53 @@ from duckdb.typing import * + def make_annotated_function(type): # Create a function that returns its input def test_base(x): return x import types + test_function = types.FunctionType( - test_base.__code__, - test_base.__globals__, - test_base.__name__, - test_base.__defaults__, - test_base.__closure__ + test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__ ) # Add annotations for the return type and 'x' - test_function.__annotations__ = { - 'return': type, - 'x': type - } + test_function.__annotations__ = {'return': type, 'x': type} return test_function + class TestScalarUDF(object): - @pytest.mark.parametrize('function_type', [ - 'native', - 'arrow' - ]) - @pytest.mark.parametrize('test_type', [ - (TINYINT, -42), - (SMALLINT, -512), - (INTEGER, -131072), - (BIGINT, -17179869184), - (UTINYINT, 254), - (USMALLINT, 65535), - (UINTEGER, 4294967295), - (UBIGINT, 18446744073709551615), - (HUGEINT, 18446744073709551616), - (VARCHAR, 'long_string_test'), - (UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), - (FLOAT, 0.12246409803628922), - (DOUBLE, 123142.12312416293784721232344), - (DATE, datetime.date(2005, 3, 11)), - (TIMESTAMP, datetime.datetime(2009, 2, 13, 11, 5, 53)), - (TIME, datetime.time(14, 1, 12)), - (BLOB, b'\xF6\x96\xB0\x85'), - (INTERVAL, datetime.timedelta(days=30969, seconds=999, microseconds=999999)), - (BOOLEAN, True), - (duckdb.struct_type(['BIGINT[]','VARCHAR[]']), {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}), - (duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string']) - ]) + @pytest.mark.parametrize('function_type', ['native', 'arrow']) + @pytest.mark.parametrize( + 'test_type', + [ + (TINYINT, -42), + (SMALLINT, -512), + (INTEGER, -131072), + (BIGINT, -17179869184), + (UTINYINT, 254), + (USMALLINT, 65535), + (UINTEGER, 4294967295), + (UBIGINT, 18446744073709551615), + (HUGEINT, 18446744073709551616), + (VARCHAR, 'long_string_test'), + (UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), + (FLOAT, 0.12246409803628922), + (DOUBLE, 123142.12312416293784721232344), + (DATE, datetime.date(2005, 3, 11)), + (TIMESTAMP, datetime.datetime(2009, 2, 13, 11, 5, 53)), + (TIME, datetime.time(14, 1, 12)), + (BLOB, b'\xF6\x96\xB0\x85'), + (INTERVAL, datetime.timedelta(days=30969, seconds=999, microseconds=999999)), + (BOOLEAN, True), + ( + duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), + {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}, + ), + (duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string']), + ], + ) def test_type_coverage(self, test_type, function_type): type = test_type[0] value = test_type[1] @@ -80,12 +80,13 @@ def test_type_coverage(self, test_type, function_type): # Multiple chunks size = duckdb.__standard_vector_size__ * 3 res = con.execute(f"select test(x) from repeat(?::{str(type)}, {size}) as tbl(x)", [value]).fetchall() - assert(len(res) == size) + assert len(res) == size # Mixed NULL/NON-NULL size = duckdb.__standard_vector_size__ * 3 con.execute("select setseed(0.1337)").fetchall() - actual = con.execute(f""" + actual = con.execute( + f""" select test( case when (x > 0.5) then ?::{str(type)} @@ -93,10 +94,13 @@ def test_type_coverage(self, test_type, function_type): NULL end ) from (select random() as x from range({size})) - """, [value]).fetchall() + """, + [value], + ).fetchall() con.execute("select setseed(0.1337)").fetchall() - expected = con.execute(f""" + expected = con.execute( + f""" select case when (x > 0.5) then ?::{str(type)} @@ -104,7 +108,9 @@ def test_type_coverage(self, test_type, function_type): NULL end from (select random() as x from range({size})) - """, [value]).fetchall() + """, + [value], + ).fetchall() assert expected == actual # Using 'relation.project' @@ -113,14 +119,11 @@ def test_type_coverage(self, test_type, function_type): res = table_rel.project('test(x)').fetchall() assert res[0][0] == value - @pytest.mark.parametrize('udf_type', [ - 'arrow', - 'native' - ]) + @pytest.mark.parametrize('udf_type', ['arrow', 'native']) def test_map_coverage(self, udf_type): def no_op(x): return x - + con = duckdb.connect() map_type = con.map_type('VARCHAR', 'BIGINT') con.create_function('test_map', no_op, [map_type], map_type, type=udf_type) @@ -128,21 +131,23 @@ def no_op(x): res = rel.fetchall() assert res == [({'key': ['non-inlined string', 'test', 'duckdb'], 'value': [42, 1337, 123]},)] - @pytest.mark.parametrize('udf_type', [ - 'arrow', - 'native' - ]) + @pytest.mark.parametrize('udf_type', ['arrow', 'native']) def test_exceptions(self, udf_type): def raises_exception(x): raise AttributeError("error") - + con = duckdb.connect() con.create_function('raises', raises_exception, [BIGINT], BIGINT, type=udf_type) - with pytest.raises(duckdb.InvalidInputException, match=' Python exception occurred while executing the UDF: AttributeError: error'): + with pytest.raises( + duckdb.InvalidInputException, + match=' Python exception occurred while executing the UDF: AttributeError: error', + ): res = con.sql('select raises(3)').fetchall() - + con.remove_function('raises') - con.create_function('raises', raises_exception, [BIGINT], BIGINT, exception_handling='return_null', type=udf_type) + con.create_function( + 'raises', raises_exception, [BIGINT], BIGINT, exception_handling='return_null', type=udf_type + ) res = con.sql('select raises(3) from range(5)').fetchall() assert res == [(None,), (None,), (None,), (None,), (None,)] @@ -164,13 +169,8 @@ def __call__(self, x): assert res == [(5,)] # pyarrow does not support creating an array filled with pd.NA values - @pytest.mark.parametrize('udf_type', [ - 'native' - ]) - @pytest.mark.parametrize('duckdb_type', [ - FLOAT, - DOUBLE - ]) + @pytest.mark.parametrize('udf_type', ['native']) + @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) def test_pd_nan(self, duckdb_type, udf_type): def return_pd_nan(): if udf_type == 'native': @@ -184,9 +184,10 @@ def return_pd_nan(): def test_side_effects(self): def count() -> int: - old = count.counter; + old = count.counter count.counter += 1 return old + count.counter = 0 con = duckdb.connect() @@ -200,20 +201,15 @@ def count() -> int: res = con.sql('select my_counter() from range(10)').fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,)] - @pytest.mark.parametrize('udf_type', [ - 'arrow', - 'native' - ]) - @pytest.mark.parametrize('duckdb_type', [ - FLOAT, - DOUBLE - ]) + @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) def test_np_nan(self, duckdb_type, udf_type): def return_np_nan(): if udf_type == 'native': return np.nan else: import pyarrow as pa + return pa.chunked_array([[np.nan]], type=pa.float64()) con = duckdb.connect() @@ -222,62 +218,61 @@ def return_np_nan(): res = con.sql('select return_np_nan()').fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', [ - 'arrow', - 'native' - ]) - @pytest.mark.parametrize('duckdb_type', [ - FLOAT, - DOUBLE - ]) + @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): import cmath + if udf_type == 'native': return cmath.nan else: import pyarrow as pa + return pa.chunked_array([[cmath.nan]], type=pa.float64()) con = duckdb.connect() - con.create_function('return_math_nan', return_math_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type) + con.create_function( + 'return_math_nan', return_math_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type + ) res = con.sql('select return_math_nan()').fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', [ - 'arrow', - 'native' - ]) - @pytest.mark.parametrize('data_type', [ - TINYINT, - SMALLINT, - INTEGER, - BIGINT, - UTINYINT, - USMALLINT, - UINTEGER, - UBIGINT, - HUGEINT, - VARCHAR, - UUID, - FLOAT, - DOUBLE, - DATE, - TIMESTAMP, - TIME, - BLOB, - INTERVAL, - BOOLEAN, - duckdb.struct_type(['BIGINT[]','VARCHAR[]']), - duckdb.list_type('VARCHAR') - ]) + @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize( + 'data_type', + [ + TINYINT, + SMALLINT, + INTEGER, + BIGINT, + UTINYINT, + USMALLINT, + UINTEGER, + UBIGINT, + HUGEINT, + VARCHAR, + UUID, + FLOAT, + DOUBLE, + DATE, + TIMESTAMP, + TIME, + BLOB, + INTERVAL, + BOOLEAN, + duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), + duckdb.list_type('VARCHAR'), + ], + ) def test_return_null(self, data_type, udf_type): def return_null(): if udf_type == 'native': return None else: import pyarrow as pa + return pa.nulls(1) con = duckdb.connect() @@ -299,7 +294,10 @@ def func(x: int) -> int: # then starting a new result-fetch would cancel the transaction # which would corrupt our internal mechanism used to check if a UDF is already registered # because that isn't transaction-aware - with pytest.raises(duckdb.InvalidInputException, match='This function can not be called with an active transaction!, commit or abort the existing one first'): + with pytest.raises( + duckdb.InvalidInputException, + match='This function can not be called with an active transaction!, commit or abort the existing one first', + ): con.create_function('func', func) # This would cancel the previous transaction, causing the function to no longer exist @@ -307,4 +305,4 @@ def func(x: int) -> int: con.create_function('func', func) res = con.sql('select func(5)').fetchall() - assert res == [(5,)] \ No newline at end of file + assert res == [(5,)] diff --git a/tools/pythonpkg/tests/fast/udf/test_scalar_arrow.py b/tools/pythonpkg/tests/fast/udf/test_scalar_arrow.py index 760dfb040075..c715ccce5d18 100644 --- a/tools/pythonpkg/tests/fast/udf/test_scalar_arrow.py +++ b/tools/pythonpkg/tests/fast/udf/test_scalar_arrow.py @@ -1,6 +1,7 @@ import duckdb import os import pytest + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") from typing import Union @@ -10,12 +11,13 @@ from duckdb.typing import * -class TestPyArrowUDF(object): +class TestPyArrowUDF(object): def test_basic_use(self): def plus_one(x): table = pa.lib.Table.from_arrays([x], names=['c0']) import pandas as pd + df = pd.DataFrame(x.to_pandas()) df['c0'] = df['c0'] + 1 return pa.lib.Table.from_pandas(df) @@ -45,11 +47,10 @@ def sort_table(x): res = con.sql("select 100-i as original, sort_table(original) from range(100) tbl(i)").fetchall() assert res[0] == (100, 1) - def test_varargs(self): def variable_args(*args): # We return a chunked array here, but internally we convert this into a Table - if (len(args) == 0): + if len(args) == 0: raise ValueError("Expected at least one argument") for item in args: return item @@ -59,13 +60,14 @@ def variable_args(*args): con.create_function('varargs', variable_args, None, BIGINT, type='arrow') res = con.sql("""select varargs(5, '3', '2', 1, 0.12345)""").fetchall() assert res == [(5,)] - + res = con.sql("""select varargs(42, 'test', [5,4,3])""").fetchall() assert res == [(42,)] def test_cast_varchar_to_int(self): def takes_string(col): return col + con = duckdb.connect() # The return type of the function is set to BIGINT, but it takes a VARCHAR con.create_function('pyarrow_string_to_num', takes_string, [VARCHAR], BIGINT, type='arrow') @@ -80,13 +82,17 @@ def takes_string(col): def test_return_multiple_columns(self): def returns_two_columns(col): import pandas as pd + # Return a pyarrow table consisting of two columns - return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5,4,3], 'b': ['test', 'quack', 'duckdb']})) + return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5, 4, 3], 'b': ['test', 'quack', 'duckdb']})) con = duckdb.connect() # Scalar functions only return a single value per tuple con.create_function('two_columns', returns_two_columns, [BIGINT], BIGINT, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='The returned table from a pyarrow scalar udf should only contain one column, found 2'): + with pytest.raises( + duckdb.InvalidInputException, + match='The returned table from a pyarrow scalar udf should only contain one column, found 2', + ): res = con.sql("""select two_columns(5)""").fetchall() def test_return_none(self): @@ -111,7 +117,7 @@ def return_empty(col): def test_excessive_result(self): def return_too_many(col): # Always returns a table consisting of 5 tuples - return pa.lib.Table.from_arrays([[5,4,3,2,1]], names=['c0']) + return pa.lib.Table.from_arrays([[5, 4, 3, 2, 1]], names=['c0']) con = duckdb.connect() con.create_function('too_many_tuples', return_too_many, [BIGINT], BIGINT, type='arrow') @@ -121,10 +127,12 @@ def return_too_many(col): def test_return_struct(self): def return_struct(col): con = duckdb.connect() - return con.sql(""" + return con.sql( + """ select {'a': 5, 'b': 'test', 'c': [5,3,2]} - """).arrow() - + """ + ).arrow() + con = duckdb.connect() struct_type = con.struct_type({'a': BIGINT, 'b': VARCHAR, 'c': con.list_type(BIGINT)}) con.create_function('return_struct', return_struct, [BIGINT], struct_type, type='arrow') @@ -134,12 +142,14 @@ def return_struct(col): def test_multiple_chunks(self): def return_unmodified(col): return col - + con = duckdb.connect() con.create_function('unmodified', return_unmodified, [BIGINT], BIGINT, type='arrow') - res = con.sql(""" + res = con.sql( + """ select unmodified(i) from range(5000) tbl(i) - """).fetchall() + """ + ).fetchall() assert len(res) == 5000 assert res == con.sql('select * from range(5000)').fetchall() @@ -147,10 +157,11 @@ def return_unmodified(col): def test_inferred(self): def func(x: int) -> int: import pandas as pd + df = pd.DataFrame({'c0': x}) df['c0'] = df['c0'] ** 2 return pa.lib.Table.from_pandas(df) - + con = duckdb.connect() con.create_function('inferred', func, type='arrow') res = con.sql('select inferred(42)').fetchall() @@ -159,6 +170,7 @@ def func(x: int) -> int: def test_nulls(self): def return_five(x): import pandas as pd + length = len(x) return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5 for _ in range(length)]})) @@ -173,4 +185,3 @@ def return_five(x): res = con.sql('select return_five(NULL) from range(10)').fetchall() # Because we didn't specify 'special' null handling, these are all NULL assert res == [(None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,)] - diff --git a/tools/pythonpkg/tests/fast/udf/test_scalar_native.py b/tools/pythonpkg/tests/fast/udf/test_scalar_native.py index cb29b411694f..5c46521289e9 100644 --- a/tools/pythonpkg/tests/fast/udf/test_scalar_native.py +++ b/tools/pythonpkg/tests/fast/udf/test_scalar_native.py @@ -5,11 +5,12 @@ from duckdb.typing import * + class TestNativeUDF(object): def test_default_conn(self): def passthrough(x): return x - + duckdb.create_function('default_conn_passthrough', passthrough, [BIGINT], BIGINT) res = duckdb.sql('select default_conn_passthrough(5)').fetchall() assert res == [(5,)] @@ -17,7 +18,7 @@ def passthrough(x): def test_basic_use(self): def plus_one(x): if x == None or x > 50: - return x; + return x return x + 1 con = duckdb.connect() @@ -38,7 +39,10 @@ def passthrough(x): con = duckdb.connect() con.create_function('passthrough', passthrough, [BIGINT], BIGINT) - assert con.sql('select passthrough(i) from range(5000) tbl(i)').fetchall() == con.sql('select * from range(5000)').fetchall() + assert ( + con.sql('select passthrough(i) from range(5000) tbl(i)').fetchall() + == con.sql('select * from range(5000)').fetchall() + ) def test_execute(self): def func(x): @@ -68,23 +72,27 @@ def concatenate(a: str, b: str): con = duckdb.connect() con.create_function('py_concatenate', concatenate, None, VARCHAR) - res = con.sql(""" + res = con.sql( + """ select py_concatenate('5','3'); - """).fetchall() + """ + ).fetchall() assert res[0][0] == '53' def test_detected_return_type(self): def add_nums(*args) -> int: - sum = 0; + sum = 0 for arg in args: sum += arg return sum con = duckdb.connect() con.create_function('add_nums', add_nums) - res = con.sql(""" + res = con.sql( + """ select add_nums(5,3,2,1); - """).fetchall() + """ + ).fetchall() assert res[0][0] == 11 def test_varargs(self): @@ -100,48 +108,54 @@ def variable_args(*args): def test_return_incorrectly_typed_object(self): def returns_duckdb() -> int: return 'duckdb' - + con = duckdb.connect() con.create_function('fastest_database_in_the_west', returns_duckdb) - with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Could not convert string 'duckdb' to INT64"): + with pytest.raises( + duckdb.InvalidInputException, match="Failed to cast value: Could not convert string 'duckdb' to INT64" + ): res = con.sql('select fastest_database_in_the_west()').fetchall() def test_nulls(self): def five_if_null(x): - if (x == None): + if x == None: return 5 return x + con = duckdb.connect() - con.create_function('null_test', five_if_null, [BIGINT], BIGINT, null_handling = "SPECIAL") + con.create_function('null_test', five_if_null, [BIGINT], BIGINT, null_handling="SPECIAL") res = con.sql('select null_test(NULL)').fetchall() assert res == [(5,)] - @pytest.mark.parametrize('pair', [ - (TINYINT, -129), - (TINYINT, 128), - (SMALLINT, -32769), - (SMALLINT, 32768), - (INTEGER, -2147483649), - (INTEGER, 2147483648), - (BIGINT, -9223372036854775815), - (BIGINT, 9223372036854775808), - (UTINYINT, -1), - (UTINYINT, 256), - (USMALLINT, -1), - (USMALLINT, 65536), - (UINTEGER, -1), - (UINTEGER, 4294967296), - (UBIGINT, -1), - (UBIGINT, 18446744073709551616), - (HUGEINT, -170141183460469231731687303715884105729), - (HUGEINT, 170141183460469231731687303715884105728), - ]) + @pytest.mark.parametrize( + 'pair', + [ + (TINYINT, -129), + (TINYINT, 128), + (SMALLINT, -32769), + (SMALLINT, 32768), + (INTEGER, -2147483649), + (INTEGER, 2147483648), + (BIGINT, -9223372036854775815), + (BIGINT, 9223372036854775808), + (UTINYINT, -1), + (UTINYINT, 256), + (USMALLINT, -1), + (USMALLINT, 65536), + (UINTEGER, -1), + (UINTEGER, 4294967296), + (UBIGINT, -1), + (UBIGINT, 18446744073709551616), + (HUGEINT, -170141183460469231731687303715884105729), + (HUGEINT, 170141183460469231731687303715884105728), + ], + ) def test_return_overflow(self, pair): duckdb_type, overflowing_value = pair def return_overflow(): return overflowing_value - + con = duckdb.connect() con.create_function('return_overflow', return_overflow, None, duckdb_type) with pytest.raises(duckdb.InvalidInputException): @@ -158,11 +172,18 @@ def add_extra_column(original): con = duckdb.connect() range_table = con.table_function('range', [5000]) - con.create_function("append_field", add_extra_column, [duckdb.struct_type({'a': BIGINT, 'b': BIGINT})], duckdb.struct_type({'a': BIGINT, 'b': BIGINT, 'c': BIGINT})) - - res = con.sql(""" + con.create_function( + "append_field", + add_extra_column, + [duckdb.struct_type({'a': BIGINT, 'b': BIGINT})], + duckdb.struct_type({'a': BIGINT, 'b': BIGINT, 'c': BIGINT}), + ) + + res = con.sql( + """ select append_field({'a': i::BIGINT, 'b': 3::BIGINT}) from range_table tbl(i) - """) + """ + ) # added extra column to the struct assert len(res.fetchone()[0].keys()) == 3 # FIXME: this is needed, otherwise the old transaction is still active when we try to start a new transaction inside of 'create_function', which means the call would fail @@ -176,8 +197,15 @@ def swap_keys(dict): result[item] = dict[item] return result - con.create_function('swap_keys', swap_keys, [con.struct_type({'a': BIGINT, 'b': VARCHAR})], con.struct_type({'a': VARCHAR, 'b': BIGINT})) - res = con.sql(""" + con.create_function( + 'swap_keys', + swap_keys, + [con.struct_type({'a': BIGINT, 'b': VARCHAR})], + con.struct_type({'a': VARCHAR, 'b': BIGINT}), + ) + res = con.sql( + """ select swap_keys({'a': 42, 'b': 'answer_to_life'}) - """).fetchall() + """ + ).fetchall() assert res == [({'a': 'answer_to_life', 'b': 42},)] diff --git a/tools/pythonpkg/tests/slow/test_h2oai_arrow.py b/tools/pythonpkg/tests/slow/test_h2oai_arrow.py index b373db998ff7..06299c0bb4cc 100644 --- a/tools/pythonpkg/tests/slow/test_h2oai_arrow.py +++ b/tools/pythonpkg/tests/slow/test_h2oai_arrow.py @@ -7,7 +7,8 @@ requests = importorskip('requests') np = importorskip('numpy') -def download_file(url,name): + +def download_file(url, name): r = requests.get(url, allow_redirects=True) open(name, 'wb').write(r.content) @@ -56,7 +57,9 @@ def group_by_q5(con): def group_by_q6(con): - con.execute("CREATE TABLE ans AS SELECT id4, id5, quantile_cont(v3, 0.5) AS median_v3, stddev(v3) AS sd_v3 FROM x GROUP BY id4, id5;") + con.execute( + "CREATE TABLE ans AS SELECT id4, id5, quantile_cont(v3, 0.5) AS median_v3, stddev(v3) AS sd_v3 FROM x GROUP BY id4, id5;" + ) res = con.execute("SELECT COUNT(*), sum(median_v3) AS median_v3, sum(sd_v3) AS sd_v3 FROM ans").fetchall() assert res[0][0] == 9216 assert math.floor(res[0][1]) == 460771 @@ -73,7 +76,9 @@ def group_by_q7(con): def group_by_q8(con): - con.execute("CREATE TABLE ans AS SELECT id6, v3 AS largest2_v3 FROM (SELECT id6, v3, row_number() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2") + con.execute( + "CREATE TABLE ans AS SELECT id6, v3 AS largest2_v3 FROM (SELECT id6, v3, row_number() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2" + ) res = con.execute("SELECT count(*), sum(largest2_v3) AS largest2_v3 FROM ans").fetchall() assert res[0][0] == 190002 assert math.floor(res[0][1]) == 18700554 @@ -89,12 +94,15 @@ def group_by_q9(con): def group_by_q10(con): - con.execute("CREATE TABLE ans AS SELECT id1, id2, id3, id4, id5, id6, sum(v3) AS v3, count(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6;") + con.execute( + "CREATE TABLE ans AS SELECT id1, id2, id3, id4, id5, id6, sum(v3) AS v3, count(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6;" + ) res = con.execute("SELECT sum(v3) AS v3, sum(count) AS count FROM ans;").fetchall() assert math.floor(res[0][0]) == 474969574 assert res[0][1] == 10000000 con.execute("DROP TABLE ans") + def join_by_q1(con): con.execute("CREATE TABLE ans AS SELECT x.*, small.id4 AS small_id4, v2 FROM x JOIN small USING (id1);") res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() @@ -103,55 +111,73 @@ def join_by_q1(con): assert math.floor(res[0][2]) == 347720187 con.execute("DROP TABLE ans") + def join_by_q2(con): - con.execute("CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x JOIN medium USING (id2);") + con.execute( + "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x JOIN medium USING (id2);" + ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 8998412 assert math.floor(res[0][1]) == 449954076 assert math.floor(res[0][2]) == 449999844 con.execute("DROP TABLE ans") + def join_by_q3(con): - con.execute("CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x LEFT JOIN medium USING (id2);") + con.execute( + "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id4 AS medium_id4, medium.id5 AS medium_id5, v2 FROM x LEFT JOIN medium USING (id2);" + ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 10000000 assert math.floor(res[0][1]) == 500043740 assert math.floor(res[0][2]) == 449999844 con.execute("DROP TABLE ans") + def join_by_q4(con): - con.execute("CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id2 AS medium_id2, medium.id4 AS medium_id4, v2 FROM x JOIN medium USING (id5);") + con.execute( + "CREATE TABLE ans AS SELECT x.*, medium.id1 AS medium_id1, medium.id2 AS medium_id2, medium.id4 AS medium_id4, v2 FROM x JOIN medium USING (id5);" + ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 8998412 assert math.floor(res[0][1]) == 449954076 assert math.floor(res[0][2]) == 449999844 con.execute("DROP TABLE ans") + def join_by_q5(con): - con.execute("CREATE TABLE ans AS SELECT x.*, big.id1 AS big_id1, big.id2 AS big_id2, big.id4 AS big_id4, big.id5 AS big_id5, big.id6 AS big_id6, v2 FROM x JOIN big USING (id3);") + con.execute( + "CREATE TABLE ans AS SELECT x.*, big.id1 AS big_id1, big.id2 AS big_id2, big.id4 AS big_id4, big.id5 AS big_id5, big.id6 AS big_id6, v2 FROM x JOIN big USING (id3);" + ) res = con.execute("SELECT COUNT(*), SUM(v1) AS v1, SUM(v2) AS v2 FROM ans;").fetchall() assert res[0][0] == 9000000 - assert math.floor(res[0][1]) == 450032091 + assert math.floor(res[0][1]) == 450032091 assert math.floor(res[0][2]) == 449860428 con.execute("DROP TABLE ans") + class TestH2OAIArrow(object): - @mark.parametrize('function', [ - group_by_q1, - group_by_q2, - group_by_q3, - group_by_q4, - group_by_q5, - group_by_q6, - group_by_q7, - group_by_q8, - group_by_q9, - group_by_q10, - ]) + @mark.parametrize( + 'function', + [ + group_by_q1, + group_by_q2, + group_by_q3, + group_by_q4, + group_by_q5, + group_by_q6, + group_by_q7, + group_by_q8, + group_by_q9, + group_by_q10, + ], + ) @mark.parametrize('threads', [1, 4]) def test_group_by(self, threads, function): con = duckdb.connect() - download_file('https://github.com/cwida/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz','G1_1e7_1e2_5_0.csv.gz') + download_file( + 'https://github.com/cwida/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz', 'G1_1e7_1e2_5_0.csv.gz' + ) arrow_table = read_csv('G1_1e7_1e2_5_0.csv.gz') con.register("x", arrow_table) os.remove('G1_1e7_1e2_5_0.csv.gz') @@ -161,13 +187,16 @@ def test_group_by(self, threads, function): function(con) @mark.parametrize('threads', [1, 4]) - @mark.parametrize('function', [ - join_by_q1, - join_by_q2, - join_by_q3, - join_by_q4, - join_by_q5, - ]) + @mark.parametrize( + 'function', + [ + join_by_q1, + join_by_q2, + join_by_q3, + join_by_q4, + join_by_q5, + ], + ) @mark.usefixtures('large_data') def test_join(self, threads, function, large_data): large_data.execute(f"PRAGMA threads={threads}") @@ -177,11 +206,19 @@ def test_join(self, threads, function, large_data): @fixture(scope="module") def large_data(): - download_file('https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz','J1_1e7_NA_0_0.csv.gz') - download_file('https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz','J1_1e7_1e1_0_0.csv.gz') - download_file('https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz','J1_1e7_1e4_0_0.csv.gz') - download_file('https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz','J1_1e7_1e7_0_0.csv.gz') - + download_file( + 'https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz', 'J1_1e7_NA_0_0.csv.gz' + ) + download_file( + 'https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz', 'J1_1e7_1e1_0_0.csv.gz' + ) + download_file( + 'https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz', 'J1_1e7_1e4_0_0.csv.gz' + ) + download_file( + 'https://github.com/cwida/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz', 'J1_1e7_1e7_0_0.csv.gz' + ) + con = duckdb.connect() arrow_table = read_csv('J1_1e7_NA_0_0.csv.gz') con.register("x", arrow_table) diff --git a/tools/pythonpkg/tests/stubs/test_stubs.py b/tools/pythonpkg/tests/stubs/test_stubs.py index 36453e0f8ea8..d598bc9f2539 100644 --- a/tools/pythonpkg/tests/stubs/test_stubs.py +++ b/tools/pythonpkg/tests/stubs/test_stubs.py @@ -4,19 +4,20 @@ MYPY_INI_PATH = os.path.join(os.path.dirname(__file__), 'mypy.ini') + def test_generated_stubs(): - skip_stubs_errors = ['pybind11', 'git_revision', 'is inconsistent, metaclass differs'] + skip_stubs_errors = ['pybind11', 'git_revision', 'is inconsistent, metaclass differs'] - stubtest.test_stubs(stubtest.parse_options(['duckdb', '--mypy-config-file', MYPY_INI_PATH])) + stubtest.test_stubs(stubtest.parse_options(['duckdb', '--mypy-config-file', MYPY_INI_PATH])) - broken_stubs = [ - error.get_description() - for error in stubtest.test_module('duckdb') - if not any(skip in error.get_description() for skip in skip_stubs_errors) - ] + broken_stubs = [ + error.get_description() + for error in stubtest.test_module('duckdb') + if not any(skip in error.get_description() for skip in skip_stubs_errors) + ] - if broken_stubs: - print("Stubs must be updated, either add them to skip_stubs_errors or update __init__.pyi accordingly") - print(broken_stubs) + if broken_stubs: + print("Stubs must be updated, either add them to skip_stubs_errors or update __init__.pyi accordingly") + print(broken_stubs) - assert not broken_stubs + assert not broken_stubs diff --git a/tools/release-pip.py b/tools/release-pip.py index 867da770affc..cc6db3ba200e 100644 --- a/tools/release-pip.py +++ b/tools/release-pip.py @@ -1,6 +1,6 @@ import urllib.request, ssl, json, tempfile, os, sys, re, subprocess -if (len(sys.argv) < 2): +if len(sys.argv) < 2: print("Usage: [release_tag]") exit(1) @@ -12,17 +12,17 @@ release_rev = None request = urllib.request.Request("https://api.github.com/repos/cwida/duckdb/git/refs/tags/") -with urllib.request.urlopen(request, context=ssl._create_unverified_context()) as url: - data = json.loads(url.read().decode()) +with urllib.request.urlopen(request, context=ssl._create_unverified_context()) as url: + data = json.loads(url.read().decode()) - for ref in data: - ref_name = ref['ref'].replace('refs/tags/','') - if (ref_name == release_name): - release_rev = ref['object']['sha'] + for ref in data: + ref_name = ref['ref'].replace('refs/tags/', '') + if ref_name == release_name: + release_rev = ref['object']['sha'] -if (release_rev is None): - print("Could not find hash for tag %s" % sys.argv[1]) - exit(-2) +if release_rev is None: + print("Could not find hash for tag %s" % sys.argv[1]) + exit(-2) print("Using sha %s for release %s" % (release_rev, release_name)) @@ -34,20 +34,20 @@ upload_files = [] request = urllib.request.Request(binurl) -with urllib.request.urlopen(request, context=ssl._create_unverified_context()) as url: - data = url.read().decode() - f_matches = re.findall(r'href="([^"]+\.(whl|tar\.gz))"', data) - for m in f_matches: - if '.dev' in m[0]: - continue - print("Downloading %s" % m[0]) - url = binurl + '/' + m[0] - local_file = fdir + '/' + m[0] - urllib.request.urlretrieve(url, local_file) - upload_files.append(local_file) - -if (len(upload_files) < 1): - print("Could not find any binaries") - exit(-3) +with urllib.request.urlopen(request, context=ssl._create_unverified_context()) as url: + data = url.read().decode() + f_matches = re.findall(r'href="([^"]+\.(whl|tar\.gz))"', data) + for m in f_matches: + if '.dev' in m[0]: + continue + print("Downloading %s" % m[0]) + url = binurl + '/' + m[0] + local_file = fdir + '/' + m[0] + urllib.request.urlretrieve(url, local_file) + upload_files.append(local_file) + +if len(upload_files) < 1: + print("Could not find any binaries") + exit(-3) subprocess.run(['twine', 'upload', '--skip-existing'] + upload_files) diff --git a/tools/rpkg/rconfigure.py b/tools/rpkg/rconfigure.py index 71b52711198d..55192b531e13 100644 --- a/tools/rpkg/rconfigure.py +++ b/tools/rpkg/rconfigure.py @@ -24,13 +24,16 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'scripts')) import package_build + def open_utf8(fpath, flags): import sys + if sys.version_info[0] < 3: return open(fpath, flags) else: return open(fpath, flags, encoding="utf8") + extension_list = "" for ext in extensions: @@ -89,7 +92,7 @@ def open_utf8(fpath, flags): (source_list, include_list, original_sources) = package_build.build_package(target_dir, extensions, linenr, unity_build) # object list, relative paths -script_path = os.path.dirname(os.path.abspath(__file__)).replace('\\','/') +script_path = os.path.dirname(os.path.abspath(__file__)).replace('\\', '/') duckdb_sources = [package_build.get_relative_path(os.path.join(script_path, 'src'), x) for x in source_list] object_list = ' '.join([x.rsplit('.', 1)[0] + '.o' for x in duckdb_sources]) diff --git a/tools/shell/shell-test.py b/tools/shell/shell-test.py index 8b81f68a209e..960ca5d4279f 100644 --- a/tools/shell/shell-test.py +++ b/tools/shell/shell-test.py @@ -5,55 +5,58 @@ import shutil if len(sys.argv) < 2: - raise Exception('need shell binary as parameter') + raise Exception('need shell binary as parameter') + def test_exception(command, input, stdout, stderr, errmsg): - print('--- COMMAND --') - print(' '.join(command)) - print('--- INPUT --') - print(input) - print('--- STDOUT --') - print(stdout) - print('--- STDERR --') - print(stderr) - raise Exception(errmsg) + print('--- COMMAND --') + print(' '.join(command)) + print('--- INPUT --') + print(input) + print('--- STDOUT --') + print(stdout) + print('--- STDERR --') + print(stderr) + raise Exception(errmsg) + def test(cmd, out=None, err=None, extra_commands=None, input_file=None, output_file=None): - command = [sys.argv[1], '--batch', '-init', '/dev/null'] - if extra_commands: - command += extra_commands + command = [sys.argv[1], '--batch', '-init', '/dev/null'] + if extra_commands: + command += extra_commands - if input_file: - command += [cmd] - input_data = open(input_file, 'rb').read() - else: - input_data = bytearray(cmd, 'utf8') - output_pipe = subprocess.PIPE - if output_file: - output_pipe = open(output_file, 'w+') + if input_file: + command += [cmd] + input_data = open(input_file, 'rb').read() + else: + input_data = bytearray(cmd, 'utf8') + output_pipe = subprocess.PIPE + if output_file: + output_pipe = open(output_file, 'w+') - res = subprocess.run(command, input=input_data, stdout=output_pipe, stderr=subprocess.PIPE) - if output_file: - stdout = open(output_file, 'r').read() - else: - stdout = res.stdout.decode('utf8').strip() - stderr = res.stderr.decode('utf8').strip() + res = subprocess.run(command, input=input_data, stdout=output_pipe, stderr=subprocess.PIPE) + if output_file: + stdout = open(output_file, 'r').read() + else: + stdout = res.stdout.decode('utf8').strip() + stderr = res.stderr.decode('utf8').strip() - if out and out not in stdout: - test_exception(command, cmd, stdout, stderr, 'out test failed') + if out and out not in stdout: + test_exception(command, cmd, stdout, stderr, 'out test failed') - if err and err not in stderr: - test_exception(command, cmd, stdout, stderr, f"err test failed, error does not contain: '{err}'") + if err and err not in stderr: + test_exception(command, cmd, stdout, stderr, f"err test failed, error does not contain: '{err}'") - if not err and stderr != '': - test_exception(command, cmd, stdout, stderr, 'got err test failed') + if not err and stderr != '': + test_exception(command, cmd, stdout, stderr, 'got err test failed') - if err is None and res.returncode != 0: - test_exception(command, cmd, stdout, stderr, 'process returned non-zero exit code but no error was specified') + if err is None and res.returncode != 0: + test_exception(command, cmd, stdout, stderr, 'process returned non-zero exit code but no error was specified') def tf(): - return tempfile.mktemp().replace('\\','/') + return tempfile.mktemp().replace('\\', '/') + # basic test test('select \'asdf\' as a;', out='asdf') @@ -61,31 +64,41 @@ def tf(): test('select * from range(10000);', out='9999') import_basic_csv_table = tf() -print("col_1,col_2\n1,2\n10,20", file=open(import_basic_csv_table, 'w')) +print("col_1,col_2\n1,2\n10,20", file=open(import_basic_csv_table, 'w')) # test create missing table with import -test(""" +test( + """ .mode csv .import "%s" test_table SELECT * FROM test_table; -""" % import_basic_csv_table, out="col_1,col_2\n1,2\n10,20" +""" + % import_basic_csv_table, + out="col_1,col_2\n1,2\n10,20", ) # test pragma -test(""" +test( + """ .mode csv .headers off .sep | CREATE TABLE t0(c0 INT); PRAGMA table_info('t0'); -""", out='0|c0|INTEGER|false||false') +""", + out='0|c0|INTEGER|false||false', +) datafile = tf() -print("42\n84", file=open(datafile, 'w')) -test(''' +print("42\n84", file=open(datafile, 'w')) +test( + ''' CREATE TABLE a (i INTEGER); .import "%s" a SELECT SUM(i) FROM a; -''' % datafile, out='126') +''' + % datafile, + out='126', +) # system functions test('SELECT 1, current_query() as my_column', out='SELECT 1, current_query() as my_column') @@ -95,68 +108,93 @@ def tf(): test("select STRUCT_PACK(x := 3, y := 3);", out="{'x': 3, 'y': 3}") test("select STRUCT_PACK(x := 3, y := LIST_VALUE(1, 2));", out="{'x': 3, 'y': [1, 2]}") -test(''' +test( + ''' CREATE TABLE a (i STRING); INSERT INTO a VALUES ('XXXX'); SELECT CAST(i AS INTEGER) FROM a; -''' , err='Could not convert') +''', + err='Could not convert', +) test('.auth ON', err='sqlite3_set_authorizer') test('.auth OFF', err='sqlite3_set_authorizer') test('.backup %s' % tf(), err='sqlite3_backup_init') # test newline in value -test('''select 'hello -world' as a;''', out='hello\\nworld') +test( + '''select 'hello +world' as a;''', + out='hello\\nworld', +) # test newline in column name -test('''select 42 as "hello -world";''', out='hello\\nworld') +test( + '''select 42 as "hello +world";''', + out='hello\\nworld', +) -test(''' +test( + ''' .bail on .bail off .binary on SELECT 42; .binary off SELECT 42; -''') +''' +) -test(''' +test( + ''' .cd %s .cd %s -''' % (tempfile.gettempdir().replace('\\','/'), os.getcwd().replace('\\','/'))) +''' + % (tempfile.gettempdir().replace('\\', '/'), os.getcwd().replace('\\', '/')) +) -test(''' +test( + ''' CREATE TABLE a (I INTEGER); .changes on INSERT INTO a VALUES (42); DROP TABLE a; -''', out="total_changes: 1") +''', + out="total_changes: 1", +) -test(''' +test( + ''' CREATE TABLE a (I INTEGER); .changes on INSERT INTO a VALUES (42); INSERT INTO a VALUES (42); INSERT INTO a VALUES (42); DROP TABLE a; -''', out="total_changes: 3") +''', + out="total_changes: 3", +) -test(''' +test( + ''' CREATE TABLE a (I INTEGER); .changes off INSERT INTO a VALUES (42); DROP TABLE a; -''') +''' +) # maybe at some point we can do something meaningful here # test('.dbinfo', err='unable to read database header') -test(''' +test( + ''' .echo on SELECT 42; -''', out="SELECT 42") +''', + out="SELECT 42", +) test('.exit') @@ -164,16 +202,22 @@ def tf(): test('.print asdf', out='asdf') -test(''' +test( + ''' .headers on SELECT 42 as wilbur; -''', out="wilbur") +''', + out="wilbur", +) -test(''' +test( + ''' .nullvalue wilbur SELECT NULL; -''', out="wilbur") +''', + out="wilbur", +) test("select 'yo' where 'abc' like 'a%c';", out='yo') @@ -184,43 +228,61 @@ def tf(): test('.load %s' % tf(), err="Error") # error in streaming result -test(''' +test( + ''' SELECT x::INT FROM (SELECT x::VARCHAR x FROM range(10) tbl(x) UNION ALL SELECT 'hello' x) tbl(x); -''', err='Could not convert string') +''', + err='Could not convert string', +) # test explain test('explain select sum(i) from range(1000) tbl(i)', out='RANGE') test('explain analyze select sum(i) from range(1000) tbl(i)', out='RANGE') # test returning insert -test(''' +test( + ''' CREATE TABLE table1 (a INTEGER DEFAULT -1, b INTEGER DEFAULT -2, c INTEGER DEFAULT -3); INSERT INTO table1 VALUES (1, 2, 3) RETURNING *; SELECT COUNT(*) FROM table1; -''', out='1') +''', + out='1', +) # test display of pragmas -test(''' +test( + ''' CREATE TABLE table1 (mylittlecolumn INTEGER); pragma table_info('table1'); -''', out='mylittlecolumn') +''', + out='mylittlecolumn', +) # test display of show -test(''' +test( + ''' CREATE TABLE table1 (mylittlecolumn INTEGER); show table1; -''', out='mylittlecolumn') +''', + out='mylittlecolumn', +) # test display of call -test(''' +test( + ''' CALL range(4); -''', out='3') +''', + out='3', +) # test display of prepare/execute -test(''' +test( + ''' PREPARE v1 AS SELECT ?::INT; EXECUTE v1(42); -''', out='42') +''', + out='42', +) # this should be fixed @@ -239,7 +301,7 @@ def tf(): # FIXME # Parser Error: syntax error at or near "[" # LINE 1: ...concat(quote(s.name) || '.' || quote(f.[from]) || '=?' || fkey_collate_claus... -#test('.lint fkey-indexes') +# test('.lint fkey-indexes') test('.timeout', err='sqlite3_busy_timeout') @@ -257,132 +319,188 @@ def tf(): test('.stats on') test('.stats off') -test(''' +test( + ''' create table test (a int, b varchar); insert into test values (1, 'hello'); .schema test -''', out="CREATE TABLE test(a INTEGER, b VARCHAR);") +''', + out="CREATE TABLE test(a INTEGER, b VARCHAR);", +) -test(''' +test( + ''' create table test (a int, b varchar); insert into test values (1, 'hello'); .schema tes% -''', out="CREATE TABLE test(a INTEGER, b VARCHAR);") +''', + out="CREATE TABLE test(a INTEGER, b VARCHAR);", +) -test(''' +test( + ''' create table test (a int, b varchar); insert into test values (1, 'hello'); .schema tes* -''', out="CREATE TABLE test(a INTEGER, b VARCHAR);") +''', + out="CREATE TABLE test(a INTEGER, b VARCHAR);", +) -test(''' +test( + ''' create table test (a int, b varchar); CREATE TABLE test2(a INTEGER, b VARCHAR); .schema -''', out="CREATE TABLE test2(a INTEGER, b VARCHAR);") +''', + out="CREATE TABLE test2(a INTEGER, b VARCHAR);", +) test('.fullschema', 'No STAT tables available', '') -test(''' +test( + ''' CREATE TABLE asda (i INTEGER); CREATE TABLE bsdf (i INTEGER); CREATE TABLE csda (i INTEGER); .tables -''', out="asda bsdf csda") +''', + out="asda bsdf csda", +) -test(''' +test( + ''' CREATE TABLE asda (i INTEGER); CREATE TABLE bsdf (i INTEGER); CREATE TABLE csda (i INTEGER); .tables %da -''', out="asda csda") +''', + out="asda csda", +) -test('.indexes', out="") +test('.indexes', out="") -test(''' +test( + ''' CREATE TABLE a (i INTEGER); CREATE INDEX a_idx ON a(i); .indexes a% -''', out="a_idx") +''', + out="a_idx", +) # this does not seem to output anything test('.sha3sum') -test(''' +test( + ''' .mode jsonlines SELECT 42,43; -''', out='{"42":42,"43":43}') +''', + out='{"42":42,"43":43}', +) -test(''' +test( + ''' .mode csv .separator XX SELECT 42,43; -''', out="42XX43") +''', + out="42XX43", +) -test(''' +test( + ''' .timer on SELECT NULL; -''', out="Run Time (s):") +''', + out="Run Time (s):", +) -test(''' +test( + ''' .scanstats on SELECT NULL; -''', err='scanstats') +''', + err='scanstats', +) test('.trace %s\n; SELECT 42;' % tf(), err='sqlite3_trace_v2') outfile = tf() -test(''' +test( + ''' .mode csv .output %s SELECT 42; -''' % outfile) -outstr = open(outfile,'rb').read() +''' + % outfile +) +outstr = open(outfile, 'rb').read() if b'42' not in outstr: - raise Exception('.output test failed') + raise Exception('.output test failed') # issue 6204 -test(''' +test( + ''' .output foo.txt select * from range(2049); -''') +''' +) outfile = tf() -test(''' +test( + ''' .once %s SELECT 43; -''' % outfile) -outstr = open(outfile,'rb').read() +''' + % outfile +) +outstr = open(outfile, 'rb').read() if b'43' not in outstr: - raise Exception('.once test failed') + raise Exception('.once test failed') # This somehow does not log nor fail. works for me. -test(''' +test( + ''' .log %s SELECT 42; .log off -''' % tf()) +''' + % tf() +) -test(''' +test( + ''' .mode ascii SELECT NULL, 42, 'fourty-two', 42.0; -''', out='fourty-two') +''', + out='fourty-two', +) -test(''' +test( + ''' .mode csv SELECT NULL, 42, 'fourty-two', 42.0; -''', out=',fourty-two,') +''', + out=',fourty-two,', +) -test(''' +test( + ''' .mode column .width 10 10 10 10 SELECT NULL, 42, 'fourty-two', 42.0; -''', out=' fourty-two ') +''', + out=' fourty-two ', +) -test(''' +test( + ''' .mode html SELECT NULL, 42, 'fourty-two', 42.0; -''', out='fourty-two') +''', + out='fourty-two', +) # FIXME sqlite3_column_blob # test(''' @@ -390,15 +508,21 @@ def tf(): # SELECT NULL, 42, 'fourty-two', 42.0; # ''', out='fourty-two') -test(''' +test( + ''' .mode line SELECT NULL, 42, 'fourty-two' x, 42.0; -''', out='x = fourty-two') +''', + out='x = fourty-two', +) -test(''' +test( + ''' .mode list SELECT NULL, 42, 'fourty-two', 42.0; -''', out='|fourty-two|') +''', + out='|fourty-two|', +) # FIXME sqlite3_column_blob and %! format specifier # test(''' @@ -406,16 +530,20 @@ def tf(): # SELECT NULL, 42, 'fourty-two', 42.0; # ''', out='fourty-two') -test(''' +test( + ''' .mode tabs SELECT NULL, 42, 'fourty-two', 42.0; -''', out='fourty-two') +''', + out='fourty-two', +) db1 = tf() db2 = tf() -test(''' +test( + ''' .open %s CREATE TABLE t1 (i INTEGER); INSERT INTO t1 VALUES (42); @@ -424,55 +552,80 @@ def tf(): INSERT INTO t2 VALUES (43); .open %s SELECT * FROM t1; -''' % (db1, db2, db1), out='42') +''' + % (db1, db2, db1), + out='42', +) # open file that is not a database duckdb_nonsense_db = 'duckdbtest_nonsensedb.db' with open(duckdb_nonsense_db, 'w+') as f: - f.write('blablabla') + f.write('blablabla') test('', err='not a valid DuckDB database file', extra_commands=[duckdb_nonsense_db]) os.remove(duckdb_nonsense_db) # enable_profiling doesn't result in any output -test(''' +test( + ''' PRAGMA enable_profiling -''', err="") +''', + err="", +) # only when we follow it up by an actual query does something get printed to the terminal -test(''' +test( + ''' PRAGMA enable_profiling; SELECT 42; -''', out="42", err="Query Profiling Information") +''', + out="42", + err="Query Profiling Information", +) # escapes in query profiling -test(""" +test( + """ PRAGMA enable_profiling=json; CREATE TABLE "foo"("hello world" INT); SELECT "hello world", '\r\t\n\b\f\\' FROM "foo"; -""", err="""SELECT \\"hello world\\", '\\r\\t\\n\\b\\f\\\\' FROM \\"foo""") +""", + err="""SELECT \\"hello world\\", '\\r\\t\\n\\b\\f\\\\' FROM \\"foo""", +) test('.system echo 42', out="42") test('.shell echo 42', out="42") # query profiling that includes the optimizer -test(""" +test( + """ PRAGMA enable_profiling=query_tree_optimizer; SELECT 42; -""", out="42", err="Optimizer") +""", + out="42", + err="Optimizer", +) # detailed also includes optimizer -test(""" +test( + """ PRAGMA enable_profiling; PRAGMA profiling_mode=detailed; SELECT 42; -""", out="42", err="Optimizer") +""", + out="42", + err="Optimizer", +) # even in json output mode -test(""" +test( + """ PRAGMA enable_profiling=json; PRAGMA profiling_mode=detailed; SELECT 42; -""", out="42", err="optimizer") +""", + out="42", + err="optimizer", +) # this fails because db_config is missing # test(''' @@ -489,57 +642,72 @@ def tf(): # ''' % tempfile.mktemp()) - test('.databases', out='memory') # .dump test -test(''' +test( + ''' CREATE TABLE a (i INTEGER); .changes off INSERT INTO a VALUES (42); .dump -''', 'CREATE TABLE a(i INTEGER)') +''', + 'CREATE TABLE a(i INTEGER)', +) -test(''' +test( + ''' CREATE TABLE a (i INTEGER); .changes off INSERT INTO a VALUES (42); .dump -''', 'COMMIT') +''', + 'COMMIT', +) # .dump a specific table -test(''' +test( + ''' CREATE TABLE a (i INTEGER); .changes off INSERT INTO a VALUES (42); .dump a -''', 'CREATE TABLE a(i INTEGER);') +''', + 'CREATE TABLE a(i INTEGER);', +) # .dump LIKE -test(''' +test( + ''' CREATE TABLE a (i INTEGER); .changes off INSERT INTO a VALUES (42); .dump a% -''', 'CREATE TABLE a(i INTEGER);') +''', + 'CREATE TABLE a(i INTEGER);', +) # more types, tables and views -test(''' +test( + ''' CREATE TABLE a (d DATE, k FLOAT, t TIMESTAMP); CREATE TABLE b (c INTEGER); .changes off INSERT INTO a VALUES (DATE '1992-01-01', 0.3, NOW()); INSERT INTO b SELECT * FROM range(0,10); .dump -''', 'CREATE TABLE a(d DATE, k FLOAT, t TIMESTAMP);') +''', + 'CREATE TABLE a(d DATE, k FLOAT, t TIMESTAMP);', +) # import/export database target_dir = 'duckdb_shell_test_export_dir' try: - shutil.rmtree(target_dir) + shutil.rmtree(target_dir) except: - pass -test(''' + pass +test( + ''' .mode csv .changes off CREATE TABLE integers(i INTEGER); @@ -551,7 +719,10 @@ def tf(): DROP TABLE integers2; IMPORT DATABASE '%s'; SELECT SUM(i)*MAX(i) FROM integers JOIN integers2 USING (i); -''' % (target_dir, target_dir), '10197') +''' + % (target_dir, target_dir), + '10197', +) shutil.rmtree(target_dir) @@ -559,29 +730,38 @@ def tf(): duckdb_nonsensecsv = 'duckdbtest_nonsensecsv.csv' with open(duckdb_nonsensecsv, 'wb+') as f: - f.write(b'\xFF\n') -test(''' + f.write(b'\xFF\n') +test( + ''' .nullvalue NULL CREATE TABLE test(i INTEGER); .import duckdbtest_nonsensecsv.csv test SELECT * FROM test; -''', out="NULL") +''', + out="NULL", +) os.remove(duckdb_nonsensecsv) # .mode latex -test(''' +test( + ''' .mode latex CREATE TABLE a (I INTEGER); .changes off INSERT INTO a VALUES (42); SELECT * FROM a; -''', '\\begin{tabular}') +''', + '\\begin{tabular}', +) # .mode trash -test(''' +test( + ''' .mode trash SELECT 1; -''', '') +''', + '', +) # dump blobs: FIXME # test(''' @@ -603,310 +783,417 @@ def tf(): # test that sqlite3_complete works somewhat correctly -test('''/* +test( + '''/* ; */ select 42; -''', out='42') +''', + out='42', +) -test('''-- this is a comment ; +test( + '''-- this is a comment ; select 42; -''', out='42') +''', + out='42', +) -test('''--;;;;;; +test( + '''--;;;;;; select 42; -''', out='42') +''', + out='42', +) test('/* ;;;;;; */ select 42;', out='42') # sqlite udfs -test(''' +test( + ''' SELECT writefile(); -''', err='wrong number of arguments to function writefile') +''', + err='wrong number of arguments to function writefile', +) -test(''' +test( + ''' SELECT writefile('hello'); -''', err='wrong number of arguments to function writefile') +''', + err='wrong number of arguments to function writefile', +) -test(''' +test( + ''' SELECT writefile('duckdbtest_writefile', 'hello'); -''') +''' +) test_writefile = 'duckdbtest_writefile' if not os.path.exists(test_writefile): - raise Exception(f"Failed to write file {test_writefile}"); + raise Exception(f"Failed to write file {test_writefile}") with open(test_writefile, 'r') as f: - text = f.read() + text = f.read() if text != 'hello': - raise Exception("Incorrect contents for test writefile") + raise Exception("Incorrect contents for test writefile") os.remove(test_writefile) -test(''' +test( + ''' SELECT lsmode(1) AS lsmode; -''', out='lsmode') +''', + out='lsmode', +) # test auto-complete -test(""" +test( + """ CALL sql_auto_complete('SEL') -""", out="SELECT" +""", + out="SELECT", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('SELECT my_') LIMIT 1; -""", out="my_column" +""", + out="my_column", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('SELECT my_column FROM my_') LIMIT 1; -""", out="my_table" +""", + out="my_table", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('SELECT my_column FROM my_table WH') LIMIT 1; -""", out="WHERE" +""", + out="WHERE", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('INS') LIMIT 1; -""", out="INSERT" +""", + out="INSERT", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('INSERT IN') LIMIT 1; -""", out="INTO" +""", + out="INTO", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('INSERT INTO my_t') LIMIT 1; -""", out="my_table" +""", + out="my_table", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('INSERT INTO my_table VAL') LIMIT 1; -""", out="VALUES" +""", + out="VALUES", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('DEL') LIMIT 1; -""", out="DELETE" +""", + out="DELETE", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('DELETE F') LIMIT 1; -""", out="FROM" +""", + out="FROM", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('DELETE FROM m') LIMIT 1; -""", out="my_table" +""", + out="my_table", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('DELETE FROM my_table WHERE m') LIMIT 1; -""", out="my_column" +""", + out="my_column", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('U') LIMIT 1; -""", out="UPDATE" +""", + out="UPDATE", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('UPDATE m') LIMIT 1; -""", out="my_table" +""", + out="my_table", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('UPDATE "m') LIMIT 1; -""", out="my_table" +""", + out="my_table", ) -test(""" +test( + """ CREATE TABLE my_table(my_column INTEGER); SELECT * FROM sql_auto_complete('UPDATE my_table SET m') LIMIT 1; -""", out="my_column" +""", + out="my_column", ) -test(""" +test( + """ CREATE TABLE "Funky Table With Spaces"(my_column INTEGER); SELECT * FROM sql_auto_complete('SELECT * FROM F') LIMIT 1; -""", out="\"Funky Table With Spaces\"" +""", + out="\"Funky Table With Spaces\"", ) -test(""" +test( + """ CREATE TABLE "Funky Table With Spaces"("Funky Column" int); SELECT * FROM sql_auto_complete('select f') LIMIT 1; -""", out="\"Funky Column\"" +""", + out="\"Funky Column\"", ) -test(""" +test( + """ CREATE TABLE "Funky Table With Spaces"("Funky Column" int); SELECT * FROM sql_auto_complete('select "Funky Column" FROM f') LIMIT 1; -""", out="\"Funky Table With Spaces\"" +""", + out="\"Funky Table With Spaces\"", ) # semicolon -test(""" +test( + """ SELECT * FROM sql_auto_complete('SELECT 42; SEL') LIMIT 1; -""", out="SELECT" +""", + out="SELECT", ) # comments -test(""" +test( + """ SELECT * FROM sql_auto_complete('--SELECT * FROM SEL') LIMIT 1; -""", out="SELECT" +""", + out="SELECT", ) # scalar functions -test(""" +test( + """ SELECT * FROM sql_auto_complete('SELECT regexp_m') LIMIT 1; -""", out="regexp_matches" +""", + out="regexp_matches", ) # aggregate functions -test(""" +test( + """ SELECT * FROM sql_auto_complete('SELECT approx_c') LIMIT 1; -""", out="approx_count_distinct" +""", + out="approx_count_distinct", ) # built-in views -test(""" +test( + """ SELECT * FROM sql_auto_complete('SELECT * FROM sqlite_ma') LIMIT 1; -""", out="sqlite_master" +""", + out="sqlite_master", ) # table functions -test(""" +test( + """ SELECT * FROM sql_auto_complete('SELECT * FROM read_csv_a') LIMIT 1; -""", out="read_csv_auto" +""", + out="read_csv_auto", ) -test(""" +test( + """ CREATE TABLE partsupp(ps_suppkey int); CREATE TABLE supplier(s_suppkey int); CREATE TABLE nation(n_nationkey int); SELECT * FROM sql_auto_complete('DROP TABLE na') LIMIT 1; -""", out="nation" +""", + out="nation", ) -test(""" +test( + """ CREATE TABLE partsupp(ps_suppkey int); CREATE TABLE supplier(s_suppkey int); CREATE TABLE nation(n_nationkey int); SELECT * FROM sql_auto_complete('SELECT s_supp') LIMIT 1; -""", out="s_suppkey" +""", + out="s_suppkey", ) # joins -test(""" +test( + """ CREATE TABLE partsupp(ps_suppkey int); CREATE TABLE supplier(s_suppkey int); CREATE TABLE nation(n_nationkey int); SELECT * FROM sql_auto_complete('SELECT * FROM partsupp JOIN supp') LIMIT 1; -""", out="supplier" +""", + out="supplier", ) -test(""" +test( + """ CREATE TABLE partsupp(ps_suppkey int); CREATE TABLE supplier(s_suppkey int); CREATE TABLE nation(n_nationkey int); .mode csv SELECT l,l FROM sql_auto_complete('SELECT * FROM partsupp JOIN supplier ON (s_supp') t(l) LIMIT 1; -""", out="s_suppkey,s_suppkey" +""", + out="s_suppkey,s_suppkey", ) -test(""" +test( + """ CREATE TABLE partsupp(ps_suppkey int); CREATE TABLE supplier(s_suppkey int); CREATE TABLE nation(n_nationkey int); SELECT * FROM sql_auto_complete('SELECT * FROM partsupp JOIN supplier USING (ps_') LIMIT 1; -""", out="ps_suppkey" +""", + out="ps_suppkey", ) -test(""" +test( + """ SELECT * FROM sql_auto_complete('SELECT * FR') LIMIT 1; -""", out="FROM" +""", + out="FROM", ) -test(""" +test( + """ CREATE TABLE MyTable(MyColumn Varchar); SELECT * FROM sql_auto_complete('SELECT My') LIMIT 1; -""", out="MyColumn" +""", + out="MyColumn", ) -test(""" +test( + """ CREATE TABLE MyTable(MyColumn Varchar); SELECT * FROM sql_auto_complete('SELECT MyColumn FROM My') LIMIT 1; -""", out="MyTable" +""", + out="MyTable", ) # duckbox renderer displays the number of rows if there are none -test(''' +test( + ''' .mode duckbox select 42 limit 0; -''', out='0 rows') +''', + out='0 rows', +) # #5411 - with maxrows=2, we still display all 4 rows (hiding them would take up more space) -test(''' +test( + ''' .maxrows 2 select * from range(4); -''', out='1') +''', + out='1', +) outfile = tf() -test(''' +test( + ''' .maxrows 2 .output %s SELECT * FROM range(100); -''' % outfile) -outstr = open(outfile,'rb').read().decode('utf8') +''' + % outfile +) +outstr = open(outfile, 'rb').read().decode('utf8') if '50' not in outstr: - raise Exception('.output test failed') + raise Exception('.output test failed') # we always display all columns when outputting to a file columns = ', '.join([str(x) for x in range(100)]) outfile = tf() -test(''' +test( + ''' .output %s SELECT %s -''' % (outfile, columns)) -outstr = open(outfile,'rb').read().decode('utf8') +''' + % (outfile, columns) +) +outstr = open(outfile, 'rb').read().decode('utf8') if '99' not in outstr: - raise Exception('.output test failed') + raise Exception('.output test failed') # columnar mode -test(''' +test( + ''' .col select * from range(4); -''', out='Row 1') +''', + out='Row 1', +) columns = ','.join(["'MyValue" + str(x) + "'" for x in range(100)]) -test(f''' +test( + f''' .col select {columns}; -''', out='MyValue50') +''', + out='MyValue50', +) -test(f''' +test( + f''' .col select {columns} from range(1000) -''', out='100 columns') +''', + out='100 columns', +) # test null-byte rendering test('select varchar from test_all_types();', out='goo\\0se') @@ -920,18 +1207,20 @@ def tf(): temp_file = os.path.join(temp_dir, 'myfile') os.mkdir(temp_dir) with open(temp_file, 'w+') as f: - f.write('hello world') + f.write('hello world') -test(f''' +test( + f''' SET temp_directory='{temp_dir}'; PRAGMA memory_limit='2MB'; CREATE TABLE t1 AS SELECT * FROM range(1000000); -''') +''' +) # make sure the temp directory or existing files are not deleted assert os.path.isdir(temp_dir) with open(temp_file, 'r') as f: - assert f.read() == "hello world" + assert f.read() == "hello world" # all other files are gone assert os.listdir(temp_dir) == ['myfile'] @@ -940,50 +1229,58 @@ def tf(): os.rmdir(temp_dir) # now use a new temp directory -test(f''' +test( + f''' SET temp_directory='{temp_dir}'; PRAGMA memory_limit='2MB'; CREATE TABLE t1 AS SELECT * FROM range(1000000); -''') +''' +) # make sure the temp directory is deleted assert not os.path.isdir(temp_dir) if os.name != 'nt': - shell_test_dir = 'shell_test_dir' - try: - os.mkdir(shell_test_dir) - except: - pass - try: - os.mkdir(os.path.join(shell_test_dir, 'extra_path')) - except: - pass - - base_files = ['extra.parquet', 'extra.file'] - for fname in base_files: - with open(os.path.join(shell_test_dir, fname), 'w+') as f: - f.write('') - - test(""" + shell_test_dir = 'shell_test_dir' + try: + os.mkdir(shell_test_dir) + except: + pass + try: + os.mkdir(os.path.join(shell_test_dir, 'extra_path')) + except: + pass + + base_files = ['extra.parquet', 'extra.file'] + for fname in base_files: + with open(os.path.join(shell_test_dir, fname), 'w+') as f: + f.write('') + + test( + """ CREATE TABLE MyTable(MyColumn Varchar); SELECT * FROM sql_auto_complete('SELECT * FROM ''shell_test') LIMIT 1; - """, out="shell_test_dir/" - ) + """, + out="shell_test_dir/", + ) - test(""" + test( + """ CREATE TABLE MyTable(MyColumn Varchar); SELECT * FROM sql_auto_complete('SELECT * FROM ''shell_test_dir/extra') LIMIT 1; - """, out="extra_path/" - ) + """, + out="extra_path/", + ) - test(""" + test( + """ CREATE TABLE MyTable(MyColumn Varchar); SELECT * FROM sql_auto_complete('SELECT * FROM ''shell_test_dir/extra.par') LIMIT 1; - """, out="extra.parquet" - ) + """, + out="extra.parquet", + ) - shutil.rmtree(shell_test_dir) + shutil.rmtree(shell_test_dir) # test backwards compatibility test('.open test/storage/bc/db_dev.db', err='older development version') @@ -1001,101 +1298,123 @@ def tf(): test("select sha3('hello world this is a long string');", out='D4') if os.name != 'nt': - test(''' + test( + ''' create table mytable as select * from read_csv('/dev/stdin', columns=STRUCT_PACK(foo := 'INTEGER', bar := 'INTEGER', baz := 'VARCHAR'), AUTO_DETECT='false' ); select * from mytable limit 1;''', - extra_commands=['-csv', ':memory:'], - input_file='test/sql/copy/csv/data/test/test.csv', - out='''foo,bar,baz -0,0," test"''') - - test(''' + extra_commands=['-csv', ':memory:'], + input_file='test/sql/copy/csv/data/test/test.csv', + out='''foo,bar,baz +0,0," test"''', + ) + + test( + ''' create table mytable as select * from read_csv_auto('/dev/stdin'); select * from mytable limit 1; ''', - extra_commands=['-csv', ':memory:'], - input_file='test/sql/copy/csv/data/test/test.csv', - out='''column0,column1,column2 -0,0," test"''') - - test('''create table mytable as select * from + extra_commands=['-csv', ':memory:'], + input_file='test/sql/copy/csv/data/test/test.csv', + out='''column0,column1,column2 +0,0," test"''', + ) + + test( + '''create table mytable as select * from read_csv_auto('/dev/stdin'); select channel,i_brand_id,sum_sales,number_sales from mytable; ''', - extra_commands=['-csv', ':memory:'], - input_file='data/csv/tpcds_14.csv', - out='''web,8006004,844.21,21''') + extra_commands=['-csv', ':memory:'], + input_file='data/csv/tpcds_14.csv', + out='''web,8006004,844.21,21''', + ) - test('''create table mytable as select * from + test( + '''create table mytable as select * from read_ndjson_objects('/dev/stdin'); select * from mytable; ''', - extra_commands=['-list', ':memory:'], - input_file='data/json/example_rn.ndjson', - out='''json + extra_commands=['-list', ':memory:'], + input_file='data/json/example_rn.ndjson', + out='''json {"id":1,"name":"O Brother, Where Art Thou?"} {"id":2,"name":"Home for the Holidays"} {"id":3,"name":"The Firm"} {"id":4,"name":"Broadcast News"} -{"id":5,"name":"Raising Arizona"}''') +{"id":5,"name":"Raising Arizona"}''', + ) - test('''create table mytable as select * from + test( + '''create table mytable as select * from read_ndjson_objects('/dev/stdin'); select * from mytable; ''', - extra_commands=['-list', ':memory:'], - input_file='data/json/example_rn.ndjson', - out='''json + extra_commands=['-list', ':memory:'], + input_file='data/json/example_rn.ndjson', + out='''json {"id":1,"name":"O Brother, Where Art Thou?"} {"id":2,"name":"Home for the Holidays"} {"id":3,"name":"The Firm"} {"id":4,"name":"Broadcast News"} -{"id":5,"name":"Raising Arizona"}''') +{"id":5,"name":"Raising Arizona"}''', + ) - test('''create table mytable as select * from + test( + '''create table mytable as select * from read_json_auto('/dev/stdin'); select * from mytable; ''', - extra_commands=['-list', ':memory:'], - input_file='data/json/example_rn.ndjson', - out='''id|name + extra_commands=['-list', ':memory:'], + input_file='data/json/example_rn.ndjson', + out='''id|name 1|O Brother, Where Art Thou? 2|Home for the Holidays 3|The Firm 4|Broadcast News -5|Raising Arizona''') +5|Raising Arizona''', + ) - test(''' + test( + ''' COPY (SELECT 42) TO '/dev/stdout' WITH (FORMAT 'csv'); ''', - extra_commands=['-csv', ':memory:'], - out='''42''') + extra_commands=['-csv', ':memory:'], + out='''42''', + ) - test(''' + test( + ''' COPY (SELECT 42) TO stdout WITH (FORMAT 'csv'); ''', - extra_commands=['-csv', ':memory:'], - out='''42''') + extra_commands=['-csv', ':memory:'], + out='''42''', + ) - test(''' + test( + ''' COPY (SELECT 42) TO '/dev/stderr' WITH (FORMAT 'csv'); ''', - extra_commands=['-csv', ':memory:'], - err='''42''') + extra_commands=['-csv', ':memory:'], + err='''42''', + ) - test(''' + test( + ''' copy (select 42) to '/dev/stdout' ''', - out='''42''') + out='''42''', + ) - test(''' + test( + ''' select list(concat('thisisalongstring', range::VARCHAR)) i from range(10000) ''', - out='''thisisalongstring''') + out='''thisisalongstring''', + ) - test("copy (select * from range(10000) tbl(i)) to '/dev/stdout' (format csv)", out='9999', output_file=tf()) + test("copy (select * from range(10000) tbl(i)) to '/dev/stdout' (format csv)", out='9999', output_file=tf()) diff --git a/tools/swift/create_package.py b/tools/swift/create_package.py index 7adf198d1d9b..052f52f5d8a0 100644 --- a/tools/swift/create_package.py +++ b/tools/swift/create_package.py @@ -18,23 +18,23 @@ # path to target base_dir = os.path.abspath(sys.argv[1] if len(sys.argv) > 1 else os.getcwd()) -package_dir = os.path.join(base_dir, repo_name) +package_dir = os.path.join(base_dir, repo_name) target_dir = os.path.join(package_dir, 'Sources', swift_target_name) includes_dir = os.path.join(target_dir, 'include') src_dir = os.path.join(target_dir, src_dir_name) # Prepare target directory -Path(target_dir).mkdir(parents=True,exist_ok=True) -Path(includes_dir).mkdir(parents=True,exist_ok=True) +Path(target_dir).mkdir(parents=True, exist_ok=True) +Path(includes_dir).mkdir(parents=True, exist_ok=True) # build package source files os.chdir(base_dir) os.chdir(os.path.join('..', '..')) sys.path.append('scripts') import package_build + # fresh build - copy over all of the files -(source_list, include_list, _) = package_build.build_package( - src_dir, extensions, 32, src_dir_name) +(source_list, include_list, _) = package_build.build_package(src_dir, extensions, 32, src_dir_name) # standardise paths source_list = [os.path.relpath(x, target_dir) if os.path.isabs(x) else x for x in source_list] include_list = [os.path.join(src_dir_name, x) for x in include_list] @@ -44,13 +44,13 @@ # copy umbrella header to path SPM expects (auto .modulemap) header_file_src = os.path.join(src_dir, 'src', 'include', 'duckdb.h') -header_file_dest= os.path.join(includes_dir, 'duckdb.h') +header_file_dest = os.path.join(includes_dir, 'duckdb.h') shutil.copyfile(header_file_src, header_file_dest) source_list_strs = ['"' + x + '",' for x in source_list] include_list_strs = ['.headerSearchPath("' + x + '"),' for x in include_list] define_list_strs = ['.define("' + x + '"),' for x in define_list] -src_line_prefix = '\n ' # indents eight spaces +src_line_prefix = '\n ' # indents eight spaces content = { 'source_list': src_line_prefix.join(source_list_strs), @@ -63,4 +63,4 @@ src = Template(f.read()) result = src.substitute(content) with open(package_manifest_path, 'w') as f: - f.write(result) \ No newline at end of file + f.write(result) diff --git a/tools/upload-s3.py b/tools/upload-s3.py index 5253029ece9a..03d9757aed90 100644 --- a/tools/upload-s3.py +++ b/tools/upload-s3.py @@ -2,7 +2,7 @@ import sys import os -if (len(sys.argv) < 3): +if len(sys.argv) < 3: print("Usage: [prefix] [filename1] [filename2] ... ") exit(1) @@ -10,12 +10,13 @@ def git_rev_hash(): return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode("utf-8").strip() + prefix = sys.argv[1].strip() # Hannes controls this web server # Files are served at https://download.duckdb.org/... -secret_key=os.getenv('DAV_PASSWORD') +secret_key = os.getenv('DAV_PASSWORD') if secret_key is None: print("Can't find DAV_PASSWORD in env ") exit(2) @@ -34,23 +35,31 @@ def git_rev_hash(): folder = 'rev/%s/%s' % (git_hash, prefix) + def curlcmd(cmd, path): - p = subprocess.Popen(['curl','--retry', '10'] + cmd + ['http://duckdb:%s@dav10635776.mywebdav.de/duckdb-download/%s' % (secret_key, path)], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p = subprocess.Popen( + ['curl', '--retry', '10'] + + cmd + + ['http://duckdb:%s@dav10635776.mywebdav.de/duckdb-download/%s' % (secret_key, path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) output, err = p.communicate() rc = p.returncode - if (p.returncode != 0): + if p.returncode != 0: print(err) exit(4) + # create dirs, no recursive create supported by webdav f_p = '' for p in folder.split('/'): f_p += p + '/' - curlcmd(['-X','MKCOL'], f_p) - + curlcmd(['-X', 'MKCOL'], f_p) + for f in files: base = os.path.basename(f) key = '%s/%s' % (folder, base) print("%s\t->\thttps://download.duckdb.org/%s " % (f, key)) - curlcmd(['-T',f], key) + curlcmd(['-T', f], key) From de1d68e503cd80f204a0cb9c5921b9cba7921d68 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 24 Jul 2023 15:15:46 +0200 Subject: [PATCH 05/11] tidy fixes --- tools/pythonpkg/src/pandas/bind.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/pythonpkg/src/pandas/bind.cpp b/tools/pythonpkg/src/pandas/bind.cpp index 152c01459339..a23e3949cb26 100644 --- a/tools/pythonpkg/src/pandas/bind.cpp +++ b/tools/pythonpkg/src/pandas/bind.cpp @@ -9,7 +9,8 @@ namespace { struct PandasBindColumn { public: - PandasBindColumn(py::handle name, py::handle type, py::object column) : name(name), type(type), handle(column) { + PandasBindColumn(py::handle name, py::handle type, py::object column) + : name(name), type(type), handle(std::move(column)) { } public: @@ -43,7 +44,7 @@ struct PandasDataFrameBind { }; // namespace -static LogicalType BindColumn(PandasBindColumn column_p, PandasColumnBindData &bind_data, +static LogicalType BindColumn(PandasBindColumn &column_p, PandasColumnBindData &bind_data, const ClientContext &context) { LogicalType column_type; auto &column = column_p.handle; From 70b23044a835d358096d4c4948bbda44e199f149 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 24 Jul 2023 15:26:10 +0200 Subject: [PATCH 06/11] extend tests slightly --- tools/pythonpkg/src/pandas/bind.cpp | 3 ++- .../tests/fast/pandas/test_copy_on_write.py | 25 +++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tools/pythonpkg/src/pandas/bind.cpp b/tools/pythonpkg/src/pandas/bind.cpp index a23e3949cb26..f0b80dbf9020 100644 --- a/tools/pythonpkg/src/pandas/bind.cpp +++ b/tools/pythonpkg/src/pandas/bind.cpp @@ -133,7 +133,8 @@ void Pandas::Bind(const ClientContext &context, py::handle df_p, vector Date: Mon, 24 Jul 2023 15:45:22 +0200 Subject: [PATCH 07/11] format --- .../tests/fast/pandas/test_copy_on_write.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py b/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py index e4d31d059810..d6194984a5a0 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py +++ b/tools/pythonpkg/tests/fast/pandas/test_copy_on_write.py @@ -3,6 +3,7 @@ import pandas import datetime + # Make sure the variable get's properly reset even in case of error @pytest.fixture(autouse=True) def scoped_copy_on_write_setting(): @@ -13,21 +14,30 @@ def scoped_copy_on_write_setting(): pandas.options.mode.copy_on_write = old_value return + def convert_to_result(col): - return [(x,) for x in col] + return [(x,) for x in col] + class TestCopyOnWrite(object): - @pytest.mark.parametrize('col', [ - ['a', 'b', 'this is a long string'], - [1.2334, None, 234.12], - [123234, -213123, 2324234], - [datetime.date(1990, 12, 7), None, datetime.date(1940, 1, 13)], - [datetime.datetime(2012, 6, 21, 13, 23, 45, 328), None] - ]) + @pytest.mark.parametrize( + 'col', + [ + ['a', 'b', 'this is a long string'], + [1.2334, None, 234.12], + [123234, -213123, 2324234], + [datetime.date(1990, 12, 7), None, datetime.date(1940, 1, 13)], + [datetime.datetime(2012, 6, 21, 13, 23, 45, 328), None], + ], + ) def test_copy_on_write(self, col): assert pandas.options.mode.copy_on_write == True con = duckdb.connect() - df_in = pandas.DataFrame({'numbers': col,}) + df_in = pandas.DataFrame( + { + 'numbers': col, + } + ) rel = con.sql('select * from df_in') res = rel.fetchall() expected = convert_to_result(col) From 1befa7d8cf04f28a8ddfab8e4bb2322ecceb2caa Mon Sep 17 00:00:00 2001 From: Pedro Holanda Date: Mon, 24 Jul 2023 16:21:18 +0200 Subject: [PATCH 08/11] [ADBC] Add support for ingestion modes --- src/common/adbc/adbc.cpp | 22 ++++++++++++++++++---- test/api/adbc/test_adbc.cpp | 10 ++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/common/adbc/adbc.cpp b/src/common/adbc/adbc.cpp index b84be532a648..26caec6a2d80 100644 --- a/src/common/adbc/adbc.cpp +++ b/src/common/adbc/adbc.cpp @@ -52,12 +52,14 @@ duckdb_adbc::AdbcStatusCode duckdb_adbc_init(size_t count, struct duckdb_adbc::A namespace duckdb_adbc { +enum IngestionMode { CREATE = 0, APPEND = 1 }; struct DuckDBAdbcStatementWrapper { ::duckdb_connection connection; ::duckdb_arrow result; ::duckdb_prepared_statement statement; char *ingestion_table_name; ArrowArrayStream *ingestion_stream; + IngestionMode ingestion_mode = IngestionMode::CREATE; }; static AdbcStatusCode QueryInternal(struct AdbcConnection *connection, struct ArrowArrayStream *out, const char *query, struct AdbcError *error); @@ -428,7 +430,7 @@ void stream_schema(uintptr_t factory_ptr, duckdb::ArrowSchemaWrapper &schema) { } AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, struct ArrowArrayStream *input, - struct AdbcError *error) { + struct AdbcError *error, IngestionMode ingestion_mode) { auto status = SetErrorMaybe(connection, error, "Invalid connection"); if (status != ADBC_STATUS_OK) { @@ -446,12 +448,11 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, stru } auto cconn = (duckdb::Connection *)connection; - auto has_table = cconn->TableInfo(table_name); auto arrow_scan = cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), duckdb::Value::POINTER((uintptr_t)stream_produce), duckdb::Value::POINTER((uintptr_t)get_schema)}); try { - if (!has_table) { + if (ingestion_mode == IngestionMode::CREATE) { // We create the table based on an Arrow Scanner arrow_scan->Create(table_name); } else { @@ -505,6 +506,7 @@ AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatem statement_wrapper->result = nullptr; statement_wrapper->ingestion_stream = nullptr; statement_wrapper->ingestion_table_name = nullptr; + statement_wrapper->ingestion_mode = IngestionMode::CREATE; return ADBC_STATUS_OK; } @@ -557,7 +559,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr if (wrapper->ingestion_stream && wrapper->ingestion_table_name) { auto stream = wrapper->ingestion_stream; wrapper->ingestion_stream = nullptr; - return Ingest(wrapper->connection, wrapper->ingestion_table_name, stream, error); + return Ingest(wrapper->connection, wrapper->ingestion_table_name, stream, error, wrapper->ingestion_mode); } auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); @@ -643,6 +645,18 @@ AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *k wrapper->ingestion_table_name = strdup(value); return ADBC_STATUS_OK; } + if (strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { + if (strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { + wrapper->ingestion_mode = IngestionMode::CREATE; + return ADBC_STATUS_OK; + } else if (strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { + wrapper->ingestion_mode = IngestionMode::APPEND; + return ADBC_STATUS_OK; + } else { + SetError(error, "Invalid ingestion mode"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + } return ADBC_STATUS_INVALID_ARGUMENT; } diff --git a/test/api/adbc/test_adbc.cpp b/test/api/adbc/test_adbc.cpp index 8a89c939f291..60ed7cea7ed3 100644 --- a/test/api/adbc/test_adbc.cpp +++ b/test/api/adbc/test_adbc.cpp @@ -372,6 +372,9 @@ TEST_CASE("Test ADBC Transactions", "[adbc]") { REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_TARGET_TABLE, table_name.c_str(), &adbc_error))); + REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &adbc_error))); + REQUIRE(SUCCESS(duckdb_adbc::StatementBindStream(&adbc_statement, &input_data, &adbc_error))); REQUIRE(SUCCESS(duckdb_adbc::StatementExecuteQuery(&adbc_statement, nullptr, nullptr, &adbc_error))); @@ -416,6 +419,9 @@ TEST_CASE("Test ADBC Transactions", "[adbc]") { REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_TARGET_TABLE, table_name.c_str(), &adbc_error))); + REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &adbc_error))); + REQUIRE(SUCCESS(duckdb_adbc::StatementBindStream(&adbc_statement, &input_data, &adbc_error))); REQUIRE(SUCCESS(duckdb_adbc::StatementExecuteQuery(&adbc_statement, nullptr, nullptr, &adbc_error))); @@ -462,6 +468,8 @@ TEST_CASE("Test ADBC Transactions", "[adbc]") { REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_TARGET_TABLE, table_name.c_str(), &adbc_error))); + REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &adbc_error))); REQUIRE(SUCCESS(duckdb_adbc::StatementBindStream(&adbc_statement, &input_data, &adbc_error))); @@ -496,6 +504,8 @@ TEST_CASE("Test ADBC Transactions", "[adbc]") { REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_TARGET_TABLE, table_name.c_str(), &adbc_error))); + REQUIRE(SUCCESS(duckdb_adbc::StatementSetOption(&adbc_statement, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &adbc_error))); REQUIRE(SUCCESS(duckdb_adbc::StatementBindStream(&adbc_statement, &input_data, &adbc_error))); From 97809dc76e3bce29313c812a14d56bf70d76fe92 Mon Sep 17 00:00:00 2001 From: Pedro Holanda Date: Mon, 24 Jul 2023 16:30:50 +0200 Subject: [PATCH 09/11] Thijs' comment --- src/common/adbc/adbc.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/adbc/adbc.cpp b/src/common/adbc/adbc.cpp index 26caec6a2d80..690c426e02e7 100644 --- a/src/common/adbc/adbc.cpp +++ b/src/common/adbc/adbc.cpp @@ -52,7 +52,7 @@ duckdb_adbc::AdbcStatusCode duckdb_adbc_init(size_t count, struct duckdb_adbc::A namespace duckdb_adbc { -enum IngestionMode { CREATE = 0, APPEND = 1 }; +enum class IngestionMode { CREATE = 0, APPEND = 1 }; struct DuckDBAdbcStatementWrapper { ::duckdb_connection connection; ::duckdb_arrow result; From d2d8d5e84a1f596dd150106744ff78245f4cdade Mon Sep 17 00:00:00 2001 From: Elliana May Date: Tue, 25 Jul 2023 11:06:39 +0800 Subject: [PATCH 10/11] update format.py to format scripts --- scripts/format.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/scripts/format.py b/scripts/format.py index 2336111159d8..aa39afdc02d7 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -14,7 +14,7 @@ cpp_format_command = 'clang-format --sort-includes=0 -style=file' cmake_format_command = 'cmake-format' extensions = ['.cpp', '.c', '.hpp', '.h', '.cc', '.hh', 'CMakeLists.txt', '.test', '.test_slow', '.test_coverage', '.benchmark', '.py'] -formatted_directories = ['src', 'benchmark', 'test', 'tools', 'examples', 'extension'] +formatted_directories = ['src', 'benchmark', 'test', 'tools', 'examples', 'extension', 'scripts'] ignored_files = ['tpch_constants.hpp', 'tpcds_constants.hpp', '_generated', 'tpce_flat_input.hpp', 'test_csv_header.hpp', 'duckdb.cpp', 'duckdb.hpp', 'json.hpp', 'sqlite3.h', 'shell.c', 'termcolor.hpp', 'test_insert_invalid.test', 'httplib.hpp', 'os_win.c', 'glob.c', 'printf.c', @@ -258,13 +258,14 @@ def format_file(f, full_path, directory, ext): with open_utf8(full_path, 'r') as f: old_text = f.read() # do not format auto-generated files - if file_is_generated(old_text): + if file_is_generated(old_text) and ext != '.py': return old_lines = old_text.split('\n') new_text = get_formatted_text(f, full_path, directory, ext) - new_text = new_text.replace('ARGS &&...args', 'ARGS &&... args') + if ext in ('.cpp', '.hpp'): + new_text = new_text.replace('ARGS &&...args', 'ARGS &&... args') if check_only: new_lines = new_text.split('\n') old_lines = [x for x in old_lines if '...' not in x] @@ -311,12 +312,9 @@ def format_directory(directory): os.system(cmake_format_command.replace("${FILE}", "CMakeLists.txt")) except: pass - format_directory('src') - format_directory('benchmark') - format_directory('test') - format_directory('tools') - format_directory('examples') - format_directory('extension') + + for direct in formatted_directories: + format_directory(direct) else: for full_path in changed_files: From 4cb3d7c6ba8cbafc6378aaafbb3bb53f2a0efe88 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Tue, 25 Jul 2023 11:06:45 +0800 Subject: [PATCH 11/11] fix: format --- scripts/amalgamation.py | 132 ++++-- scripts/asset-upload-gha.py | 157 ++++--- scripts/asset-upload.py | 129 +++--- scripts/check_coverage.py | 29 +- scripts/ci_test.py | 4 +- scripts/create-release-notes.py | 71 ++-- scripts/exported_symbols_check.py | 77 ++-- scripts/format.py | 125 ++++-- scripts/fuzzer_helper.py | 24 +- scripts/generate_benchmarks.py | 29 +- scripts/generate_builtin_types.py | 15 +- scripts/generate_csv_header.py | 85 ++-- scripts/generate_enum_util.py | 41 +- scripts/generate_extensions_function.py | 56 +-- scripts/generate_flex.py | 67 +-- scripts/generate_functions.py | 45 +- scripts/generate_grammar.py | 71 ++-- scripts/generate_plan_storage_version.py | 36 +- scripts/generate_querygraph.py | 34 +- scripts/generate_serialization.py | 172 ++++++-- scripts/generate_storage_version.py | 38 +- scripts/generate_tpcds_results.py | 80 +++- scripts/generate_tpcds_schema.py | 61 +-- scripts/generate_vector_sizes.py | 24 +- scripts/gentpcecode.py | 226 +++++----- scripts/get_test_list.py | 59 +-- scripts/include_analyzer.py | 10 +- scripts/jdbc_maven_deploy.py | 75 ++-- scripts/merge_vcpkg_deps.py | 15 +- scripts/package_build.py | 38 +- scripts/plan_cost_runner.py | 21 +- scripts/pypi_cleanup.py | 81 ++-- scripts/python_helpers.py | 8 +- scripts/reduce_sql.py | 21 +- scripts/regression_check.py | 42 +- scripts/regression_test_python.py | 20 +- scripts/regression_test_runner.py | 43 +- scripts/regression_test_storage_size.py | 14 +- scripts/repeat_until_success.py | 14 +- scripts/run-clang-tidy.py | 501 ++++++++++++----------- scripts/run_fuzzer.py | 29 +- scripts/run_sqlancer.py | 9 +- scripts/run_test_list.py | 67 +-- scripts/run_tests_one_by_one.py | 99 ++--- scripts/runsqlsmith.py | 51 ++- scripts/test_compile.py | 89 ++-- scripts/test_vector_sizes.py | 37 +- scripts/test_zero_initialize.py | 49 ++- scripts/try_timeout.py | 3 + scripts/windows_ci.py | 12 +- 50 files changed, 1938 insertions(+), 1297 deletions(-) diff --git a/scripts/amalgamation.py b/scripts/amalgamation.py index dbb832e1a625..303c5a841a7a 100644 --- a/scripts/amalgamation.py +++ b/scripts/amalgamation.py @@ -18,7 +18,8 @@ include_dir = os.path.join('src', 'include') # files included in the amalgamated "duckdb.hpp" file -main_header_files = [os.path.join(include_dir, 'duckdb.hpp'), +main_header_files = [ + os.path.join(include_dir, 'duckdb.hpp'), os.path.join(include_dir, 'duckdb.h'), os.path.join(include_dir, 'duckdb', 'common', 'types', 'date.hpp'), os.path.join(include_dir, 'duckdb', 'common', 'adbc', 'adbc.h'), @@ -40,52 +41,66 @@ os.path.join(include_dir, 'duckdb', 'function', 'function.hpp'), os.path.join(include_dir, 'duckdb', 'function', 'table_function.hpp'), os.path.join(include_dir, 'duckdb', 'parser', 'parsed_data', 'create_table_function_info.hpp'), - os.path.join(include_dir, 'duckdb', 'parser', 'parsed_data', 'create_copy_function_info.hpp')] + os.path.join(include_dir, 'duckdb', 'parser', 'parsed_data', 'create_copy_function_info.hpp'), +] extended_amalgamation = False if '--extended' in sys.argv: + def add_include_dir(dirpath): return [os.path.join(dirpath, x) for x in os.listdir(dirpath)] extended_amalgamation = True - main_header_files += [os.path.join(include_dir, x) for x in [ - 'duckdb/planner/expression/bound_constant_expression.hpp', - 'duckdb/planner/expression/bound_function_expression.hpp', - 'duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp', - 'duckdb/parser/parsed_data/create_table_info.hpp', - 'duckdb/planner/parsed_data/bound_create_table_info.hpp', - 'duckdb/parser/constraints/not_null_constraint.hpp', - 'duckdb/storage/data_table.hpp', - 'duckdb/function/pragma_function.hpp', - 'duckdb/parser/qualified_name.hpp', - 'duckdb/parser/parser.hpp', - 'duckdb/planner/binder.hpp', - 'duckdb/storage/object_cache.hpp', - 'duckdb/planner/table_filter.hpp', - "duckdb/storage/statistics/base_statistics.hpp", - "duckdb/planner/filter/conjunction_filter.hpp", - "duckdb/planner/filter/constant_filter.hpp", - "duckdb/execution/operator/persistent/buffered_csv_reader.hpp", - "duckdb/common/types/vector_cache.hpp", - "duckdb/common/string_map_set.hpp", - "duckdb/planner/filter/null_filter.hpp", - "duckdb/common/arrow/arrow_wrapper.hpp", - "duckdb/common/hive_partitioning.hpp", - "duckdb/common/union_by_name.hpp", - "duckdb/planner/operator/logical_get.hpp", - "duckdb/common/compressed_file_system.hpp"]] + main_header_files += [ + os.path.join(include_dir, x) + for x in [ + 'duckdb/planner/expression/bound_constant_expression.hpp', + 'duckdb/planner/expression/bound_function_expression.hpp', + 'duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp', + 'duckdb/parser/parsed_data/create_table_info.hpp', + 'duckdb/planner/parsed_data/bound_create_table_info.hpp', + 'duckdb/parser/constraints/not_null_constraint.hpp', + 'duckdb/storage/data_table.hpp', + 'duckdb/function/pragma_function.hpp', + 'duckdb/parser/qualified_name.hpp', + 'duckdb/parser/parser.hpp', + 'duckdb/planner/binder.hpp', + 'duckdb/storage/object_cache.hpp', + 'duckdb/planner/table_filter.hpp', + "duckdb/storage/statistics/base_statistics.hpp", + "duckdb/planner/filter/conjunction_filter.hpp", + "duckdb/planner/filter/constant_filter.hpp", + "duckdb/execution/operator/persistent/buffered_csv_reader.hpp", + "duckdb/common/types/vector_cache.hpp", + "duckdb/common/string_map_set.hpp", + "duckdb/planner/filter/null_filter.hpp", + "duckdb/common/arrow/arrow_wrapper.hpp", + "duckdb/common/hive_partitioning.hpp", + "duckdb/common/union_by_name.hpp", + "duckdb/planner/operator/logical_get.hpp", + "duckdb/common/compressed_file_system.hpp", + ] + ] main_header_files += add_include_dir(os.path.join(include_dir, 'duckdb/parser/expression')) main_header_files += add_include_dir(os.path.join(include_dir, 'duckdb/parser/parsed_data')) main_header_files += add_include_dir(os.path.join(include_dir, 'duckdb/parser/tableref')) main_header_files = normalize_path(main_header_files) import package_build + # include paths for where to search for include files during amalgamation include_paths = [include_dir] + package_build.third_party_includes() # paths of where to look for files to compile and include to the final amalgamation compile_directories = [src_dir] + package_build.third_party_sources() # files always excluded -always_excluded = normalize_path(['src/amalgamation/duckdb.cpp', 'src/amalgamation/duckdb.hpp', 'src/amalgamation/parquet-amalgamation.cpp', 'src/amalgamation/parquet-amalgamation.hpp']) +always_excluded = normalize_path( + [ + 'src/amalgamation/duckdb.cpp', + 'src/amalgamation/duckdb.hpp', + 'src/amalgamation/parquet-amalgamation.cpp', + 'src/amalgamation/parquet-amalgamation.hpp', + ] +) # files excluded from the amalgamation excluded_files = ['grammar.cpp', 'grammar.hpp', 'symbols.cpp'] # files excluded from individual file compilation during test_compile @@ -93,6 +108,7 @@ def add_include_dir(dirpath): linenumbers = False + def get_includes(fpath, text): # find all the includes referred to in the directory regex_include_statements = re.findall("(^[\t ]*[#][\t ]*include[\t ]+[\"]([^\"]+)[\"])", text, flags=re.MULTILINE) @@ -103,7 +119,12 @@ def get_includes(fpath, text): included_file = x[1] if skip_duckdb_includes and 'duckdb' in included_file: continue - if ('extension_helper.cpp' in fpath and (included_file.endswith('_extension.hpp')) or included_file == 'generated_extension_loader.hpp' or included_file == 'generated_extension_headers.hpp'): + if ( + 'extension_helper.cpp' in fpath + and (included_file.endswith('_extension.hpp')) + or included_file == 'generated_extension_loader.hpp' + or included_file == 'generated_extension_headers.hpp' + ): continue if 'allocator.cpp' in fpath and included_file.endswith('jemalloc_extension.hpp'): continue @@ -122,18 +143,21 @@ def get_includes(fpath, text): raise Exception('Could not find include file "' + included_file + '", included from file "' + fpath + '"') return (include_statements, include_files) + def cleanup_file(text): # remove all "#pragma once" notifications text = re.sub('#pragma once', '', text) return text + # recursively get all includes and write them written_files = {} -#licenses +# licenses licenses = [] -def need_to_write_file(current_file, ignore_excluded = False): + +def need_to_write_file(current_file, ignore_excluded=False): if amal_dir in current_file: return False if current_file in always_excluded: @@ -146,6 +170,7 @@ def need_to_write_file(current_file, ignore_excluded = False): return False return True + def find_license(original_file): global licenses file = original_file @@ -166,7 +191,7 @@ def find_license(original_file): return licenses.index(license) -def write_file(current_file, ignore_excluded = False): +def write_file(current_file, ignore_excluded=False): global linenumbers global written_files if not need_to_write_file(current_file, ignore_excluded): @@ -179,7 +204,12 @@ def write_file(current_file, ignore_excluded = False): if current_file.startswith("third_party") and not current_file.endswith("LICENSE"): lic_idx = find_license(current_file) - text = "\n\n// LICENSE_CHANGE_BEGIN\n// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #%s\n// See the end of this file for a list\n\n" % str(lic_idx + 1) + text + "\n\n// LICENSE_CHANGE_END\n" + text = ( + "\n\n// LICENSE_CHANGE_BEGIN\n// The following code up to LICENSE_CHANGE_END is subject to THIRD PARTY LICENSE #%s\n// See the end of this file for a list\n\n" + % str(lic_idx + 1) + + text + + "\n\n// LICENSE_CHANGE_END\n" + ) (statements, includes) = get_includes(current_file, text) # find the linenr of the final #include statement we parsed @@ -202,6 +232,7 @@ def write_file(current_file, ignore_excluded = False): # now read the header and write it return cleanup_file(text) + def write_dir(dir): files = os.listdir(dir) files.sort() @@ -217,6 +248,7 @@ def write_dir(dir): text += write_file(fpath) return text + def copy_if_different(src, dest): if os.path.isfile(dest): # dest exists, check if the files are different @@ -230,13 +262,15 @@ def copy_if_different(src, dest): # print("Copying " + src + " to " + dest) shutil.copyfile(src, dest) + def git_commit_hash(): - return subprocess.check_output(['git','log','-1','--format=%h']).strip().decode('utf8') + return subprocess.check_output(['git', 'log', '-1', '--format=%h']).strip().decode('utf8') + def git_dev_version(): try: - version = subprocess.check_output(['git','describe','--tags','--abbrev=0']).strip().decode('utf8') - long_version = subprocess.check_output(['git','describe','--tags','--long']).strip().decode('utf8') + version = subprocess.check_output(['git', 'describe', '--tags', '--abbrev=0']).strip().decode('utf8') + long_version = subprocess.check_output(['git', 'describe', '--tags', '--long']).strip().decode('utf8') version_splits = version.split('.') dev_version = long_version.split('-')[1] if int(dev_version) == 0: @@ -249,6 +283,7 @@ def git_dev_version(): except: return "0.0.0" + def generate_duckdb_hpp(header_file): print("-----------------------") print("-- Writing " + header_file + " --") @@ -268,6 +303,7 @@ def generate_duckdb_hpp(header_file): for fpath in main_header_files: hfile.write(write_file(fpath)) + def generate_amalgamation(source_file, header_file): # construct duckdb.hpp from these headers generate_duckdb_hpp(header_file) @@ -291,10 +327,9 @@ def generate_amalgamation(source_file, header_file): for license in licenses: sfile.write("\n\n\n### THIRD PARTY LICENSE #%s ###\n\n" % str(license_idx + 1)) sfile.write(write_file(license)) - license_idx+=1 + license_idx += 1 sfile.write('\n\n*/\n') - copy_if_different(temp_header, header_file) copy_if_different(temp_source, source_file) try: @@ -303,6 +338,7 @@ def generate_amalgamation(source_file, header_file): except: pass + def list_files(dname, file_list): files = os.listdir(dname) files.sort() @@ -316,12 +352,14 @@ def list_files(dname, file_list): if need_to_write_file(fpath): file_list.append(fpath) + def list_sources(): file_list = [] for compile_dir in compile_directories: list_files(compile_dir, file_list) return file_list + def list_include_files_recursive(dname, file_list): files = os.listdir(dname) files.sort() @@ -334,15 +372,18 @@ def list_include_files_recursive(dname, file_list): elif fname.endswith(('.hpp', '.h', '.hh', '.tcc', '.inc')): file_list.append(fpath) + def list_includes_files(include_dirs): file_list = [] for include_dir in include_dirs: list_include_files_recursive(include_dir, file_list) return file_list + def list_includes(): return list_includes_files(include_paths) + def gather_file(current_file, source_files, header_files): global linenumbers global written_files @@ -379,6 +420,7 @@ def gather_file(current_file, source_files, header_files): text = '\n#line 1 "%s"\n' % (current_file,) + text source_files.append(cleanup_file(text)) + def gather_files(dir, source_files, header_files): files = os.listdir(dir) files.sort() @@ -391,9 +433,11 @@ def gather_files(dir, source_files, header_files): elif fname.endswith('.cpp') or fname.endswith('.c') or fname.endswith('.cc'): gather_file(fpath, source_files, header_files) + def write_license(hfile): hfile.write("// See https://raw.githubusercontent.com/duckdb/duckdb/master/LICENSE for licensing information\n\n") + def generate_amalgamation_splits(source_file, header_file, nsplits): # construct duckdb.hpp from these headers generate_duckdb_hpp(header_file) @@ -466,11 +510,13 @@ def generate_amalgamation_splits(source_file, header_file, nsplits): with open_utf8(temp_partition_name, 'w+') as f: write_license(f) f.write('#include "%s"\n#include "%s"' % (header_file_name, internal_header_file_name)) - f.write(''' + f.write( + ''' #ifndef DUCKDB_AMALGAMATION #error header mismatch #endif -''') +''' + ) for sfile in partition: f.write(sfile) current_partition += 1 @@ -488,9 +534,12 @@ def generate_amalgamation_splits(source_file, header_file, nsplits): os.remove(p[1]) except: pass + + def list_include_dirs(): return include_paths + if __name__ == "__main__": nsplits = 1 for arg in sys.argv: @@ -528,4 +577,3 @@ def list_include_dirs(): generate_amalgamation_splits(source_file, header_file, nsplits) else: generate_amalgamation(source_file, header_file) - diff --git a/scripts/asset-upload-gha.py b/scripts/asset-upload-gha.py index 2ae0d20b7195..6832d41703af 100644 --- a/scripts/asset-upload-gha.py +++ b/scripts/asset-upload-gha.py @@ -7,26 +7,26 @@ api_url = 'https://api.github.com/repos/duckdb/duckdb/' -if (len(sys.argv) < 2): - print("Usage: [filename1] [filename2] ... ") - exit(1) +if len(sys.argv) < 2: + print("Usage: [filename1] [filename2] ... ") + exit(1) # this essentially should run on release tag builds to fill up release assets and master repo = os.getenv("GITHUB_REPOSITORY", "") if repo != "duckdb/duckdb": - print("Not running on forks. Exiting.") - exit(0) + print("Not running on forks. Exiting.") + exit(0) -ref = os.getenv("GITHUB_REF", '') # this env var is always present just not always used +ref = os.getenv("GITHUB_REF", '') # this env var is always present just not always used if ref == 'refs/heads/master': - print("Not running on master. Exiting.") - exit(0) + print("Not running on master. Exiting.") + exit(0) elif ref.startswith('refs/tags/'): - tag = ref.replace('refs/tags/', '') + tag = ref.replace('refs/tags/', '') else: - print("Not running on branches. Exiting.") - exit(0) + print("Not running on branches. Exiting.") + exit(0) print("Running on tag %s" % tag) @@ -34,64 +34,63 @@ token = os.getenv("GH_TOKEN", "") if token == "": - raise ValueError('need a GitHub token in GH_TOKEN') + raise ValueError('need a GitHub token in GH_TOKEN') + def internal_gh_api(suburl, filename='', method='GET'): - url = api_url + suburl - headers = { - "Content-Type": "application/json", - 'Authorization': 'token ' + token - } - - body_data = b'' - raw_resp = None - if len(filename) > 0: - method = 'POST' - body_data = open(filename, 'rb') - headers["Content-Type"] = "binary/octet-stream" - headers["Content-Length"] = os.path.getsize(local_filename) - url = suburl # cough - - req = urllib.request.Request(url, body_data, headers) - req.get_method = lambda: method - print(f'GH API URL: "{url}" Filename: "{filename}" Method: "{method}"') - raw_resp = urllib.request.urlopen(req).read().decode() - - if (method != 'DELETE'): - return json.loads(raw_resp) - else: - return {} + url = api_url + suburl + headers = {"Content-Type": "application/json", 'Authorization': 'token ' + token} + + body_data = b'' + raw_resp = None + if len(filename) > 0: + method = 'POST' + body_data = open(filename, 'rb') + headers["Content-Type"] = "binary/octet-stream" + headers["Content-Length"] = os.path.getsize(local_filename) + url = suburl # cough + + req = urllib.request.Request(url, body_data, headers) + req.get_method = lambda: method + print(f'GH API URL: "{url}" Filename: "{filename}" Method: "{method}"') + raw_resp = urllib.request.urlopen(req).read().decode() + + if method != 'DELETE': + return json.loads(raw_resp) + else: + return {} + def gh_api(suburl, filename='', method='GET'): - timeout = 1 - nretries = 10 - success = False - for i in range(nretries+1): - try: - response = internal_gh_api(suburl, filename, method) - success = True - except urllib.error.HTTPError as e: - print(e.read().decode()) # gah - except Exception as e: - print(e) - if success: - break - print(f"Failed upload, retrying in {timeout} seconds... ({i}/{nretries})") - time.sleep(timeout) - timeout = timeout * 2 - if not success: - raise Exception("Failed to open URL " + suburl) - return response + timeout = 1 + nretries = 10 + success = False + for i in range(nretries + 1): + try: + response = internal_gh_api(suburl, filename, method) + success = True + except urllib.error.HTTPError as e: + print(e.read().decode()) # gah + except Exception as e: + print(e) + if success: + break + print(f"Failed upload, retrying in {timeout} seconds... ({i}/{nretries})") + time.sleep(timeout) + timeout = timeout * 2 + if not success: + raise Exception("Failed to open URL " + suburl) + return response # check if tag exists resp = gh_api('git/ref/tags/%s' % tag) -if 'object' not in resp or 'sha' not in resp['object'] : # or resp['object']['sha'] != sha - raise ValueError('tag %s not found' % tag) +if 'object' not in resp or 'sha' not in resp['object']: # or resp['object']['sha'] != sha + raise ValueError('tag %s not found' % tag) resp = gh_api('releases/tags/%s' % tag) if 'id' not in resp or 'upload_url' not in resp: - raise ValueError('release does not exist for tag ' % tag) + raise ValueError('release does not exist for tag ' % tag) # double-check that release exists and has correct sha @@ -102,26 +101,26 @@ def gh_api(suburl, filename='', method='GET'): # TODO this could be a paged response! assets = gh_api('releases/%s/assets' % resp['id']) -upload_url = resp['upload_url'].split('{')[0] # gah +upload_url = resp['upload_url'].split('{')[0] # gah files = sys.argv[1:] for filename in files: - if '=' in filename: - parts = filename.split("=") - asset_filename = parts[0] - paths = glob.glob(parts[1]) - if len(paths) != 1: - raise ValueError("Could not find file for pattern %s" % parts[1]) - local_filename = paths[0] - else: - asset_filename = os.path.basename(filename) - local_filename = filename - - # delete if present - for asset in assets: - if asset['name'] == asset_filename: - gh_api('releases/assets/%s' % asset['id'], method='DELETE') - - resp = gh_api(f'{upload_url}?name={asset_filename}', filename=local_filename) - if 'id' not in resp: - raise ValueError('upload failed :/ ' + str(resp)) - print("%s -> %s" % (local_filename, resp['browser_download_url'])) + if '=' in filename: + parts = filename.split("=") + asset_filename = parts[0] + paths = glob.glob(parts[1]) + if len(paths) != 1: + raise ValueError("Could not find file for pattern %s" % parts[1]) + local_filename = paths[0] + else: + asset_filename = os.path.basename(filename) + local_filename = filename + + # delete if present + for asset in assets: + if asset['name'] == asset_filename: + gh_api('releases/assets/%s' % asset['id'], method='DELETE') + + resp = gh_api(f'{upload_url}?name={asset_filename}', filename=local_filename) + if 'id' not in resp: + raise ValueError('upload failed :/ ' + str(resp)) + print("%s -> %s" % (local_filename, resp['browser_download_url'])) diff --git a/scripts/asset-upload.py b/scripts/asset-upload.py index 801b73f67198..afb0afc16afb 100644 --- a/scripts/asset-upload.py +++ b/scripts/asset-upload.py @@ -7,72 +7,71 @@ api_url = 'https://api.github.com/repos/duckdb/duckdb/' -if (len(sys.argv) < 2): - print("Usage: [filename1] [filename2] ... ") - exit(1) +if len(sys.argv) < 2: + print("Usage: [filename1] [filename2] ... ") + exit(1) # this essentially should run on release tag builds to fill up release assets and master pr = os.getenv("TRAVIS_PULL_REQUEST", "") if pr != "false": - print("Not running on PRs. Exiting.") - exit(0) + print("Not running on PRs. Exiting.") + exit(0) -tag = os.getenv("TRAVIS_TAG", '') # this env var is always present just not always used +tag = os.getenv("TRAVIS_TAG", '') # this env var is always present just not always used if tag == '': - tag = 'master-builds' + tag = 'master-builds' print("Running on tag %s" % tag) if tag == "master-builds" and os.getenv("TRAVIS_BRANCH", "") != "master": - print("Only running on master branch for %s tag. Exiting." % tag) - exit(0) + print("Only running on master branch for %s tag. Exiting." % tag) + exit(0) token = os.getenv("GH_TOKEN", "") if token == "": - raise ValueError('need a GitHub token in GH_TOKEN') + raise ValueError('need a GitHub token in GH_TOKEN') + def gh_api(suburl, filename='', method='GET'): - url = api_url + suburl - headers = { - "Content-Type": "application/json", - 'Authorization': 'token ' + token - } - - body_data = b'' - - if len(filename) > 0: - method = 'POST' - body_data = open(filename, 'rb') - - mime_type = mimetypes.guess_type(local_filename)[0] - if mime_type is None: - mime_type = "application/octet-stream" - headers["Content-Type"] = mime_type - headers["Content-Length"] = os.path.getsize(local_filename) - - url = suburl # cough - - req = urllib.request.Request(url, body_data, headers) - req.get_method = lambda: method - try: - raw_resp = urllib.request.urlopen(req).read().decode() - except urllib.error.HTTPError as e: - raw_resp = e.read().decode() # gah - - if (method != 'DELETE'): - return json.loads(raw_resp) - else: - return {} + url = api_url + suburl + headers = {"Content-Type": "application/json", 'Authorization': 'token ' + token} + + body_data = b'' + + if len(filename) > 0: + method = 'POST' + body_data = open(filename, 'rb') + + mime_type = mimetypes.guess_type(local_filename)[0] + if mime_type is None: + mime_type = "application/octet-stream" + headers["Content-Type"] = mime_type + headers["Content-Length"] = os.path.getsize(local_filename) + + url = suburl # cough + + req = urllib.request.Request(url, body_data, headers) + req.get_method = lambda: method + try: + raw_resp = urllib.request.urlopen(req).read().decode() + except urllib.error.HTTPError as e: + raw_resp = e.read().decode() # gah + + if method != 'DELETE': + return json.loads(raw_resp) + else: + return {} + # check if tag exists resp = gh_api('git/ref/tags/%s' % tag) -if 'object' not in resp or 'sha' not in resp['object'] : # or resp['object']['sha'] != sha - raise ValueError('tag %s not found' % tag) +if 'object' not in resp or 'sha' not in resp['object']: # or resp['object']['sha'] != sha + raise ValueError('tag %s not found' % tag) resp = gh_api('releases/tags/%s' % tag) if 'id' not in resp or 'upload_url' not in resp: - raise ValueError('release does not exist for tag ' % tag) + raise ValueError('release does not exist for tag ' % tag) # double-check that release exists and has correct sha # disabled to not spam people watching releases @@ -82,26 +81,26 @@ def gh_api(suburl, filename='', method='GET'): # TODO this could be a paged response! assets = gh_api('releases/%s/assets' % resp['id']) -upload_url = resp['upload_url'].split('{')[0] # gah +upload_url = resp['upload_url'].split('{')[0] # gah files = sys.argv[1:] for filename in files: - if '=' in filename: - parts = filename.split("=") - asset_filename = parts[0] - paths = glob.glob(parts[1]) - if len(paths) != 1: - raise ValueError("Could not find file for pattern %s" % local_filename) - local_filename = paths[0] - else : - asset_filename = os.path.basename(filename) - local_filename = filename - - # delete if present - for asset in assets: - if asset['name'] == asset_filename: - gh_api('releases/assets/%s' % asset['id'], method='DELETE') - - resp = gh_api(upload_url + '?name=%s' % asset_filename, filename=local_filename) - if 'id' not in resp: - raise ValueError('upload failed :/ ' + str(resp)) - print("%s -> %s" % (local_filename, resp['browser_download_url'])) + if '=' in filename: + parts = filename.split("=") + asset_filename = parts[0] + paths = glob.glob(parts[1]) + if len(paths) != 1: + raise ValueError("Could not find file for pattern %s" % local_filename) + local_filename = paths[0] + else: + asset_filename = os.path.basename(filename) + local_filename = filename + + # delete if present + for asset in assets: + if asset['name'] == asset_filename: + gh_api('releases/assets/%s' % asset['id'], method='DELETE') + + resp = gh_api(upload_url + '?name=%s' % asset_filename, filename=local_filename) + if 'id' not in resp: + raise ValueError('upload failed :/ ' + str(resp)) + print("%s -> %s" % (local_filename, resp['browser_download_url'])) diff --git a/scripts/check_coverage.py b/scripts/check_coverage.py index 47201b37def3..eae7175caadf 100644 --- a/scripts/check_coverage.py +++ b/scripts/check_coverage.py @@ -5,25 +5,39 @@ parser = argparse.ArgumentParser(description='Check code coverage results') -parser.add_argument('--uncovered_files', action='store', - help='Set of files that are not 100% covered', default=os.path.join(".github", "config", "uncovered_files.csv")) +parser.add_argument( + '--uncovered_files', + action='store', + help='Set of files that are not 100% covered', + default=os.path.join(".github", "config", "uncovered_files.csv"), +) parser.add_argument('--directory', help='Directory of generated HTML files', action='store', default='coverage_html') parser.add_argument('--fix', help='Fill up the uncovered_files.csv with all files', action='store_true', default=False) args = parser.parse_args() if not os.path.exists(args.directory): - print(f"The provided directory ({args.directory}) does not exist, please create it first") - exit(1) + print(f"The provided directory ({args.directory}) does not exist, please create it first") + exit(1) + +covered_regex = ( + r'[ \t\n]*[ \t\n0-9]+[ \t\n0-9]+:([^<]+)' +) -covered_regex = '[ \t\n]*[ \t\n0-9]+[ \t\n0-9]+:([^<]+)' def get_original_path(path): - return path.replace('.gcov.html', '').replace(os.getcwd(), '').replace('coverage_html' + os.path.sep, '').replace('home/runner/work/duckdb/duckdb/', '') + return ( + path.replace('.gcov.html', '') + .replace(os.getcwd(), '') + .replace('coverage_html' + os.path.sep, '') + .replace('home/runner/work/duckdb/duckdb/', '') + ) + def cleanup_line(line): return line.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"') + partial_coverage_dict = {} with open(args.uncovered_files, 'r') as f: for line in f.readlines(): @@ -37,6 +51,7 @@ def cleanup_line(line): total_difference = 0 allowed_difference = 1 + def check_file(path, partial_coverage_dict): global any_failed global total_difference @@ -67,7 +82,6 @@ def check_file(path, partial_coverage_dict): uncovered_file.write(f'{original_path}\t{expected_uncovered}\n') return - if len(uncovered_lines) > expected_uncovered_lines: total_difference += len(uncovered_lines) - expected_uncovered_lines @@ -96,6 +110,7 @@ def scan_directory(path): file_list += scan_directory(os.path.join(path, file)) return file_list + files = scan_directory(args.directory) files.sort() diff --git a/scripts/ci_test.py b/scripts/ci_test.py index bf5d25937002..1508fcf34beb 100644 --- a/scripts/ci_test.py +++ b/scripts/ci_test.py @@ -7,6 +7,6 @@ if fail: print("Sorry man") - assert(0) + assert 0 else: - print("Yeah man") \ No newline at end of file + print("Yeah man") diff --git a/scripts/create-release-notes.py b/scripts/create-release-notes.py index afe2b6985e9c..5769ea3250cd 100644 --- a/scripts/create-release-notes.py +++ b/scripts/create-release-notes.py @@ -2,14 +2,15 @@ api_url = 'https://api.github.com/repos/duckdb/duckdb/' -if (len(sys.argv) < 2): - print("Usage: [last_tag] ") - exit(1) +if len(sys.argv) < 2: + print("Usage: [last_tag] ") + exit(1) token = os.getenv("GH_TOKEN", "") if token == "": - raise ValueError('need a GitHub token in GH_TOKEN') + raise ValueError('need a GitHub token in GH_TOKEN') + # amazingly this is the entire code of the pypy package `linkheader-parser` def extract(link_header): @@ -22,34 +23,31 @@ def extract(link_header): rels[group_dict['rel']] = group_dict['url'] return rels -def gh_api(suburl, full_url=''): - if full_url == '': - url = api_url + suburl - else: - url = full_url - headers = { - "Content-Type": "application/json", - 'Authorization': 'token ' + token - } - - req = urllib.request.Request(url, b'', headers) - req.get_method = lambda: 'GET' - next_link = None - try: - resp = urllib.request.urlopen(req) - if (not resp.getheader("Link") is None): - link_data = extract(resp.getheader("Link")) - if ("next" in link_data): - next_link = link_data["next"] - raw_resp = resp.read().decode() - except urllib.error.HTTPError as e: - raw_resp = e.read().decode() # gah - - ret_json = json.loads(raw_resp) - if (next_link is not None): - return ret_json + gh_api('', full_url=next_link) - return ret_json +def gh_api(suburl, full_url=''): + if full_url == '': + url = api_url + suburl + else: + url = full_url + headers = {"Content-Type": "application/json", 'Authorization': 'token ' + token} + + req = urllib.request.Request(url, b'', headers) + req.get_method = lambda: 'GET' + next_link = None + try: + resp = urllib.request.urlopen(req) + if not resp.getheader("Link") is None: + link_data = extract(resp.getheader("Link")) + if "next" in link_data: + next_link = link_data["next"] + raw_resp = resp.read().decode() + except urllib.error.HTTPError as e: + raw_resp = e.read().decode() # gah + + ret_json = json.loads(raw_resp) + if next_link is not None: + return ret_json + gh_api('', full_url=next_link) + return ret_json # get time of tag @@ -58,9 +56,8 @@ def gh_api(suburl, full_url=''): pulls = gh_api('pulls?base=master&state=closed') for p in pulls: - if p["merged_at"] is None: - continue - if p["merged_at"] < old_release["published_at"]: - continue - print(" - #%s: %s" % (p["number"], p["title"])) - + if p["merged_at"] is None: + continue + if p["merged_at"] < old_release["published_at"]: + continue + print(" - #%s: %s" % (p["number"], p["title"])) diff --git a/scripts/exported_symbols_check.py b/scripts/exported_symbols_check.py index 0e4abacb1de2..5deb6b739c16 100644 --- a/scripts/exported_symbols_check.py +++ b/scripts/exported_symbols_check.py @@ -3,42 +3,65 @@ import os if len(sys.argv) < 2 or not os.path.isfile(sys.argv[1]): - print("Usage: [libduckdb dynamic library file, release build]") - exit(1) + print("Usage: [libduckdb dynamic library file, release build]") + exit(1) res = subprocess.run('nm -g -C -P'.split(' ') + [sys.argv[1]], check=True, capture_output=True) if res.returncode != 0: - raise ValueError('Failed to run `nm`') + raise ValueError('Failed to run `nm`') culprits = [] -whitelist = ['@@GLIBC', '@@CXXABI', '__gnu_cxx::', 'std::', -'N6duckdb', 'duckdb::', 'duckdb_miniz::', 'duckdb_fmt::', 'duckdb_hll::', 'duckdb_moodycamel::', 'duckdb_', -'RefCounter', 'registerTMCloneTable', 'RegisterClasses', 'Unwind_Resume', '__gmon_start', '_fini', '_init', '_version', '_end', '_edata', '__bss_start', '__udivti3', 'Adbc'] +whitelist = [ + '@@GLIBC', + '@@CXXABI', + '__gnu_cxx::', + 'std::', + 'N6duckdb', + 'duckdb::', + 'duckdb_miniz::', + 'duckdb_fmt::', + 'duckdb_hll::', + 'duckdb_moodycamel::', + 'duckdb_', + 'RefCounter', + 'registerTMCloneTable', + 'RegisterClasses', + 'Unwind_Resume', + '__gmon_start', + '_fini', + '_init', + '_version', + '_end', + '_edata', + '__bss_start', + '__udivti3', + 'Adbc', +] for symbol in res.stdout.decode('utf-8').split('\n'): - if len(symbol.strip()) == 0: - continue - if symbol.endswith(' U'): # undefined because dynamic linker - continue - if symbol.endswith(' U 0 0'): # undefined because dynamic linker - continue - - is_whitelisted = False - for entry in whitelist: - if entry in symbol: - is_whitelisted = True - if is_whitelisted: - continue - - culprits.append(symbol) - + if len(symbol.strip()) == 0: + continue + if symbol.endswith(' U'): # undefined because dynamic linker + continue + if symbol.endswith(' U 0 0'): # undefined because dynamic linker + continue + + is_whitelisted = False + for entry in whitelist: + if entry in symbol: + is_whitelisted = True + if is_whitelisted: + continue + + culprits.append(symbol) + if len(culprits) > 0: - print("Found leaked symbols. Either white-list above or change visibility:") - for symbol in culprits: - print(symbol) - sys.exit(1) + print("Found leaked symbols. Either white-list above or change visibility:") + for symbol in culprits: + print(symbol) + sys.exit(1) -sys.exit(0) \ No newline at end of file +sys.exit(0) diff --git a/scripts/format.py b/scripts/format.py index aa39afdc02d7..c6435a4da700 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -13,34 +13,90 @@ cpp_format_command = 'clang-format --sort-includes=0 -style=file' cmake_format_command = 'cmake-format' -extensions = ['.cpp', '.c', '.hpp', '.h', '.cc', '.hh', 'CMakeLists.txt', '.test', '.test_slow', '.test_coverage', '.benchmark', '.py'] +extensions = [ + '.cpp', + '.c', + '.hpp', + '.h', + '.cc', + '.hh', + 'CMakeLists.txt', + '.test', + '.test_slow', + '.test_coverage', + '.benchmark', + '.py', +] formatted_directories = ['src', 'benchmark', 'test', 'tools', 'examples', 'extension', 'scripts'] -ignored_files = ['tpch_constants.hpp', 'tpcds_constants.hpp', '_generated', 'tpce_flat_input.hpp', - 'test_csv_header.hpp', 'duckdb.cpp', 'duckdb.hpp', 'json.hpp', 'sqlite3.h', 'shell.c', - 'termcolor.hpp', 'test_insert_invalid.test', 'httplib.hpp', 'os_win.c', 'glob.c', 'printf.c', - 'helper.hpp', 'single_thread_ptr.hpp', 'types.hpp', 'default_views.cpp', 'default_functions.cpp', - 'release.h', 'genrand.cpp', 'address.cpp', 'visualizer_constants.hpp', 'icu-collate.cpp', 'icu-collate.hpp', - 'yyjson.cpp', 'yyjson.hpp', 'duckdb_pdqsort.hpp', 'stubdata.cpp', - 'nf_calendar.cpp', 'nf_calendar.h', 'nf_localedata.cpp', 'nf_localedata.h', 'nf_zformat.cpp', - 'nf_zformat.h', 'expr.cc', 'function_list.cpp'] -ignored_directories = ['.eggs', '__pycache__', 'dbgen', os.path.join('tools', 'pythonpkg', 'duckdb'), - os.path.join('tools', 'pythonpkg', 'build'), os.path.join('tools', 'rpkg', 'src', 'duckdb'), - os.path.join('tools', 'rpkg', 'inst', 'include', 'cpp11'), - os.path.join('extension', 'tpcds', 'dsdgen'), os.path.join('extension', 'jemalloc', 'jemalloc'), - os.path.join('extension', 'json', 'yyjson'), os.path.join('extension', 'icu', 'third_party'), - os.path.join('src', 'include', 'duckdb', 'core_functions', 'aggregate'), - os.path.join('src', 'include', 'duckdb', 'core_functions', 'scalar'), - os.path.join('tools', 'nodejs', 'src', 'duckdb')] +ignored_files = [ + 'tpch_constants.hpp', + 'tpcds_constants.hpp', + '_generated', + 'tpce_flat_input.hpp', + 'test_csv_header.hpp', + 'duckdb.cpp', + 'duckdb.hpp', + 'json.hpp', + 'sqlite3.h', + 'shell.c', + 'termcolor.hpp', + 'test_insert_invalid.test', + 'httplib.hpp', + 'os_win.c', + 'glob.c', + 'printf.c', + 'helper.hpp', + 'single_thread_ptr.hpp', + 'types.hpp', + 'default_views.cpp', + 'default_functions.cpp', + 'release.h', + 'genrand.cpp', + 'address.cpp', + 'visualizer_constants.hpp', + 'icu-collate.cpp', + 'icu-collate.hpp', + 'yyjson.cpp', + 'yyjson.hpp', + 'duckdb_pdqsort.hpp', + 'stubdata.cpp', + 'nf_calendar.cpp', + 'nf_calendar.h', + 'nf_localedata.cpp', + 'nf_localedata.h', + 'nf_zformat.cpp', + 'nf_zformat.h', + 'expr.cc', + 'function_list.cpp', +] +ignored_directories = [ + '.eggs', + '__pycache__', + 'dbgen', + os.path.join('tools', 'pythonpkg', 'duckdb'), + os.path.join('tools', 'pythonpkg', 'build'), + os.path.join('tools', 'rpkg', 'src', 'duckdb'), + os.path.join('tools', 'rpkg', 'inst', 'include', 'cpp11'), + os.path.join('extension', 'tpcds', 'dsdgen'), + os.path.join('extension', 'jemalloc', 'jemalloc'), + os.path.join('extension', 'json', 'yyjson'), + os.path.join('extension', 'icu', 'third_party'), + os.path.join('src', 'include', 'duckdb', 'core_functions', 'aggregate'), + os.path.join('src', 'include', 'duckdb', 'core_functions', 'scalar'), + os.path.join('tools', 'nodejs', 'src', 'duckdb'), +] format_all = False check_only = True confirm = True silent = False + def print_usage(): print("Usage: python scripts/format.py [revision|--all] [--check|--fix]") - print(" [revision] is an optional revision number, all files that changed since that revision will be formatted (default=HEAD)") print( - " if [revision] is set to --all, all files will be formatted") + " [revision] is an optional revision number, all files that changed since that revision will be formatted (default=HEAD)" + ) + print(" if [revision] is set to --all, all files will be formatted") print(" --check only prints differences, --fix also fixes the files (--check is default)") exit(1) @@ -70,6 +126,7 @@ def print_usage(): if revision == '--all': format_all = True + def file_is_ignored(full_path): if os.path.basename(full_path) in ignored_files: return True @@ -80,7 +137,6 @@ def file_is_ignored(full_path): return False - def can_format_file(full_path): global extensions, formatted_directories, ignored_files if not os.path.isfile(full_path): @@ -108,9 +164,9 @@ def can_format_file(full_path): if check_only: action = "Checking" + def get_changed_files(revision): - proc = subprocess.Popen( - ['git', 'diff', '--name-only', revision], stdout=subprocess.PIPE) + proc = subprocess.Popen(['git', 'diff', '--name-only', revision], stdout=subprocess.PIPE) files = proc.stdout.read().decode('utf8').split('\n') changed_files = [] for f in files: @@ -121,6 +177,7 @@ def get_changed_files(revision): changed_files.append(f) return changed_files + if os.path.isfile(revision): print(action + " individual file: " + revision) changed_files = [revision] @@ -173,14 +230,19 @@ def get_changed_files(revision): header_bottom += "//===----------------------------------------------------------------------===//\n\n" base_dir = os.path.join(os.getcwd(), 'src/include') + def get_formatted_text(f, full_path, directory, ext): if not can_format_file(full_path): print("Eek, cannot format file " + full_path + " but attempted to format anyway") exit(1) if f == 'list.hpp': # fill in list file - file_list = [os.path.join(dp, f) for dp, dn, filenames in os.walk( - directory) for f in filenames if os.path.splitext(f)[1] == '.hpp' and not f.endswith("list.hpp")] + file_list = [ + os.path.join(dp, f) + for dp, dn, filenames in os.walk(directory) + for f in filenames + if os.path.splitext(f)[1] == '.hpp' and not f.endswith("list.hpp") + ] file_list = [x.replace('src/include/', '') for x in file_list] file_list.sort() result = "" @@ -211,7 +273,7 @@ def get_formatted_text(f, full_path, directory, ext): found_group = False group_name = full_path.split('/')[-2] new_path_line = '# name: ' + full_path + '\n' - new_group_line = '# group: [' + group_name + ']' + '\n' + new_group_line = '# group: [' + group_name + ']' + '\n' found_diff = False # Find description. found_description = False @@ -230,12 +292,15 @@ def get_formatted_text(f, full_path, directory, ext): lines.pop(0) # Ensure header is prepended. header = [new_path_line] - if found_description: header.append(new_description_line) + if found_description: + header.append(new_description_line) header.append(new_group_line) header.append('\n') return ''.join(header + lines) proc_command = format_commands[ext].split(' ') + [full_path] - proc = subprocess.Popen(proc_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=open(full_path) if ext == '.py' else None) + proc = subprocess.Popen( + proc_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=open(full_path) if ext == '.py' else None + ) new_text = proc.stdout.read().decode('utf8') stderr = proc.stderr.read().decode('utf8') if len(stderr) > 0: @@ -248,11 +313,13 @@ def get_formatted_text(f, full_path, directory, ext): new_text = re.sub(r'\n*$', '', new_text) return new_text + '\n' + def file_is_generated(text): if '// This file is automatically generated by scripts/' in text: return True return False + def format_file(f, full_path, directory, ext): global difference_files with open_utf8(full_path, 'r') as f: @@ -262,7 +329,6 @@ def format_file(f, full_path, directory, ext): return old_lines = old_text.split('\n') - new_text = get_formatted_text(f, full_path, directory, ext) if ext in ('.cpp', '.hpp'): new_text = new_text.replace('ARGS &&...args', 'ARGS &&... args') @@ -303,8 +369,7 @@ def format_directory(directory): print(full_path) format_directory(full_path) elif can_format_file(full_path): - format_file(f, full_path, directory, '.' + - f.split('.')[-1]) + format_file(f, full_path, directory, '.' + f.split('.')[-1]) if format_all: diff --git a/scripts/fuzzer_helper.py b/scripts/fuzzer_helper.py index 4ca6c91b52d8..e36275e9b0c5 100644 --- a/scripts/fuzzer_helper.py +++ b/scripts/fuzzer_helper.py @@ -41,28 +41,31 @@ footer = ''' ```''' + def get_github_hash(): proc = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE) return proc.stdout.read().decode('utf8').strip() + # github stuff def issue_url(): return 'https://api.github.com/repos/%s/%s/issues' % (REPO_OWNER, REPO_NAME) + def create_session(): # Create an authenticated session to create the issue session = requests.Session() session.headers.update({'Authorization': 'token %s' % (TOKEN,)}) return session + def make_github_issue(title, body): if len(title) > 240: # avoid title is too long error (maximum is 256 characters) title = title[:240] + '...' session = create_session() url = issue_url() - issue = {'title': title, - 'body': body} + issue = {'title': title, 'body': body} r = session.post(url, json.dumps(issue)) if r.status_code == 201: print('Successfully created Issue "%s"' % title) @@ -71,6 +74,7 @@ def make_github_issue(title, body): print('Response:', r.content.decode('utf8')) raise Exception("Failed to create issue") + def get_github_issues(): session = create_session() url = issue_url() @@ -81,6 +85,7 @@ def get_github_issues(): raise Exception("Failed to get list of issues") return json.loads(r.content.decode('utf8')) + def close_github_issue(number): session = create_session() url = issue_url() + '/' + str(number) @@ -93,17 +98,19 @@ def close_github_issue(number): print('Response:', r.content.decode('utf8')) raise Exception("Failed to close issue") + def extract_issue(body, nr): try: splits = body.split(middle) sql = splits[0].split(header)[1] - error = splits[1][:-len(footer)] + error = splits[1][: -len(footer)] return (sql, error) except: print(f"Failed to extract SQL/error message from issue {nr}") print(body) return None + def run_shell_command_batch(shell, cmd): command = [shell, '--batch', '-init', '/dev/null'] @@ -112,6 +119,7 @@ def run_shell_command_batch(shell, cmd): stderr = res.stderr.decode('utf8').strip() return (stdout, stderr, res.returncode) + def test_reproducibility(shell, issue, current_errors): extract = extract_issue(issue['body'], issue['number']) if extract is None: @@ -128,6 +136,7 @@ def test_reproducibility(shell, issue, current_errors): current_errors[error] = issue return True + def extract_github_issues(shell): current_errors = dict() issues = get_github_issues() @@ -139,16 +148,23 @@ def extract_github_issues(shell): close_github_issue(int(issue['number'])) return current_errors + def file_issue(cmd, error_msg, fuzzer, seed, hash): # issue is new, file it print("Filing new issue to Github") title = error_msg - body = fuzzer_desc.replace("${FUZZER}", fuzzer).replace("${FULL_HASH}", hash).replace("${SHORT_HASH}", hash[:5]).replace("${SEED}", str(seed)) + body = ( + fuzzer_desc.replace("${FUZZER}", fuzzer) + .replace("${FULL_HASH}", hash) + .replace("${SHORT_HASH}", hash[:5]) + .replace("${SEED}", str(seed)) + ) body += header + cmd + middle + error_msg + footer print(title, body) make_github_issue(title, body) + def is_internal_error(error): if 'differs from original result' in error: return True diff --git a/scripts/generate_benchmarks.py b/scripts/generate_benchmarks.py index afe3f4b1ba2b..089bd0adef03 100644 --- a/scripts/generate_benchmarks.py +++ b/scripts/generate_benchmarks.py @@ -1,22 +1,31 @@ import os from python_helpers import open_utf8 + def format_tpch_queries(target_dir, tpch_in, comment): - with open_utf8(tpch_in, 'r') as f: - text = f.read() + with open_utf8(tpch_in, 'r') as f: + text = f.read() - for i in range(1, 23): - qnr = '%02d' % (i,) - target_file = os.path.join(target_dir, 'q' + qnr + '.benchmark') - new_text = '''# name: %s + for i in range(1, 23): + qnr = '%02d' % (i,) + target_file = os.path.join(target_dir, 'q' + qnr + '.benchmark') + new_text = '''# name: %s # description: Run query %02d from the TPC-H benchmark (%s) # group: [sf1] template %s QUERY_NUMBER=%d -QUERY_NUMBER_PADDED=%02d''' % (target_file, i, comment, tpch_in, i, i) - with open_utf8(target_file, 'w+') as f: - f.write(new_text) +QUERY_NUMBER_PADDED=%02d''' % ( + target_file, + i, + comment, + tpch_in, + i, + i, + ) + with open_utf8(target_file, 'w+') as f: + f.write(new_text) + # generate the TPC-H benchmark files single_threaded_dir = os.path.join('benchmark', 'tpch', 'sf1') @@ -25,4 +34,4 @@ def format_tpch_queries(target_dir, tpch_in, comment): parallel_threaded_dir = os.path.join('benchmark', 'tpch', 'sf1-parallel') parallel_threaded_in = os.path.join(parallel_threaded_dir, 'tpch_sf1_parallel.benchmark.in') -format_tpch_queries(parallel_threaded_dir, parallel_threaded_in, '4 threads') \ No newline at end of file +format_tpch_queries(parallel_threaded_dir, parallel_threaded_in, '4 threads') diff --git a/scripts/generate_builtin_types.py b/scripts/generate_builtin_types.py index e0a1027ec6da..63f98c57647d 100644 --- a/scripts/generate_builtin_types.py +++ b/scripts/generate_builtin_types.py @@ -22,32 +22,37 @@ footer = '''} // namespace duckdb ''' + def normalize_path_separators(x): return os.path.sep.join(x.split('/')) + def legal_struct_name(name): return name.isalnum() + def get_struct_name(function_name): return function_name.replace('_', ' ').title().replace(' ', '') + 'Fun' + def sanitize_string(text): return text.replace('"', '\\"') + new_text = header type_entries = [] json_path = normalize_path_separators(f'src/include/duckdb/catalog/default/builtin_types/types.json') with open(json_path, 'r') as f: - parsed_json = json.load(f) + parsed_json = json.load(f) # Extract all the types from the json for type in parsed_json: - names = type['names'] - - type_id = type['id'] + names = type['names'] + + type_id = type['id'] - type_entries += ['\t{' + f'''"{name}", LogicalTypeId::{type_id}''' + '}' for name in names] + type_entries += ['\t{' + f'''"{name}", LogicalTypeId::{type_id}''' + '}' for name in names] TYPE_COUNT = len(type_entries) new_text += ''' diff --git a/scripts/generate_csv_header.py b/scripts/generate_csv_header.py index 486aea3cac8c..54554bfffb54 100644 --- a/scripts/generate_csv_header.py +++ b/scripts/generate_csv_header.py @@ -2,32 +2,35 @@ import os from python_helpers import open_utf8 -def get_csv_text(fpath, add_null_terminator = False): - with open(fpath, 'rb') as f: - text = bytearray(f.read()) - result_text = "" - first = True - for byte in text: - if first: - result_text += str(byte) - else: - result_text += ", " + str(byte) - first = False - if add_null_terminator: - result_text += ", 0" - return result_text + +def get_csv_text(fpath, add_null_terminator=False): + with open(fpath, 'rb') as f: + text = bytearray(f.read()) + result_text = "" + first = True + for byte in text: + if first: + result_text += str(byte) + else: + result_text += ", " + str(byte) + first = False + if add_null_terminator: + result_text += ", 0" + return result_text + def write_dir(dirname, varname): - files = os.listdir(dirname) - files.sort() - result = "" - aggregated_result = "const char *%s[] = {\n" % (varname,) - for fname in files: - file_varname = "%s_%s" % (varname,fname.split('.')[0]) - result += "const uint8_t %s[] = {" % (file_varname,) + get_csv_text(os.path.join(dirname, fname), True) + "};\n" - aggregated_result += "\t(const char*) %s,\n" % (file_varname,) - aggregated_result = aggregated_result[:-2] + "\n};\n" - return result + aggregated_result + files = os.listdir(dirname) + files.sort() + result = "" + aggregated_result = "const char *%s[] = {\n" % (varname,) + for fname in files: + file_varname = "%s_%s" % (varname, fname.split('.')[0]) + result += "const uint8_t %s[] = {" % (file_varname,) + get_csv_text(os.path.join(dirname, fname), True) + "};\n" + aggregated_result += "\t(const char*) %s,\n" % (file_varname,) + aggregated_result = aggregated_result[:-2] + "\n};\n" + return result + aggregated_result + # ------------------------------------------- # # ------------------------------------------- # @@ -41,21 +44,23 @@ def write_dir(dirname, varname): tpch_answers_sf1 = os.path.join(tpch_dir, 'answers', 'sf1') tpch_header = os.path.join(tpch_dir, 'include', 'tpch_constants.hpp') + def create_tpch_header(tpch_dir): - result = """/* THIS FILE WAS AUTOMATICALLY GENERATED BY generate_csv_header.py */ + result = """/* THIS FILE WAS AUTOMATICALLY GENERATED BY generate_csv_header.py */ #pragma once const int TPCH_QUERIES_COUNT = 22; """ - # write the queries - result += write_dir(tpch_queries, "TPCH_QUERIES") - result += write_dir(tpch_answers_sf001, "TPCH_ANSWERS_SF0_01") - result += write_dir(tpch_answers_sf01, "TPCH_ANSWERS_SF0_1") - result += write_dir(tpch_answers_sf1, "TPCH_ANSWERS_SF1") + # write the queries + result += write_dir(tpch_queries, "TPCH_QUERIES") + result += write_dir(tpch_answers_sf001, "TPCH_ANSWERS_SF0_01") + result += write_dir(tpch_answers_sf01, "TPCH_ANSWERS_SF0_1") + result += write_dir(tpch_answers_sf1, "TPCH_ANSWERS_SF1") + + with open_utf8(tpch_header, 'w+') as f: + f.write(result) - with open_utf8(tpch_header, 'w+') as f: - f.write(result) print(tpch_header) create_tpch_header(tpch_dir) @@ -71,21 +76,23 @@ def create_tpch_header(tpch_dir): tpcds_answers_sf1 = os.path.join(tpcds_dir, 'answers', 'sf1') tpcds_header = os.path.join(tpcds_dir, 'include', 'tpcds_constants.hpp') + def create_tpcds_header(tpch_dir): - result = """/* THIS FILE WAS AUTOMATICALLY GENERATED BY generate_csv_header.py */ + result = """/* THIS FILE WAS AUTOMATICALLY GENERATED BY generate_csv_header.py */ #pragma once const int TPCDS_QUERIES_COUNT = 99; const int TPCDS_TABLE_COUNT = 24; """ - # write the queries - result += write_dir(tpcds_queries, "TPCDS_QUERIES") - result += write_dir(tpcds_answers_sf001, "TPCDS_ANSWERS_SF0_01") - result += write_dir(tpcds_answers_sf1, "TPCDS_ANSWERS_SF1") + # write the queries + result += write_dir(tpcds_queries, "TPCDS_QUERIES") + result += write_dir(tpcds_answers_sf001, "TPCDS_ANSWERS_SF0_01") + result += write_dir(tpcds_answers_sf1, "TPCDS_ANSWERS_SF1") + + with open_utf8(tpcds_header, 'w+') as f: + f.write(result) - with open_utf8(tpcds_header, 'w+') as f: - f.write(result) print(tpcds_header) create_tpcds_header(tpcds_dir) diff --git a/scripts/generate_enum_util.py b/scripts/generate_enum_util.py index 74d181afee0f..10324cc9bc82 100644 --- a/scripts/generate_enum_util.py +++ b/scripts/generate_enum_util.py @@ -18,26 +18,20 @@ "SQLNULL": "NULL", "TIMESTAMP_TZ": "TIMESTAMP WITH TIME ZONE", "TIME_TZ": "TIME WITH TIME ZONE", - "TIMESTAMP_SEC": "TIMESTAMP_S", - }, - "JoinType": { - "OUTER": "FULL" + "TIMESTAMP_SEC": "TIMESTAMP_S", }, + "JoinType": {"OUTER": "FULL"}, "OrderType": { "ORDER_DEFAULT": ["ORDER_DEFAULT", "DEFAULT"], "DESCENDING": ["DESCENDING", "DESC"], - "ASCENDING": ["ASCENDING", "ASC"] + "ASCENDING": ["ASCENDING", "ASC"], }, "OrderByNullType": { "ORDER_DEFAULT": ["ORDER_DEFAULT", "DEFAULT"], "NULLS_FIRST": ["NULLS_FIRST", "NULLS FIRST"], - "NULLS_LAST": ["NULLS_LAST", "NULLS LAST"] + "NULLS_LAST": ["NULLS_LAST", "NULLS LAST"], }, - "SampleMethod": { - "SYSTEM_SAMPLE": "System", - "BERNOULLI_SAMPLE": "Bernoulli", - "RESERVOIR_SAMPLE": "Reservoir" - } + "SampleMethod": {"SYSTEM_SAMPLE": "System", "BERNOULLI_SAMPLE": "Bernoulli", "RESERVOIR_SAMPLE": "Reservoir"}, } @@ -52,11 +46,13 @@ if file.endswith(".hpp"): hpp_files.append(os.path.join(root, file)) + def remove_prefix(str, prefix): if str.startswith(prefix): - return str[len(prefix):] + return str[len(prefix) :] return str + # get all the enum classes enums = [] enum_paths = [] @@ -110,7 +106,7 @@ def remove_prefix(str, prefix): if not file_path in enum_path_set: enum_path_set.add(file_path) enum_paths.append(file_path) - + enums.append((enum_name, enum_type, enum_members)) enum_paths.sort() @@ -128,16 +124,16 @@ def remove_prefix(str, prefix): # Write the enum util header with open(enum_util_header_file, "w") as f: - f.write(header) f.write('#pragma once\n\n') f.write('#include \n') f.write('#include "duckdb/common/string.hpp"\n\n') - + f.write("namespace duckdb {\n\n") - - f.write("""struct EnumUtil { + + f.write( + """struct EnumUtil { // String -> Enum template static T FromString(const char *value) = delete; @@ -151,18 +147,19 @@ def remove_prefix(str, prefix): template static string ToString(T value) { return string(ToChars(value)); } -};\n\n""") +};\n\n""" + ) # Forward declare all enums for enum_name, enum_type, _ in enums: f.write(f"enum class {enum_name} : {enum_type};\n\n") f.write("\n") - + # Forward declare all enum serialization functions for enum_name, enum_type, _ in enums: f.write(f"template<>\nconst char* EnumUtil::ToChars<{enum_name}>({enum_name} value);\n\n") f.write("\n") - + # Forward declare all enum dserialization functions for enum_name, enum_type, _ in enums: f.write(f"template<>\n{enum_name} EnumUtil::FromString<{enum_name}>(const char *value);\n\n") @@ -190,7 +187,9 @@ def remove_prefix(str, prefix): for key, strings in enum_members: # Always use the first string as the enum string f.write(f"\tcase {enum_name}::{key}:\n\t\treturn \"{strings[0]}\";\n") - f.write('\tdefault:\n\t\tthrow NotImplementedException(StringUtil::Format("Enum value: \'%d\' not implemented", value));\n') + f.write( + '\tdefault:\n\t\tthrow NotImplementedException(StringUtil::Format("Enum value: \'%d\' not implemented", value));\n' + ) f.write("\t}\n") f.write("}\n\n") diff --git a/scripts/generate_extensions_function.py b/scripts/generate_extensions_function.py index 55083edb154a..b456b383389a 100644 --- a/scripts/generate_extensions_function.py +++ b/scripts/generate_extensions_function.py @@ -8,8 +8,11 @@ parser = argparse.ArgumentParser(description='Generates/Validates extension_functions.hpp file') -parser.add_argument('--validate', action=argparse.BooleanOptionalAction, - help='If set will validate that extension_entries.hpp is up to date, otherwise it generates the extension_functions.hpp file.') +parser.add_argument( + '--validate', + action=argparse.BooleanOptionalAction, + help='If set will validate that extension_entries.hpp is up to date, otherwise it generates the extension_functions.hpp file.', +) args = parser.parse_args() @@ -17,42 +20,44 @@ stored_functions = { 'substrait': ["from_substrait", "get_substrait", "get_substrait_json", "from_substrait_json"], 'arrow': ["scan_arrow_ipc", "to_arrow_ipc"], - 'spatial': [] -} -stored_settings = { - 'substrait': [], - 'arrow': [], - 'spatial': [] + 'spatial': [], } +stored_settings = {'substrait': [], 'arrow': [], 'spatial': []} functions = {} + # Parses the extension config files for which extension names there are to be expected def parse_extension_txt(): - extensions_file = os.path.join("..", "build","extension_configuration","extensions.txt") + extensions_file = os.path.join("..", "build", "extension_configuration", "extensions.txt") with open(extensions_file) as f: return [line.rstrip() for line in f] + extension_names = parse_extension_txt() # Add exception for jemalloc as it doesn't produce a loadable extension but is in the config if "jemalloc" in extension_names: extension_names.remove("jemalloc") -ext_hpp = os.path.join("..", "src","include","duckdb", "main", "extension_entries.hpp") +ext_hpp = os.path.join("..", "src", "include", "duckdb", "main", "extension_entries.hpp") get_functions_query = "select distinct function_name from duckdb_functions();" get_settings_query = "select distinct name from duckdb_settings();" -duckdb_path = os.path.join("..",'build', 'release', 'duckdb') +duckdb_path = os.path.join("..", 'build', 'release', 'duckdb') + def get_query(sql_query, load_query): return os.popen(f'{duckdb_path} -csv -unsigned -c "{load_query}{sql_query}" ').read().split("\n")[1:-1] -def get_functions(load = ""): + +def get_functions(load=""): return set(get_query(get_functions_query, load)) -def get_settings(load = ""): + +def get_settings(load=""): return set(get_query(get_settings_query, load)) + base_functions = get_functions() base_settings = get_settings() @@ -64,16 +69,21 @@ def get_settings(load = ""): for filename in glob.iglob('/tmp/' + '**/*.duckdb_extension', recursive=True): extension_path[os.path.splitext(os.path.basename(filename))[0]] = filename + def update_extensions(extension_name, function_list, settings_list): global function_map, settings_map - function_map.update({ - extension_function.lower(): extension_name.lower() - for extension_function in (set(function_list) - base_functions) - }) - settings_map.update({ - extension_setting.lower(): extension_name.lower() - for extension_setting in (set(settings_list) - base_settings) - }) + function_map.update( + { + extension_function.lower(): extension_name.lower() + for extension_function in (set(function_list) - base_functions) + } + ) + settings_map.update( + { + extension_setting.lower(): extension_name.lower() + for extension_setting in (set(settings_list) - base_settings) + } + ) for extension_name in extension_names: @@ -94,7 +104,7 @@ def update_extensions(extension_name, function_list, settings_list): update_extensions(extension_name, extension_functions, extension_settings) if args.validate: - file = open(ext_hpp,'r') + file = open(ext_hpp, 'r') pattern = re.compile("{\"(.*?)\", \"(.*?)\"}[,}\n]") cur_function_map = dict(pattern.findall(file.read())) function_map.update(settings_map) @@ -118,7 +128,7 @@ def update_extensions(extension_name, function_list, settings_list): exit(1) else: # extension_functions - file = open(ext_hpp,'w') + file = open(ext_hpp, 'w') header = """//===----------------------------------------------------------------------===// // DuckDB // diff --git a/scripts/generate_flex.py b/scripts/generate_flex.py index 8714033f726d..2fa6eff1ce1d 100644 --- a/scripts/generate_flex.py +++ b/scripts/generate_flex.py @@ -13,42 +13,49 @@ namespace = 'duckdb_libpgquery' for arg in sys.argv[1:]: - if arg.startswith("--flex="): - flex_bin = arg.replace("--flex=", "") - elif arg.startswith("--custom_dir_prefix"): - pg_path = arg.split("=")[1] + pg_path - elif arg.startswith("--namespace"): - namespace = arg.split("=")[1] - else: - raise Exception("Unrecognized argument: " + arg + ", expected --flex, --custom_dir_prefix, --namespace") + if arg.startswith("--flex="): + flex_bin = arg.replace("--flex=", "") + elif arg.startswith("--custom_dir_prefix"): + pg_path = arg.split("=")[1] + pg_path + elif arg.startswith("--namespace"): + namespace = arg.split("=")[1] + else: + raise Exception("Unrecognized argument: " + arg + ", expected --flex, --custom_dir_prefix, --namespace") flex_file_path = os.path.join(pg_path, 'scan.l') target_file = os.path.join(pg_path, 'src_backend_parser_scan.cpp') -proc = subprocess.Popen([flex_bin, '--nounistd', '-o', target_file, flex_file_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE) +proc = subprocess.Popen( + [flex_bin, '--nounistd', '-o', target_file, flex_file_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE +) stdout = proc.stdout.read().decode('utf8') stderr = proc.stderr.read().decode('utf8') if proc.returncode != None or len(stderr) > 0: - print("Flex failed") - print("stdout: ", stdout) - print("stderr: ", stderr) - exit(1) + print("Flex failed") + print("stdout: ", stdout) + print("stderr: ", stderr) + exit(1) with open_utf8(target_file, 'r') as f: - text = f.read() + text = f.read() # convert this from 'int' to 'yy_size_t' to avoid triggering a warning text = text.replace('int yy_buf_size;\n', 'yy_size_t yy_buf_size;\n') # add the libpg_query namespace -text = text.replace(''' +text = text.replace( + ''' #ifndef FLEXINT_H #define FLEXINT_H -''', ''' +''', + ''' #ifndef FLEXINT_H #define FLEXINT_H -namespace ''' + namespace + ''' { -''') +namespace ''' + + namespace + + ''' { +''', +) text = text.replace('register ', '') text = text + "\n} /* " + namespace + " */\n" @@ -60,23 +67,27 @@ file_null = 'NULL' if platform == 'linux' else '[(]FILE [*][)] 0' -text = re.sub(rf'[#]ifdef\s*YY_STDINIT\n\s*yyin = stdin;\n\s*yyout = stdout;\n[#]else\n\s*yyin = {file_null};\n\s*yyout = {file_null};\n[#]endif', ' yyin = (FILE *) 0;\n yyout = (FILE *) 0;', text) +text = re.sub( + rf'[#]ifdef\s*YY_STDINIT\n\s*yyin = stdin;\n\s*yyout = stdout;\n[#]else\n\s*yyin = {file_null};\n\s*yyout = {file_null};\n[#]endif', + ' yyin = (FILE *) 0;\n yyout = (FILE *) 0;', + text, +) if 'stdin;' in text: - print("STDIN not removed!") - # exit(1) + print("STDIN not removed!") + # exit(1) if 'stdout' in text: - print("STDOUT not removed!") - # exit(1) + print("STDOUT not removed!") + # exit(1) if 'fprintf(' in text: - print("PRINTF not removed!") - # exit(1) + print("PRINTF not removed!") + # exit(1) if 'exit(' in text: - print("EXIT not removed!") - # exit(1) + print("EXIT not removed!") + # exit(1) with open_utf8(target_file, 'w+') as f: - f.write(text) + f.write(text) diff --git a/scripts/generate_functions.py b/scripts/generate_functions.py index bba40e32a829..d2dd8cfa3cc8 100644 --- a/scripts/generate_functions.py +++ b/scripts/generate_functions.py @@ -3,7 +3,22 @@ import json aggregate_functions = ['algebraic', 'distributive', 'holistic', 'nested', 'regression'] -scalar_functions = ['bit', 'blob', 'date', 'enum', 'generic', 'list', 'map', 'math', 'operators', 'random', 'string', 'debug', 'struct', 'union'] +scalar_functions = [ + 'bit', + 'blob', + 'date', + 'enum', + 'generic', + 'list', + 'map', + 'math', + 'operators', + 'random', + 'string', + 'debug', + 'struct', + 'union', +] header = '''//===----------------------------------------------------------------------===// // DuckDB @@ -27,18 +42,23 @@ footer = '''} // namespace duckdb ''' + def normalize_path_separators(x): return os.path.sep.join(x.split('/')) + def legal_struct_name(name): return name.isalnum() + def get_struct_name(function_name): return function_name.replace('_', ' ').title().replace(' ', '') + 'Fun' + def sanitize_string(text): return text.replace('\\', '\\\\').replace('"', '\\"') + all_function_types = [] all_function_types += [f'aggregate/{x}' for x in aggregate_functions] all_function_types += [f'scalar/{x}' for x in scalar_functions] @@ -81,7 +101,8 @@ def sanitize_string(text): if 'extra_functions' in entry: for func_text in entry['extra_functions']: function_text += '\n ' + func_text - new_text += '''struct {STRUCT} { + new_text += ( + '''struct {STRUCT} { static constexpr const char *Name = "{NAME}"; static constexpr const char *Parameters = "{PARAMETERS}"; static constexpr const char *Description = "{DESCRIPTION}"; @@ -90,7 +111,15 @@ def sanitize_string(text): {FUNCTION} }; -'''.replace('{STRUCT}', struct_name).replace('{NAME}', entry['name']).replace('{PARAMETERS}', entry['parameters'] if 'parameters' in entry else '').replace('{DESCRIPTION}', sanitize_string(entry['description'])).replace('{EXAMPLE}', sanitize_string(entry['example'])).replace('{FUNCTION}', function_text) +'''.replace( + '{STRUCT}', struct_name + ) + .replace('{NAME}', entry['name']) + .replace('{PARAMETERS}', entry['parameters'] if 'parameters' in entry else '') + .replace('{DESCRIPTION}', sanitize_string(entry['description'])) + .replace('{EXAMPLE}', sanitize_string(entry['example'])) + .replace('{FUNCTION}', function_text) + ) alias_count = 1 if 'aliases' in entry: for alias in entry['aliases']: @@ -114,13 +143,19 @@ def sanitize_string(text): print("Unknown entry type " + aliased_type + ' for entry ' + struct_name) exit(1) function_type_set[alias_struct_name] = aliased_type - new_text += '''struct {STRUCT} { + new_text += ( + '''struct {STRUCT} { using ALIAS = {ALIAS}; static constexpr const char *Name = "{NAME}"; }; -'''.replace('{STRUCT}', alias_struct_name).replace('{NAME}', alias).replace('{ALIAS}', struct_name) +'''.replace( + '{STRUCT}', alias_struct_name + ) + .replace('{NAME}', alias) + .replace('{ALIAS}', struct_name) + ) new_text += footer with open(header_path, 'w+') as f: f.write(new_text) diff --git a/scripts/generate_grammar.py b/scripts/generate_grammar.py index 18ced9bafacc..0ecce7c3116b 100644 --- a/scripts/generate_grammar.py +++ b/scripts/generate_grammar.py @@ -7,10 +7,10 @@ import sys from python_helpers import open_utf8 -bison_location = "bison" -base_dir = 'third_party/libpg_query/grammar' -pg_dir = 'third_party/libpg_query' -namespace = 'duckdb_libpgquery' +bison_location = "bison" +base_dir = 'third_party/libpg_query/grammar' +pg_dir = 'third_party/libpg_query' +namespace = 'duckdb_libpgquery' counterexamples = False run_update = False @@ -31,25 +31,31 @@ elif arg.startswith("--verbose"): verbose = True else: - raise Exception("Unrecognized argument: " + arg + ", expected --counterexamples, --bison=/loc/to/bison, --custom_dir_prefix, --namespace, --verbose") - -template_file = os.path.join(base_dir, 'grammar.y') -target_file = os.path.join(base_dir, 'grammar.y.tmp') -header_file = os.path.join(base_dir, 'grammar.hpp') -source_file = os.path.join(base_dir, 'grammar.cpp') -type_dir = os.path.join(base_dir, 'types') -rule_dir = os.path.join(base_dir, 'statements') -result_source = os.path.join(base_dir, 'grammar_out.cpp') -result_header = os.path.join(base_dir, 'grammar_out.hpp') -target_source_loc = os.path.join(pg_dir, 'src_backend_parser_gram.cpp') -target_header_loc = os.path.join(pg_dir, 'include/parser/gram.hpp') -kwlist_header = os.path.join(pg_dir, 'include/parser/kwlist.hpp') + raise Exception( + "Unrecognized argument: " + + arg + + ", expected --counterexamples, --bison=/loc/to/bison, --custom_dir_prefix, --namespace, --verbose" + ) + +template_file = os.path.join(base_dir, 'grammar.y') +target_file = os.path.join(base_dir, 'grammar.y.tmp') +header_file = os.path.join(base_dir, 'grammar.hpp') +source_file = os.path.join(base_dir, 'grammar.cpp') +type_dir = os.path.join(base_dir, 'types') +rule_dir = os.path.join(base_dir, 'statements') +result_source = os.path.join(base_dir, 'grammar_out.cpp') +result_header = os.path.join(base_dir, 'grammar_out.hpp') +target_source_loc = os.path.join(pg_dir, 'src_backend_parser_gram.cpp') +target_header_loc = os.path.join(pg_dir, 'include/parser/gram.hpp') +kwlist_header = os.path.join(pg_dir, 'include/parser/kwlist.hpp') + # parse the keyword lists def read_list_from_file(fname): with open_utf8(fname, 'r') as f: return [x.strip() for x in f.read().split('\n') if len(x.strip()) > 0] + kwdir = os.path.join(base_dir, 'keywords') unreserved_keywords = read_list_from_file(os.path.join(kwdir, 'unreserved_keywords.list')) colname_keywords = read_list_from_file(os.path.join(kwdir, 'column_name_keywords.list')) @@ -57,12 +63,14 @@ def read_list_from_file(fname): type_name_keywords = read_list_from_file(os.path.join(kwdir, 'type_name_keywords.list')) reserved_keywords = read_list_from_file(os.path.join(kwdir, 'reserved_keywords.list')) + def strip_p(x): if x.endswith("_P"): return x[:-2] else: return x + unreserved_keywords.sort(key=lambda x: strip_p(x)) colname_keywords.sort(key=lambda x: strip_p(x)) func_name_keywords.sort(key=lambda x: strip_p(x)) @@ -97,20 +105,28 @@ def strip_p(x): # now generate kwlist.h # PG_KEYWORD("abort", ABORT_P, UNRESERVED_KEYWORD) -kwtext = """ -namespace """ + namespace + """ { +kwtext = ( + """ +namespace """ + + namespace + + """ { #define PG_KEYWORD(a,b,c) {a,b,c}, const PGScanKeyword ScanKeywords[] = { """ +) for tpl in kwlist: kwtext += 'PG_KEYWORD("%s", %s, %s)\n' % (strip_p(tpl[0]).lower(), tpl[0], tpl[1]) -kwtext += """ +kwtext += ( + """ }; const int NumScanKeywords = lengthof(ScanKeywords); -} // namespace """ + namespace + """ +} // namespace """ + + namespace + + """ """ +) with open_utf8(kwlist_header, 'w+') as f: f.write(kwtext) @@ -123,6 +139,7 @@ def strip_p(x): # now perform a series of replacements in the file to construct the final yacc file + def get_file_contents(fpath, add_line_numbers=False): with open_utf8(fpath, 'r') as f: result = f.read() @@ -171,6 +188,7 @@ def get_file_contents(fpath, add_line_numbers=False): exit(1) unreserved_dict[ur] = True + def add_to_other_keywords(kw, list_name): global unreserved_dict global reserved_dict @@ -183,6 +201,7 @@ def add_to_other_keywords(kw, list_name): exit(1) other_dict[kw] = True + for cr in colname_keywords: add_to_other_keywords(cr, "colname") @@ -213,6 +232,7 @@ def add_to_other_keywords(kw, list_name): kw_definitions += "reserved_keyword: " + " | ".join(reserved_keywords) + "\n" text = text.replace("{{{ KEYWORD_DEFINITIONS }}}", kw_definitions) + # types def concat_dir(dname, extension, add_line_numbers=False): result = "" @@ -226,6 +246,7 @@ def concat_dir(dname, extension, add_line_numbers=False): result += get_file_contents(fpath, add_line_numbers) return result + type_definitions = concat_dir(type_dir, ".yh") # add statement types as well for stmt in statements: @@ -254,7 +275,7 @@ def concat_dir(dname, extension, add_line_numbers=False): cmd += ["-o", result_source, "-d", target_file] print(' '.join(cmd)) proc = subprocess.Popen(cmd, stderr=subprocess.PIPE) -res = proc.wait(timeout=10) # ensure CI does not hang as was seen when running with Bison 3.x release. +res = proc.wait(timeout=10) # ensure CI does not hang as was seen when running with Bison 3.x release. if res != 0: text = proc.stderr.read().decode('utf8') @@ -266,7 +287,9 @@ def concat_dir(dname, extension, add_line_numbers=False): print("On a Macbook you can obtain this using \"brew install bison\"") if counterexamples and 'time limit exceeded' in text: print("---------------------------------------------------------------------") - print("The counterexamples time limit was exceeded. This likely means that no useful counterexample was generated.") + print( + "The counterexamples time limit was exceeded. This likely means that no useful counterexample was generated." + ) print("") print("The counterexamples time limit can be increased by setting the TIME_LIMIT environment variable, e.g.:") print("export TIME_LIMIT=100") @@ -283,4 +306,4 @@ def concat_dir(dname, extension, add_line_numbers=False): text = text.replace('yynerrs = 0;', 'yynerrs = 0; (void)yynerrs;') with open_utf8(target_source_loc, 'w+') as f: - f.write(text) \ No newline at end of file + f.write(text) diff --git a/scripts/generate_plan_storage_version.py b/scripts/generate_plan_storage_version.py index fa4a21b3d933..561fb1583a75 100644 --- a/scripts/generate_plan_storage_version.py +++ b/scripts/generate_plan_storage_version.py @@ -10,26 +10,30 @@ shell_proc = os.path.join('build', 'debug', 'test', 'unittest') gen_binary_file = os.path.join('test', 'api', 'serialized_plans', 'serialized_plans.binary') + def try_remove_file(fname): - try: - os.remove(fname) - except: - pass + try: + os.remove(fname) + except: + pass + try_remove_file(gen_binary_file) + def run_test(test): - print(test) - env = os.environ.copy() - env["GEN_PLAN_STORAGE"] = "1" - res = subprocess.run([shell_proc, test ], capture_output=True, env = env) - stdout = res.stdout.decode('utf8').strip() - stderr = res.stderr.decode('utf8').strip() - if res.returncode != 0: - print("Failed to create binary file!") - print("----STDOUT----") - print(stdout) - print("----STDERR----") - print(stderr) + print(test) + env = os.environ.copy() + env["GEN_PLAN_STORAGE"] = "1" + res = subprocess.run([shell_proc, test], capture_output=True, env=env) + stdout = res.stdout.decode('utf8').strip() + stderr = res.stderr.decode('utf8').strip() + if res.returncode != 0: + print("Failed to create binary file!") + print("----STDOUT----") + print(stdout) + print("----STDERR----") + print(stderr) + run_test("Generate serialized plans file") diff --git a/scripts/generate_querygraph.py b/scripts/generate_querygraph.py index 5b685f9ae0ef..9d62443f66af 100644 --- a/scripts/generate_querygraph.py +++ b/scripts/generate_querygraph.py @@ -12,33 +12,33 @@ arguments = sys.argv if len(arguments) <= 1: - print("Usage: python generate_querygraph.py [input.json] [output.html] [open={1,0}]") - exit(1) + print("Usage: python generate_querygraph.py [input.json] [output.html] [open={1,0}]") + exit(1) input = arguments[1] if len(arguments) <= 2: - if ".json" in input: - output = input.replace(".json", ".html") - else: - output = input + ".html" + if ".json" in input: + output = input.replace(".json", ".html") + else: + output = input + ".html" else: - output = arguments[2] + output = arguments[2] open_output = True if len(arguments) >= 4: - open_arg = arguments[3].lower().replace('open=', '') - if open_arg == "1" or open_arg == "true": - open_output = True - elif open_arg == "0" or open_arg == "false": - open_output = False - else: - print("Incorrect input for open_output, expected TRUE or FALSE") - exit(1) + open_arg = arguments[3].lower().replace('open=', '') + if open_arg == "1" or open_arg == "true": + open_output = True + elif open_arg == "0" or open_arg == "false": + open_output = False + else: + print("Incorrect input for open_output, expected TRUE or FALSE") + exit(1) duckdb_query_graph.generate(input, output) with open(output, 'r') as f: - text = f.read() + text = f.read() if open_output: - os.system('open "' + output.replace('"', '\\"') + '"') + os.system('open "' + output.replace('"', '\\"') + '"') diff --git a/scripts/generate_serialization.py b/scripts/generate_serialization.py index c73758c29647..56cab8a07315 100644 --- a/scripts/generate_serialization.py +++ b/scripts/generate_serialization.py @@ -12,7 +12,7 @@ file_list.append( { 'source': os.path.join(source_base, fname), - 'target': os.path.join(target_base, 'serialize_' + fname.replace('.json', '.cpp')) + 'target': os.path.join(target_base, 'serialize_' + fname.replace('.json', '.cpp')), } ) @@ -61,50 +61,70 @@ switch_header = '\tcase ${ENUM_TYPE}::${ENUM_VALUE}:\n' -switch_statement = switch_header + '''\t\tresult = ${CLASS_DESERIALIZE}::FormatDeserialize(deserializer); +switch_statement = ( + switch_header + + '''\t\tresult = ${CLASS_DESERIALIZE}::FormatDeserialize(deserializer); \t\tbreak; ''' +) deserialize_element = '\tauto ${PROPERTY_NAME} = deserializer.ReadProperty<${PROPERTY_TYPE}>("${PROPERTY_KEY}");\n' deserialize_element_class = '\tdeserializer.ReadProperty("${PROPERTY_KEY}", result${ASSIGNMENT}${PROPERTY_NAME});\n' deserialize_element_class_base = '\tauto ${PROPERTY_NAME} = deserializer.ReadProperty>("${PROPERTY_KEY}");\n\tresult${ASSIGNMENT}${PROPERTY_NAME} = unique_ptr_cast<${BASE_PROPERTY}, ${DERIVED_PROPERTY}>(std::move(${PROPERTY_NAME}));\n' -move_list = [ - 'string', 'ParsedExpression*', 'CommonTableExpressionMap', 'LogicalType', 'ColumnDefinition' -] +move_list = ['string', 'ParsedExpression*', 'CommonTableExpressionMap', 'LogicalType', 'ColumnDefinition'] + +reference_list = ['ClientContext', 'bound_parameter_map_t'] -reference_list = [ - 'ClientContext', 'bound_parameter_map_t' -] def is_container(type): return '<' in type + def is_pointer(type): return type.endswith('*') or type.startswith('shared_ptr<') + def requires_move(type): return is_container(type) or is_pointer(type) or type in move_list + def replace_pointer(type): return re.sub('([a-zA-Z0-9]+)[*]', 'unique_ptr<\\1>', type) + def get_serialize_element(property_name, property_key, property_type, is_optional, pointer_type): write_method = 'WriteProperty' assignment = '.' if pointer_type == 'none' else '->' if is_optional: write_method = 'WriteOptionalProperty' - return serialize_element.replace('${PROPERTY_NAME}', property_name).replace('${PROPERTY_KEY}', property_key).replace('WriteProperty', write_method).replace('${ASSIGNMENT}', assignment) + return ( + serialize_element.replace('${PROPERTY_NAME}', property_name) + .replace('${PROPERTY_KEY}', property_key) + .replace('WriteProperty', write_method) + .replace('${ASSIGNMENT}', assignment) + ) + def get_deserialize_element_template(template, property_name, property_key, property_type, is_optional, pointer_type): read_method = 'ReadProperty' assignment = '.' if pointer_type == 'none' else '->' if is_optional: read_method = 'ReadOptionalProperty' - return template.replace('${PROPERTY_NAME}', property_name).replace('${PROPERTY_KEY}', property_key).replace('ReadProperty', read_method).replace('${PROPERTY_TYPE}', property_type).replace('${ASSIGNMENT}', assignment) + return ( + template.replace('${PROPERTY_NAME}', property_name) + .replace('${PROPERTY_KEY}', property_key) + .replace('ReadProperty', read_method) + .replace('${PROPERTY_TYPE}', property_type) + .replace('${ASSIGNMENT}', assignment) + ) + def get_deserialize_element(property_name, property_key, property_type, is_optional, pointer_type): - return get_deserialize_element_template(deserialize_element, property_name, property_key, property_type, is_optional, pointer_type) + return get_deserialize_element_template( + deserialize_element, property_name, property_key, property_type, is_optional, pointer_type + ) + def get_deserialize_assignment(property_name, property_type, pointer_type): assignment = '.' if pointer_type == 'none' else '->' @@ -113,27 +133,38 @@ def get_deserialize_assignment(property_name, property_type, pointer_type): property = f'std::move({property})' return f'\tresult{assignment}{property_name} = {property};\n' + def get_return_value(pointer_type, class_name): if pointer_type == 'none': return class_name return pointer_return.replace('${POINTER}', pointer_type).replace('${CLASS_NAME}', class_name) + def generate_constructor(pointer_type, class_name, constructor_parameters): if pointer_type == 'none': params = '' if len(constructor_parameters) == 0 else '(' + constructor_parameters + ')' return f'\t{class_name} result{params};\n' return f'\tauto result = duckdb::{pointer_type}<{class_name}>(new {class_name}({constructor_parameters}));\n' + def generate_return(class_entry): if class_entry.base is None: return '\treturn result;' else: return '\treturn std::move(result);' + supported_member_entries = [ - 'name', 'type', 'property', 'serialize_property', 'deserialize_property', 'optional', 'base' + 'name', + 'type', + 'property', + 'serialize_property', + 'deserialize_property', + 'optional', + 'base', ] + class MemberVariable: def __init__(self, entry): self.name = entry['name'] @@ -156,13 +187,27 @@ def __init__(self, entry): self.base = entry['base'] for key in entry.keys(): if key not in supported_member_entries: - print(f"Unsupported key \"{key}\" in member variable, key should be in set {str(supported_member_entries)}") + print( + f"Unsupported key \"{key}\" in member variable, key should be in set {str(supported_member_entries)}" + ) + supported_serialize_entries = [ - 'class', 'class_type', 'pointer_type', 'base', 'enum', 'constructor', 'custom_implementation', 'custom_switch_code', 'members', 'return_type', 'set_parameters', - 'includes' + 'class', + 'class_type', + 'pointer_type', + 'base', + 'enum', + 'constructor', + 'custom_implementation', + 'custom_switch_code', + 'members', + 'return_type', + 'set_parameters', + 'includes', ] + class SerializableClass: def __init__(self, entry): self.name = entry['class'] @@ -215,7 +260,9 @@ def __init__(self, entry): raise Exception(f'Set parameter {set_parameter_name} not found in member list') for key in entry.keys(): if key not in supported_serialize_entries: - print(f"Unsupported key \"{key}\" in member variable, key should be in set {str(supported_serialize_entries)}") + print( + f"Unsupported key \"{key}\" in member variable, key should be in set {str(supported_serialize_entries)}" + ) def inherit(self, base_class): self.base_object = base_class @@ -233,14 +280,20 @@ def generate_base_class_code(base_class): if entry.serialize_property == base_class.enum_value: enum_type = entry.type is_optional = entry.optional - base_class_serialize += get_serialize_element(entry.serialize_property, entry.name, type_name, is_optional, base_class.pointer_type) - base_class_deserialize += get_deserialize_element(entry.deserialize_property, entry.name, type_name, is_optional, base_class.pointer_type) + base_class_serialize += get_serialize_element( + entry.serialize_property, entry.name, type_name, is_optional, base_class.pointer_type + ) + base_class_deserialize += get_deserialize_element( + entry.deserialize_property, entry.name, type_name, is_optional, base_class.pointer_type + ) expressions = [x for x in base_class.children.items()] expressions = sorted(expressions, key=lambda x: x[0]) # set parameters for entry in base_class.set_parameters: - base_class_deserialize += set_deserialize_parameter.replace('${PROPERTY_TYPE}', entry.type).replace('${PROPERTY_NAME}', entry.name) + base_class_deserialize += set_deserialize_parameter.replace('${PROPERTY_TYPE}', entry.type).replace( + '${PROPERTY_NAME}', entry.name + ) base_class_deserialize += f'\t{base_class.pointer_type}<{base_class.name}> result;\n' switch_cases = '' @@ -248,11 +301,21 @@ def generate_base_class_code(base_class): enum_value = expr[0] child_data = expr[1] if child_data.custom_switch_code is not None: - switch_cases += switch_header.replace('${ENUM_TYPE}', enum_type).replace('${ENUM_VALUE}', enum_value).replace('${CLASS_DESERIALIZE}', child_data.name) - switch_cases += '\n'.join(['\t\t' + x for x in child_data.custom_switch_code.replace('\\n', '\n').split('\n')]) + switch_cases += ( + switch_header.replace('${ENUM_TYPE}', enum_type) + .replace('${ENUM_VALUE}', enum_value) + .replace('${CLASS_DESERIALIZE}', child_data.name) + ) + switch_cases += '\n'.join( + ['\t\t' + x for x in child_data.custom_switch_code.replace('\\n', '\n').split('\n')] + ) switch_cases += '\n' continue - switch_cases += switch_statement.replace('${ENUM_TYPE}', enum_type).replace('${ENUM_VALUE}', enum_value).replace('${CLASS_DESERIALIZE}', child_data.name) + switch_cases += ( + switch_statement.replace('${ENUM_TYPE}', enum_type) + .replace('${ENUM_VALUE}', enum_value) + .replace('${CLASS_DESERIALIZE}', child_data.name) + ) assign_entries = [] for entry in base_class.members: @@ -267,7 +330,11 @@ def generate_base_class_code(base_class): assign_entries.append(entry) # class switch statement - base_class_deserialize += switch_code.replace('${SWITCH_VARIABLE}', base_class.enum_value).replace('${CASE_STATEMENTS}', switch_cases).replace('${BASE_CLASS}', base_class.name) + base_class_deserialize += ( + switch_code.replace('${SWITCH_VARIABLE}', base_class.enum_value) + .replace('${CASE_STATEMENTS}', switch_cases) + .replace('${BASE_CLASS}', base_class.name) + ) deserialize_return = get_return_value(base_class.pointer_type, base_class.return_type) @@ -279,16 +346,24 @@ def generate_base_class_code(base_class): if entry.type in move_list or is_container(entry.type) or is_pointer(entry.type): move = True if move: - base_class_deserialize+= f'\tresult->{entry.deserialize_property} = std::move({entry.deserialize_property});\n' + base_class_deserialize += ( + f'\tresult->{entry.deserialize_property} = std::move({entry.deserialize_property});\n' + ) else: - base_class_deserialize+= f'\tresult->{entry.deserialize_property} = {entry.deserialize_property};\n' + base_class_deserialize += f'\tresult->{entry.deserialize_property} = {entry.deserialize_property};\n' base_class_deserialize += generate_return(base_class) base_class_generation = '' serialization = '' if base_class.base is not None: serialization += base_serialize.replace('${BASE_CLASS_NAME}', base_class.base) - base_class_generation += serialize_base.replace('${CLASS_NAME}', base_class.name).replace('${MEMBERS}', serialization + base_class_serialize) - base_class_generation += deserialize_base.replace('${DESERIALIZE_RETURN}', deserialize_return).replace('${CLASS_NAME}', base_class.name).replace('${MEMBERS}', base_class_deserialize) + base_class_generation += serialize_base.replace('${CLASS_NAME}', base_class.name).replace( + '${MEMBERS}', serialization + base_class_serialize + ) + base_class_generation += ( + deserialize_base.replace('${DESERIALIZE_RETURN}', deserialize_return) + .replace('${CLASS_NAME}', base_class.name) + .replace('${MEMBERS}', base_class_deserialize) + ) return base_class_generation @@ -353,7 +428,9 @@ def generate_class_code(class_entry): type_name = replace_pointer(entry.type) class_deserialize += get_deserialize_element(entry.name, entry.name, type_name, entry.optional, 'unique_ptr') - class_deserialize += generate_constructor(class_entry.pointer_type, class_entry.return_class, constructor_parameters) + class_deserialize += generate_constructor( + class_entry.pointer_type, class_entry.return_class, constructor_parameters + ) if class_entry.members is None: return None for entry_idx in range(len(class_entry.members)): @@ -365,15 +442,28 @@ def generate_class_code(class_entry): if not is_optional: write_property_name = '*' + entry.serialize_property elif is_optional: - raise Exception(f"Optional can only be combined with pointers (in {class_entry.name}, type {entry.type}, member {entry.type})") + raise Exception( + f"Optional can only be combined with pointers (in {class_entry.name}, type {entry.type}, member {entry.type})" + ) deserialize_template_str = deserialize_element_class if entry.base: write_property_name = f"({entry.base} &)" + write_property_name - deserialize_template_str = deserialize_element_class_base.replace('${BASE_PROPERTY}', entry.base.replace('*', '')).replace('${DERIVED_PROPERTY}', entry.type.replace('*', '')) + deserialize_template_str = deserialize_element_class_base.replace( + '${BASE_PROPERTY}', entry.base.replace('*', '') + ).replace('${DERIVED_PROPERTY}', entry.type.replace('*', '')) type_name = replace_pointer(entry.type) - class_serialize += get_serialize_element(write_property_name, property_key, type_name, is_optional, class_entry.pointer_type) + class_serialize += get_serialize_element( + write_property_name, property_key, type_name, is_optional, class_entry.pointer_type + ) if entry_idx > last_constructor_index: - class_deserialize += get_deserialize_element_template(deserialize_template_str, entry.deserialize_property, property_key, type_name, is_optional, class_entry.pointer_type) + class_deserialize += get_deserialize_element_template( + deserialize_template_str, + entry.deserialize_property, + property_key, + type_name, + is_optional, + class_entry.pointer_type, + ) elif entry.name not in constructor_entries: class_deserialize += get_deserialize_assignment(entry.name, entry.type, class_entry.pointer_type) @@ -382,18 +472,24 @@ def generate_class_code(class_entry): class_generation = '' class_generation += serialize_base.replace('${CLASS_NAME}', class_entry.name).replace('${MEMBERS}', class_serialize) - class_generation += deserialize_base.replace('${DESERIALIZE_RETURN}', deserialize_return).replace('${CLASS_NAME}', class_entry.name).replace('${MEMBERS}', class_deserialize) + class_generation += ( + deserialize_base.replace('${DESERIALIZE_RETURN}', deserialize_return) + .replace('${CLASS_NAME}', class_entry.name) + .replace('${MEMBERS}', class_deserialize) + ) return class_generation - for entry in file_list: source_path = entry['source'] target_path = entry['target'] with open(source_path, 'r') as f: json_data = json.load(f) - include_list = ['duckdb/common/serializer/format_serializer.hpp', 'duckdb/common/serializer/format_deserializer.hpp'] + include_list = [ + 'duckdb/common/serializer/format_serializer.hpp', + 'duckdb/common/serializer/format_deserializer.hpp', + ] base_classes = [] classes = [] base_class_data = {} @@ -422,7 +518,9 @@ def generate_class_code(class_entry): base_class_object.children[enum_entry] = new_class with open(target_path, 'w+') as f: - f.write(header.replace('${INCLUDE_LIST}', ''.join([include_base.replace('${FILENAME}', x) for x in include_list]))) + f.write( + header.replace('${INCLUDE_LIST}', ''.join([include_base.replace('${FILENAME}', x) for x in include_list])) + ) # generate the base class serialization for base_class in base_classes: @@ -430,7 +528,7 @@ def generate_class_code(class_entry): f.write(base_class_generation) # generate the class serialization - classes = sorted(classes, key = lambda x: x.name) + classes = sorted(classes, key=lambda x: x.name) for class_entry in classes: class_generation = generate_class_code(class_entry) if class_generation is None: diff --git a/scripts/generate_storage_version.py b/scripts/generate_storage_version.py index c963cece3851..b54983ee91d0 100644 --- a/scripts/generate_storage_version.py +++ b/scripts/generate_storage_version.py @@ -10,29 +10,37 @@ gen_storage_script = os.path.join('test', 'sql', 'storage_version', 'generate_storage_version.sql') gen_storage_target = os.path.join('test', 'sql', 'storage_version', 'storage_version.db') + def try_remove_file(fname): - try: - os.remove(fname) - except: - pass + try: + os.remove(fname) + except: + pass + try_remove_file(gen_storage_target) try_remove_file(gen_storage_target + '.wal') + def run_command_in_shell(cmd): - print(cmd) - res = subprocess.run([shell_proc, '--batch', '-init', '/dev/null', gen_storage_target], capture_output=True, input=bytearray(cmd, 'utf8')) - stdout = res.stdout.decode('utf8').strip() - stderr = res.stderr.decode('utf8').strip() - if res.returncode != 0: - print("Failed to create database file!") - print("----STDOUT----") - print(stdout) - print("----STDERR----") - print(stderr) + print(cmd) + res = subprocess.run( + [shell_proc, '--batch', '-init', '/dev/null', gen_storage_target], + capture_output=True, + input=bytearray(cmd, 'utf8'), + ) + stdout = res.stdout.decode('utf8').strip() + stderr = res.stderr.decode('utf8').strip() + if res.returncode != 0: + print("Failed to create database file!") + print("----STDOUT----") + print(stdout) + print("----STDERR----") + print(stderr) + with open_utf8(gen_storage_script, 'r') as f: - cmd = f.read() + cmd = f.read() run_command_in_shell(cmd) run_command_in_shell('select * from integral_values') diff --git a/scripts/generate_tpcds_results.py b/scripts/generate_tpcds_results.py index 79c34dbe7fc0..3bdc88244406 100644 --- a/scripts/generate_tpcds_results.py +++ b/scripts/generate_tpcds_results.py @@ -8,20 +8,42 @@ import multiprocessing.pool parser = argparse.ArgumentParser(description='Generate TPC-DS reference results from Postgres.') -parser.add_argument('--sf', dest='sf', - action='store', help='The TPC-DS scale factor reference results to generate', default=1) -parser.add_argument('--query-dir', dest='query_dir', - action='store', help='The directory with queries to run', default='extension/tpcds/dsdgen/queries') -parser.add_argument('--answer-dir', dest='answer_dir', - action='store', help='The directory where to store the answers', default='extension/tpcds/dsdgen/answers/sf${SF}') -parser.add_argument('--duckdb-path', dest='duckdb_path', - action='store', help='The path to the DuckDB executable', default='build/reldebug/duckdb') -parser.add_argument('--skip-load', dest='skip_load', - action='store_const', const=True, help='Whether or not to skip loading', default=False) -parser.add_argument('--query-list', dest='query_list', - action='store', help='The list of queries to run (default = all)', default='') -parser.add_argument('--nthreads', dest='nthreads', - action='store', type=int, help='The number of threads', default=0) +parser.add_argument( + '--sf', dest='sf', action='store', help='The TPC-DS scale factor reference results to generate', default=1 +) +parser.add_argument( + '--query-dir', + dest='query_dir', + action='store', + help='The directory with queries to run', + default='extension/tpcds/dsdgen/queries', +) +parser.add_argument( + '--answer-dir', + dest='answer_dir', + action='store', + help='The directory where to store the answers', + default='extension/tpcds/dsdgen/answers/sf${SF}', +) +parser.add_argument( + '--duckdb-path', + dest='duckdb_path', + action='store', + help='The path to the DuckDB executable', + default='build/reldebug/duckdb', +) +parser.add_argument( + '--skip-load', + dest='skip_load', + action='store_const', + const=True, + help='Whether or not to skip loading', + default=False, +) +parser.add_argument( + '--query-list', dest='query_list', action='store', help='The list of queries to run (default = all)', default='' +) +parser.add_argument('--nthreads', dest='nthreads', action='store', type=int, help='The number of threads', default=0) args = parser.parse_args() @@ -40,7 +62,33 @@ exit(1) # drop the previous tables - tables = ['name','web_site','web_sales','web_returns','web_page','warehouse','time_dim','store_sales','store_returns','store','ship_mode','reason','promotion','item','inventory','income_band','household_demographics','date_dim','customer_demographics','customer_address','customer','catalog_sales','catalog_returns','catalog_page','call_center'] + tables = [ + 'name', + 'web_site', + 'web_sales', + 'web_returns', + 'web_page', + 'warehouse', + 'time_dim', + 'store_sales', + 'store_returns', + 'store', + 'ship_mode', + 'reason', + 'promotion', + 'item', + 'inventory', + 'income_band', + 'household_demographics', + 'date_dim', + 'customer_demographics', + 'customer_address', + 'customer', + 'catalog_sales', + 'catalog_returns', + 'catalog_page', + 'call_center', + ] for table in tables: c.execute(f'DROP TABLE IF EXISTS {table};') @@ -69,6 +117,7 @@ queries = [x for x in queries if x in passing_queries] queries.sort() + def run_query(q): print(q) with open(os.path.join(args.query_dir, q), 'r') as f: @@ -78,6 +127,7 @@ def run_query(q): c.execute(f'CREATE TABLE "query_result{q}" AS ' + sql_query) c.execute(f"COPY \"query_result{q}\" TO '{answer_path}' (FORMAT CSV, DELIMITER '|', HEADER, NULL 'NULL')") + if args.nthreads == 0: for q in queries: run_query(q) diff --git a/scripts/generate_tpcds_schema.py b/scripts/generate_tpcds_schema.py index 205c4f5006de..9c2f66db286c 100644 --- a/scripts/generate_tpcds_schema.py +++ b/scripts/generate_tpcds_schema.py @@ -39,20 +39,23 @@ select concat('const char *', '$STRUCT_NAME', '::PrimaryKeyColumns[] = {', STRING_AGG('"' || name || '"', ', ') || '};') from pragma_table_info('$NAME') where pk=true; ''' + def run_query(sql): - input_sql = initcode + '\n' + sql - res = subprocess.run(duckdb_program, input=input_sql.encode('utf8'), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + input_sql = initcode + '\n' + sql + res = subprocess.run(duckdb_program, input=input_sql.encode('utf8'), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + stdout = res.stdout.decode('utf8').strip() + stderr = res.stderr.decode('utf8').strip() + if res.returncode != 0: + print("FAILED TO RUN QUERY") + print(stderr) + exit(1) + return stdout - stdout = res.stdout.decode('utf8').strip() - stderr = res.stderr.decode('utf8').strip() - if res.returncode != 0: - print("FAILED TO RUN QUERY") - print(stderr) - exit(1) - return stdout def prepare_query(sql, table_name, struct_name): - return sql.replace('$NAME', table_name).replace('$STRUCT_NAME', struct_name) + return sql.replace('$NAME', table_name).replace('$STRUCT_NAME', struct_name) + header = ''' #pragma once @@ -86,20 +89,28 @@ def prepare_query(sql, table_name, struct_name): table_list = run_query('show tables') for table_name in table_list.split('\n'): - table_name = table_name.strip() - print(''' + table_name = table_name.strip() + print( + ''' //===--------------------------------------------------------------------===// // $NAME -//===--------------------------------------------------------------------===//'''.replace('$NAME', table_name)) - struct_name = str(table_name.title().replace('_', '')) + 'Info' - column_count = int(run_query(prepare_query(column_count_query, table_name, struct_name)).strip()) - pk_column_count = int(run_query(prepare_query(pk_column_count_query, table_name, struct_name)).strip()) - print(prepare_query(struct_def, table_name, struct_name).replace('$COLUMN_COUNT', str(column_count)).replace('$PK_COLUMN_COUNT', str(pk_column_count))) - - print(run_query(prepare_query(gen_names, table_name, struct_name)).replace('""', '"').strip('"')) - print("") - print(run_query(prepare_query(gen_types, table_name, struct_name)).strip('"')) - print("") - print(run_query(prepare_query(pk_columns, table_name, struct_name)).replace('""', '"').strip('"')) - -print(footer) \ No newline at end of file +//===--------------------------------------------------------------------===//'''.replace( + '$NAME', table_name + ) + ) + struct_name = str(table_name.title().replace('_', '')) + 'Info' + column_count = int(run_query(prepare_query(column_count_query, table_name, struct_name)).strip()) + pk_column_count = int(run_query(prepare_query(pk_column_count_query, table_name, struct_name)).strip()) + print( + prepare_query(struct_def, table_name, struct_name) + .replace('$COLUMN_COUNT', str(column_count)) + .replace('$PK_COLUMN_COUNT', str(pk_column_count)) + ) + + print(run_query(prepare_query(gen_names, table_name, struct_name)).replace('""', '"').strip('"')) + print("") + print(run_query(prepare_query(gen_types, table_name, struct_name)).strip('"')) + print("") + print(run_query(prepare_query(pk_columns, table_name, struct_name)).replace('""', '"').strip('"')) + +print(footer) diff --git a/scripts/generate_vector_sizes.py b/scripts/generate_vector_sizes.py index aa2b5eb0b93b..8b9e88db76c3 100644 --- a/scripts/generate_vector_sizes.py +++ b/scripts/generate_vector_sizes.py @@ -2,18 +2,18 @@ result = "" for i in range(len(supported_vector_sizes)): - vsize = supported_vector_sizes[i] - if i == 0: - result += "#if" - else: - result += "#elif" - result += " STANDARD_VECTOR_SIZE == " + str(vsize) + "\n" - result += "const sel_t FlatVector::incremental_vector[] = {" - for idx in range(vsize): - if idx != 0: - result += ", " - result += str(idx) - result += "};\n" + vsize = supported_vector_sizes[i] + if i == 0: + result += "#if" + else: + result += "#elif" + result += " STANDARD_VECTOR_SIZE == " + str(vsize) + "\n" + result += "const sel_t FlatVector::incremental_vector[] = {" + for idx in range(vsize): + if idx != 0: + result += ", " + result += str(idx) + result += "};\n" result += """#else #error Unsupported VECTOR_SIZE! diff --git a/scripts/gentpcecode.py b/scripts/gentpcecode.py index ea41f71d1319..1de2f86248e2 100644 --- a/scripts/gentpcecode.py +++ b/scripts/gentpcecode.py @@ -19,16 +19,19 @@ source = open_utf8(GENERATED_SOURCE, 'w+') for fp in [header, source]: - fp.write(""" + fp.write( + """ //////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////// // THIS FILE IS GENERATED BY gentpcecode.py, DO NOT EDIT MANUALLY // //////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////// -""") +""" + ) -header.write(""" +header.write( + """ #include "duckdb/catalog/catalog.hpp" #include "duckdb/main/appender.hpp" #include "duckdb/main/connection.hpp" @@ -92,9 +95,11 @@ class DuckDBLoaderFactory : public CBaseLoaderFactory { virtual CBaseLoader *CreateZipCodeLoader(); }; -""") +""" +) -source.write(""" +source.write( + """ #include "tpce_generated.hpp" using namespace duckdb; @@ -155,58 +160,62 @@ class DuckDBLoaderFactory : public CBaseLoaderFactory { } }; -""") +""" +) with open(os.path.join(TPCE_DIR, 'include/main/TableRows.h'), 'r') as f: - for line in f: - line = line.strip() - if line.startswith('typedef struct '): - line = line.replace('typedef struct ', '') - current_table = line.split(' ')[0].replace('_ROW', ' ').replace('_', ' ').lower().strip() - tables[current_table] = [] - elif line.startswith('}'): - current_table = None - elif current_table != None: -#row -#get type - splits = line.strip().split(' ') - if len(splits) < 2: - continue - line = splits[0] - name = splits[1].split(';')[0].split('[')[0].lower() - is_single_char = False - if 'TIdent' in line or 'INT64' in line or 'TTrade' in line: - tpe = "TypeId::BIGINT" - sqltpe = "BIGINT" - elif 'double' in line or 'float' in line: - tpe = "TypeId::DECIMAL" - sqltpe = "DECIMAL" - elif 'int' in line: - tpe = "TypeId::INTEGER" - sqltpe = "INTEGER" - elif 'CDateTime' in line: - tpe = "TypeId::TIMESTAMP" - sqltpe = "TIMESTAMP" - elif 'bool' in line: - tpe = 'TypeId::BOOLEAN' - sqltpe = "BOOLEAN" - elif 'char' in line: - if '[' not in splits[1]: - is_single_char = True - tpe = "TypeId::VARCHAR" - sqltpe = "VARCHAR" - else: - continue - tables[current_table].append([name, tpe, is_single_char, sqltpe]) + for line in f: + line = line.strip() + if line.startswith('typedef struct '): + line = line.replace('typedef struct ', '') + current_table = line.split(' ')[0].replace('_ROW', ' ').replace('_', ' ').lower().strip() + tables[current_table] = [] + elif line.startswith('}'): + current_table = None + elif current_table != None: + # row + # get type + splits = line.strip().split(' ') + if len(splits) < 2: + continue + line = splits[0] + name = splits[1].split(';')[0].split('[')[0].lower() + is_single_char = False + if 'TIdent' in line or 'INT64' in line or 'TTrade' in line: + tpe = "TypeId::BIGINT" + sqltpe = "BIGINT" + elif 'double' in line or 'float' in line: + tpe = "TypeId::DECIMAL" + sqltpe = "DECIMAL" + elif 'int' in line: + tpe = "TypeId::INTEGER" + sqltpe = "INTEGER" + elif 'CDateTime' in line: + tpe = "TypeId::TIMESTAMP" + sqltpe = "TIMESTAMP" + elif 'bool' in line: + tpe = 'TypeId::BOOLEAN' + sqltpe = "BOOLEAN" + elif 'char' in line: + if '[' not in splits[1]: + is_single_char = True + tpe = "TypeId::VARCHAR" + sqltpe = "VARCHAR" + else: + continue + tables[current_table].append([name, tpe, is_single_char, sqltpe]) + def get_tablename(name): - name = name.title().replace(' ', '') - if name == 'NewsXref': - return 'NewsXRef' - return name + name = name.title().replace(' ', '') + if name == 'NewsXref': + return 'NewsXRef' + return name + for table in tables.keys(): - source.write(""" + source.write( + """ class DuckDB${TABLENAME}Load : public DuckDBBaseLoader<${ROW_TYPE}> { public: DuckDB${TABLENAME}Load(Connection &con, string schema, string table) : @@ -215,48 +224,61 @@ class DuckDB${TABLENAME}Load : public DuckDBBaseLoader<${ROW_TYPE}> { } void WriteNextRecord(const ${ROW_TYPE} &next_record) { - info.appender.BeginRow();""".replace("${TABLENAME}", get_tablename(table)).replace("${ROW_TYPE}", table.upper().replace(' ', '_') + '_ROW')); - source.write("\n") - collist = tables[table] - for i in range(len(collist)): - entry = collist[i] - name = entry[0].upper() - tpe = entry[1] - if tpe == "TypeId::BIGINT": - funcname = "bigint" - elif tpe == "TypeId::DECIMAL": - funcname = "double" - elif tpe == "TypeId::INTEGER": - funcname = "value" - elif tpe == "TypeId::TIMESTAMP": - funcname = "timestamp" - elif tpe == 'TypeId::BOOLEAN': - funcname = "bool" - elif tpe == "TypeId::VARCHAR": - if entry[2]: - funcname = "char" - else: - funcname = "string" - else: - print("Unknown type " + tpe) - exit(1) - source.write("\t\tappend_%s(info, next_record.%s);" % (funcname, name)) - if i != len(collist) - 1: - source.write("\n") - source.write(""" + info.appender.BeginRow();""".replace( + "${TABLENAME}", get_tablename(table) + ).replace( + "${ROW_TYPE}", table.upper().replace(' ', '_') + '_ROW' + ) + ) + source.write("\n") + collist = tables[table] + for i in range(len(collist)): + entry = collist[i] + name = entry[0].upper() + tpe = entry[1] + if tpe == "TypeId::BIGINT": + funcname = "bigint" + elif tpe == "TypeId::DECIMAL": + funcname = "double" + elif tpe == "TypeId::INTEGER": + funcname = "value" + elif tpe == "TypeId::TIMESTAMP": + funcname = "timestamp" + elif tpe == 'TypeId::BOOLEAN': + funcname = "bool" + elif tpe == "TypeId::VARCHAR": + if entry[2]: + funcname = "char" + else: + funcname = "string" + else: + print("Unknown type " + tpe) + exit(1) + source.write("\t\tappend_%s(info, next_record.%s);" % (funcname, name)) + if i != len(collist) - 1: + source.write("\n") + source.write( + """ info.appender.EndRow(); } -};""") +};""" + ) for table in tables.keys(): - source.write(""" + source.write( + """ CBaseLoader<${ROW_TYPE}> * DuckDBLoaderFactory::Create${TABLENAME}Loader() { return new DuckDB${TABLENAME}Load(con, schema, "${TABLEINDB}" + suffix); } -""".replace("${TABLENAME}", get_tablename(table)).replace("${ROW_TYPE}", table.upper().replace(' ', '_') + '_ROW').replace("${TABLEINDB}", table.replace(' ', '_'))) +""".replace( + "${TABLENAME}", get_tablename(table) + ) + .replace("${ROW_TYPE}", table.upper().replace(' ', '_') + '_ROW') + .replace("${TABLEINDB}", table.replace(' ', '_')) + ) source.write("\n") @@ -269,20 +291,20 @@ class DuckDB${TABLENAME}Load : public DuckDBBaseLoader<${ROW_TYPE}> { for table in tables.keys(): - tname = table.replace(' ', '_') - str = 'static string ' + table.title().replace(' ', '') + 'Schema(string schema, string suffix) {\n' - str += '\treturn "CREATE TABLE " + schema + ".%s" + suffix + " ("\n' % (tname,) - columns = tables[table] - for i in range(len(columns)): - column = columns[i] - str += '\t "' + column[0] + " " + column[3] - if i == len(columns) - 1: - str += ')";' - else: - str += ',"' - str += "\n" - str += "}\n\n" - source.write(str) + tname = table.replace(' ', '_') + str = 'static string ' + table.title().replace(' ', '') + 'Schema(string schema, string suffix) {\n' + str += '\treturn "CREATE TABLE " + schema + ".%s" + suffix + " ("\n' % (tname,) + columns = tables[table] + for i in range(len(columns)): + column = columns[i] + str += '\t "' + column[0] + " " + column[3] + if i == len(columns) - 1: + str += ')";' + else: + str += ',"' + str += "\n" + str += "}\n\n" + source.write(str) func = 'void CreateTPCESchema(duckdb::DuckDB &db, duckdb::Connection &con, std::string &schema, std::string &suffix)' @@ -293,15 +315,13 @@ class DuckDB${TABLENAME}Load : public DuckDBBaseLoader<${ROW_TYPE}> { # con.Query(RegionSchema(schema, suffix)); for table in tables.keys(): - tname = table.replace(' ', '_') - source.write('\tcon.Query(%sSchema(schema, suffix));\n' % - (table.title().replace(' ', ''))) + tname = table.replace(' ', '_') + source.write('\tcon.Query(%sSchema(schema, suffix));\n' % (table.title().replace(' ', ''))) source.write('}\n\n') - for fp in [header, source]: - fp.write("} /* namespace TPCE */\n") - fp.close() + fp.write("} /* namespace TPCE */\n") + fp.close() diff --git a/scripts/get_test_list.py b/scripts/get_test_list.py index 946abf4addf8..ac42fb2faa5d 100644 --- a/scripts/get_test_list.py +++ b/scripts/get_test_list.py @@ -5,12 +5,21 @@ import os parser = argparse.ArgumentParser(description='Print a list of tests to run.') -parser.add_argument('--file-contains', dest='file_contains', - action='store', help='Filter based on a string contained in the text', default=None) -parser.add_argument('--unittest', dest='unittest', - action='store', help='The path to the unittest program', default='build/release/test/unittest') -parser.add_argument('--list', dest='filter', - action='store', help='The unittest filter to apply', default='') +parser.add_argument( + '--file-contains', + dest='file_contains', + action='store', + help='Filter based on a string contained in the text', + default=None, +) +parser.add_argument( + '--unittest', + dest='unittest', + action='store', + help='The path to the unittest program', + default='build/release/test/unittest', +) +parser.add_argument('--list', dest='filter', action='store', help='The unittest filter to apply', default='') args = parser.parse_args() @@ -23,25 +32,25 @@ stdout = proc.stdout.read().decode('utf8') stderr = proc.stderr.read().decode('utf8') if proc.returncode is not None and proc.returncode != 0: - print("Failed to run program " + unittest_program) - print(proc.returncode) - print(stdout) - print(stderr) - exit(1) + print("Failed to run program " + unittest_program) + print(proc.returncode) + print(stdout) + print(stderr) + exit(1) test_cases = [] for line in stdout.splitlines()[1:]: - if not line.strip(): - continue - splits = line.rsplit('\t', 1) - if file_contains is not None: - if not os.path.isfile(splits[0]): - continue - try: - with open(splits[0], 'r') as f: - text = f.read() - except UnicodeDecodeError: - continue - if file_contains not in text: - continue - print(splits[0]) + if not line.strip(): + continue + splits = line.rsplit('\t', 1) + if file_contains is not None: + if not os.path.isfile(splits[0]): + continue + try: + with open(splits[0], 'r') as f: + text = f.read() + except UnicodeDecodeError: + continue + if file_contains not in text: + continue + print(splits[0]) diff --git a/scripts/include_analyzer.py b/scripts/include_analyzer.py index f2cda96140f4..8a94adfb0199 100644 --- a/scripts/include_analyzer.py +++ b/scripts/include_analyzer.py @@ -9,7 +9,8 @@ include_chains = {} cached_includes = {} -def analyze_include_file(fpath, already_included_files, prev_include = ""): + +def analyze_include_file(fpath, already_included_files, prev_include=""): if fpath in already_included_files: return if fpath in amalgamation.always_excluded: @@ -40,6 +41,7 @@ def analyze_include_file(fpath, already_included_files, prev_include = ""): for include in includes: analyze_include_file(include, already_included_files, prev_include) + def analyze_includes(dir): files = os.listdir(dir) files.sort() @@ -52,6 +54,7 @@ def analyze_includes(dir): elif fname.endswith('.cpp') or fname.endswith('.c') or fname.endswith('.cc'): analyze_include_file(fpath, []) + for compile_dir in amalgamation.compile_directories: analyze_includes(compile_dir) @@ -59,7 +62,7 @@ def analyze_includes(dir): for entry in include_counts.keys(): kws.append([entry, include_counts[entry]]) -kws.sort(key = lambda tup: -tup[1]) +kws.sort(key=lambda tup: -tup[1]) for k in range(0, len(kws)): include_file = kws[k][0] include_count = kws[k][1] @@ -70,7 +73,6 @@ def analyze_includes(dir): chainkws = [] for chain in include_chains[include_file]: chainkws.append([chain, include_chains[include_file][chain]]) - chainkws.sort(key = lambda tup: -tup[1]) + chainkws.sort(key=lambda tup: -tup[1]) for l in range(0, min(5, len(chainkws))): print(chainkws[l]) - diff --git a/scripts/jdbc_maven_deploy.py b/scripts/jdbc_maven_deploy.py index e6db6a5d5a85..b54e8567334b 100644 --- a/scripts/jdbc_maven_deploy.py +++ b/scripts/jdbc_maven_deploy.py @@ -2,7 +2,7 @@ # https://issues.sonatype.org/browse/OSSRH-58179 # this is the pgp key we use to sign releases -# if this key should be lost, generate a new one with `gpg --full-generate-key` +# if this key should be lost, generate a new one with `gpg --full-generate-key` # AND upload to keyserver: `gpg --keyserver hkp://keys.openpgp.org --send-keys [...]` # export the keys for GitHub Actions like so: `gpg --export-secret-keys | base64` # -------------------------------- @@ -20,38 +20,40 @@ import zipfile import re + def exec(cmd): - print(cmd) - res = subprocess.run(cmd.split(' '), capture_output=True) - if res.returncode == 0: - return res.stdout - raise ValueError(res.stdout + res.stderr) + print(cmd) + res = subprocess.run(cmd.split(' '), capture_output=True) + if res.returncode == 0: + return res.stdout + raise ValueError(res.stdout + res.stderr) + if len(sys.argv) < 4 or not os.path.isdir(sys.argv[2]) or not os.path.isdir(sys.argv[3]): - print("Usage: [release_tag, format: v1.2.3] [artifact_dir] [jdbc_root_path]") - exit(1) + print("Usage: [release_tag, format: v1.2.3] [artifact_dir] [jdbc_root_path]") + exit(1) version_regex = re.compile(r'^v((\d+)\.(\d+)\.\d+)$') release_tag = sys.argv[1] deploy_url = 'https://oss.sonatype.org/service/local/staging/deploy/maven2/' is_release = True -if (release_tag == 'master'): - # for SNAPSHOT builds we increment the minor version and set patch level to zero. - # seemed the most sensible - last_tag = exec('git tag --sort=-committerdate').decode('utf8').split('\n')[0] - re_result = version_regex.search(last_tag) - if re_result is None: - raise ValueError("Could not parse last tag %s" % last_tag) - release_version = "%d.%d.0-SNAPSHOT" % (int(re_result.group(2)), int(re_result.group(3)) + 1) - # orssh uses a different deploy url for snapshots yay - deploy_url = 'https://oss.sonatype.org/content/repositories/snapshots/' - is_release = False +if release_tag == 'master': + # for SNAPSHOT builds we increment the minor version and set patch level to zero. + # seemed the most sensible + last_tag = exec('git tag --sort=-committerdate').decode('utf8').split('\n')[0] + re_result = version_regex.search(last_tag) + if re_result is None: + raise ValueError("Could not parse last tag %s" % last_tag) + release_version = "%d.%d.0-SNAPSHOT" % (int(re_result.group(2)), int(re_result.group(3)) + 1) + # orssh uses a different deploy url for snapshots yay + deploy_url = 'https://oss.sonatype.org/content/repositories/snapshots/' + is_release = False elif version_regex.match(release_tag): - release_version = version_regex.search(release_tag).group(1) + release_version = version_regex.search(release_tag).group(1) else: - print("Not running on %s" % release_tag) - exit(0) + print("Not running on %s" % release_tag) + exit(0) jdbc_artifact_dir = sys.argv[2] jdbc_root_path = sys.argv[3] @@ -131,11 +133,11 @@ def exec(cmd): # fatten up jar to add other binaries, start with first one shutil.copyfile(os.path.join(jdbc_artifact_dir, "java-" + combine_builds[0], "duckdb_jdbc.jar"), binary_jar) for build in combine_builds[1:]: - old_jar = zipfile.ZipFile(os.path.join(jdbc_artifact_dir, "java-" + build, "duckdb_jdbc.jar"), 'r') - for zip_entry in old_jar.namelist(): - if zip_entry.startswith('libduckdb_java.so'): - old_jar.extract(zip_entry, staging_dir) - exec("jar -uf %s -C %s %s" % (binary_jar, staging_dir, zip_entry)) + old_jar = zipfile.ZipFile(os.path.join(jdbc_artifact_dir, "java-" + build, "duckdb_jdbc.jar"), 'r') + for zip_entry in old_jar.namelist(): + if zip_entry.startswith('libduckdb_java.so'): + old_jar.extract(zip_entry, staging_dir) + exec("jar -uf %s -C %s %s" % (binary_jar, staging_dir, zip_entry)) javadoc_stage_dir = tempfile.mkdtemp() @@ -144,8 +146,13 @@ def exec(cmd): exec("jar -cvf %s -C %s/src/main/java org" % (sources_jar, jdbc_root_path)) # make sure all files exist before continuing -if not os.path.exists(javadoc_jar) or not os.path.exists(sources_jar) or not os.path.exists(pom) or not os.path.exists(binary_jar): - raise ValueError('could not create all required files') +if ( + not os.path.exists(javadoc_jar) + or not os.path.exists(sources_jar) + or not os.path.exists(pom) + or not os.path.exists(binary_jar) +): + raise ValueError('could not create all required files') # run basic tests, it should now work on whatever platform this is exec("java -cp %s org.duckdb.test.TestDuckDBJDBC" % binary_jar) @@ -168,11 +175,11 @@ def exec(cmd): results_dir = os.path.join(jdbc_artifact_dir, "results") if not os.path.exists(results_dir): - os.mkdir(results_dir) + os.mkdir(results_dir) for jar in [binary_jar, sources_jar, javadoc_jar]: - shutil.copyfile(jar, os.path.join(results_dir, os.path.basename(jar))) + shutil.copyfile(jar, os.path.join(results_dir, os.path.basename(jar))) print("JARs created, uploading (this can take a while!)") deploy_cmd_prefix = 'mvn gpg:sign-and-deploy-file -Durl=%s -DrepositoryId=ossrh' % deploy_url @@ -182,14 +189,14 @@ def exec(cmd): if not is_release: - print("Not a release, not closing repo") - exit(0) + print("Not a release, not closing repo") + exit(0) print("Close/Release steps") # # beautiful os.environ["MAVEN_OPTS"] = '--add-opens=java.base/java.util=ALL-UNNAMED' -#this list has horrid output, lets try to parse. What we want starts with orgduckdb- and then a number +# this list has horrid output, lets try to parse. What we want starts with orgduckdb- and then a number repo_id = re.search(r'(orgduckdb-\d+)', exec("mvn -f %s nexus-staging:rc-list" % (pom)).decode('utf8')).groups()[0] exec("mvn -f %s nexus-staging:rc-close -DstagingRepositoryId=%s" % (pom, repo_id)) exec("mvn -f %s nexus-staging:rc-release -DstagingRepositoryId=%s" % (pom, repo_id)) diff --git a/scripts/merge_vcpkg_deps.py b/scripts/merge_vcpkg_deps.py index e5b6e08cc5a3..6fa25b283bae 100644 --- a/scripts/merge_vcpkg_deps.py +++ b/scripts/merge_vcpkg_deps.py @@ -16,7 +16,7 @@ def prefix_overlay_ports(overlay_ports, path_to_vcpkg_json): def prefix_overlay_port(overlay_port): - vcpkg_prefix_path = path_to_vcpkg_json[0:path_to_vcpkg_json.find("/vcpkg.json")] + vcpkg_prefix_path = path_to_vcpkg_json[0 : path_to_vcpkg_json.find("/vcpkg.json")] return vcpkg_prefix_path + overlay_port return map(prefix_overlay_port, overlay_ports) @@ -57,22 +57,15 @@ def prefix_overlay_port(overlay_port): "description": f"Auto-generated vcpkg.json for combined DuckDB extension build", "builtin-baseline": "501db0f17ef6df184fcdbfbe0f87cde2313b6ab1", "dependencies": final_deduplicated_deps, - "overrides": [ - { - "name": "openssl", - "version": "3.0.8" - } - ] + "overrides": [{"name": "openssl", "version": "3.0.8"}], } if merged_overlay_ports: - data['vcpkg-configuration'] = { - 'overlay-ports': merged_overlay_ports - } + data['vcpkg-configuration'] = {'overlay-ports': merged_overlay_ports} # Print output print("Writing to 'build/extension_configuration/vcpkg.json': ") print(data["dependencies"]) with open('build/extension_configuration/vcpkg.json', 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=4) \ No newline at end of file + json.dump(data, f, ensure_ascii=False, indent=4) diff --git a/scripts/package_build.py b/scripts/package_build.py index 5a964f3132ee..f0d7535ed0d5 100644 --- a/scripts/package_build.py +++ b/scripts/package_build.py @@ -7,6 +7,7 @@ excluded_objects = ['utf8proc_data.cpp'] + def third_party_includes(): includes = [] includes += [os.path.join('third_party', 'fmt', 'include')] @@ -31,6 +32,7 @@ def third_party_includes(): includes += [os.path.join('third_party', 'jaro_winkler', 'details')] return includes + def third_party_sources(): sources = [] sources += [os.path.join('third_party', 'fmt')] @@ -44,6 +46,7 @@ def third_party_sources(): sources += [os.path.join('third_party', 'mbedtls')] return sources + def file_is_lib(fname, libname): libextensions = ['.a', '.lib'] libprefixes = ['', 'lib'] @@ -54,8 +57,10 @@ def file_is_lib(fname, libname): return True return False + def get_libraries(binary_dir, libraries, extensions): result_libs = [] + def find_library_recursive(search_dir, libname): flist = os.listdir(search_dir) for fname in flist: @@ -90,6 +95,7 @@ def find_library(search_dir, libname, result_libs, required=False): return result_libs + def includes(extensions): scripts_dir = os.path.dirname(os.path.abspath(__file__)) # add includes for duckdb and extensions @@ -101,12 +107,15 @@ def includes(extensions): includes.append(os.path.join(scripts_dir, '..', 'extension', ext, 'include')) return includes + def include_flags(extensions): return ' ' + ' '.join(['-I' + x for x in includes(extensions)]) + def convert_backslashes(x): return '/'.join(x.split(os.path.sep)) + def get_relative_path(source_dir, target_file): source_dir = convert_backslashes(source_dir) target_file = convert_backslashes(target_file) @@ -116,20 +125,22 @@ def get_relative_path(source_dir, target_file): target_file = target_file.replace(source_dir, "").lstrip('/') return target_file + def git_commit_hash(): if 'SETUPTOOLS_SCM_PRETEND_HASH' in os.environ: return os.environ['SETUPTOOLS_SCM_PRETEND_HASH'] try: - return subprocess.check_output(['git','log','-1','--format=%h']).strip().decode('utf8') + return subprocess.check_output(['git', 'log', '-1', '--format=%h']).strip().decode('utf8') except: return "deadbeeff" + def git_dev_version(): if 'SETUPTOOLS_SCM_PRETEND_VERSION' in os.environ: return os.environ['SETUPTOOLS_SCM_PRETEND_VERSION'] try: - version = subprocess.check_output(['git','describe','--tags','--abbrev=0']).strip().decode('utf8') - long_version = subprocess.check_output(['git','describe','--tags','--long']).strip().decode('utf8') + version = subprocess.check_output(['git', 'describe', '--tags', '--abbrev=0']).strip().decode('utf8') + long_version = subprocess.check_output(['git', 'describe', '--tags', '--long']).strip().decode('utf8') version_splits = version.lstrip('v').split('.') dev_version = long_version.split('-')[1] if int(dev_version) == 0: @@ -142,8 +153,10 @@ def git_dev_version(): except: return "0.0.0" + def include_package(pkg_name, pkg_dir, include_files, include_list, source_list): import amalgamation + original_path = sys.path # append the directory sys.path.append(pkg_dir) @@ -158,7 +171,8 @@ def include_package(pkg_name, pkg_dir, include_files, include_list, source_list) sys.path = original_path -def build_package(target_dir, extensions, linenumbers = False, unity_count = 32, folder_name = 'duckdb'): + +def build_package(target_dir, extensions, linenumbers=False, unity_count=32, folder_name='duckdb'): if not os.path.isdir(target_dir): os.mkdir(target_dir) @@ -270,17 +284,23 @@ def generate_unity_builds(source_list, nsplits, linenumbers): for filename in filenames: scores[filename] = score score += 1 - current_files.sort(key = lambda x: scores[os.path.basename(x)] if os.path.basename(x) in scores else 99999) + current_files.sort( + key=lambda x: scores[os.path.basename(x)] if os.path.basename(x) in scores else 99999 + ) if not unity_build: new_source_files += [os.path.join(folder_name, file) for file in current_files] else: - new_source_files.append(generate_unity_build(current_files, dirname.replace(os.path.sep, '_'), linenumbers)) + new_source_files.append( + generate_unity_build(current_files, dirname.replace(os.path.sep, '_'), linenumbers) + ) return new_source_files original_sources = source_list source_list = generate_unity_builds(source_list, unity_count, linenumbers) os.chdir(prev_wd) - return ([convert_backslashes(x) for x in source_list if not file_is_excluded(x)], - [convert_backslashes(x) for x in include_list], - [convert_backslashes(x) for x in original_sources]) + return ( + [convert_backslashes(x) for x in source_list if not file_is_excluded(x)], + [convert_backslashes(x) for x in include_list], + [convert_backslashes(x) for x in original_sources], + ) diff --git a/scripts/plan_cost_runner.py b/scripts/plan_cost_runner.py index be6454d35c86..7f573816c60f 100644 --- a/scripts/plan_cost_runner.py +++ b/scripts/plan_cost_runner.py @@ -17,7 +17,9 @@ def print_usage(): - print(f"Expected usage: python3 scripts/{os.path.basename(__file__)} --old=/old/duckdb_cli --new=/new/duckdb_cli --dir=/path/to/benchmark/dir") + print( + f"Expected usage: python3 scripts/{os.path.basename(__file__)} --old=/old/duckdb_cli --new=/new/duckdb_cli --dir=/path/to/benchmark/dir" + ) exit(1) @@ -41,7 +43,9 @@ def parse_args(): def init_db(cli, dbname, benchmark_dir): print(f"INITIALIZING {dbname} ...") - subprocess.run(f"{cli} {dbname} < {benchmark_dir}/init/schema.sql", shell=True, check=True, stdout=subprocess.DEVNULL) + subprocess.run( + f"{cli} {dbname} < {benchmark_dir}/init/schema.sql", shell=True, check=True, stdout=subprocess.DEVNULL + ) subprocess.run(f"{cli} {dbname} < {benchmark_dir}/init/load.sql", shell=True, check=True, stdout=subprocess.DEVNULL) print("INITIALIZATION DONE") @@ -59,7 +63,12 @@ def op_inspect(op): def query_plan_cost(cli, dbname, query): try: - subprocess.run(f"{cli} --readonly {dbname} -c \"{ENABLE_PROFILING};{PROFILE_OUTPUT};{query}\"", shell=True, check=True, capture_output=True) + subprocess.run( + f"{cli} --readonly {dbname} -c \"{ENABLE_PROFILING};{PROFILE_OUTPUT};{query}\"", + shell=True, + check=True, + capture_output=True, + ) except subprocess.CalledProcessError as e: print("-------------------------") print("--------Failure----------") @@ -95,11 +104,13 @@ def print_diffs(diffs): print("Old cost:", old_cost) print("New cost:", new_cost) + def cardinality_is_higher(card_a, card_b): # card_a > card_b? # add 20% threshold before we start caring return card_a > card_b + def main(): old, new, benchmark_dir = parse_args() init_db(old, OLD_DB_NAME, benchmark_dir) @@ -126,7 +137,7 @@ def main(): improvements.append((query_name, old_cost, new_cost)) elif cardinality_is_higher(new_cost, old_cost): regressions.append((query_name, old_cost, new_cost)) - + exit_code = 0 if improvements: print_banner("IMPROVEMENTS DETECTED") @@ -137,7 +148,7 @@ def main(): print_diffs(regressions) if not improvements and not regressions: print_banner("NO DIFFERENCES DETECTED") - + os.remove(OLD_DB_NAME) os.remove(NEW_DB_NAME) os.remove(PROFILE_FILENAME) diff --git a/scripts/pypi_cleanup.py b/scripts/pypi_cleanup.py index 717e87276e02..e17fe791d4db 100644 --- a/scripts/pypi_cleanup.py +++ b/scripts/pypi_cleanup.py @@ -12,8 +12,8 @@ pypi_username = "hfmuehleisen" pypi_password = os.getenv("PYPI_PASSWORD", "") if pypi_password == "": - print(f'need {pypi_username}\' PyPI password in PYPI_PASSWORD env variable') - exit(1) + print(f'need {pypi_username}\' PyPI password in PYPI_PASSWORD env variable') + exit(1) ctx = ssl.create_default_context() ctx.check_hostname = False @@ -22,88 +22,91 @@ url = 'https://pypi.python.org/pypi/duckdb/json' req = urllib.request.urlopen(url, context=ctx) raw_resp = req.read().decode() -resp_json = json.loads(raw_resp) +resp_json = json.loads(raw_resp) last_release = resp_json["info"]["version"] latest_release_v = version.LooseVersion(last_release) latest_prereleases = [] + def parsever(ele): - major = version.LooseVersion('.'.join(ele.split('.')[:3])) - dev = int(ele.split('.')[3].replace('dev','')) - return (major, dev,) + major = version.LooseVersion('.'.join(ele.split('.')[:3])) + dev = int(ele.split('.')[3].replace('dev', '')) + return ( + major, + dev, + ) + # get a list of all pre-releases release_list = resp_json["releases"] for ele in release_list: - if not ".dev" in ele: - continue + if not ".dev" in ele: + continue - (major, dev) = parsever(ele) + (major, dev) = parsever(ele) - if major > latest_release_v: - latest_prereleases.append((ele, dev)) + if major > latest_release_v: + latest_prereleases.append((ele, dev)) # sort the pre-releases latest_prereleases = sorted(latest_prereleases, key=lambda x: x[1]) print("List of pre-releases") for prerelease in latest_prereleases: - print(prerelease[0]) + print(prerelease[0]) if len(latest_prereleases) <= retain_count: - print(f"At most {retain_count} pre-releases - nothing to delete") - exit(0) + print(f"At most {retain_count} pre-releases - nothing to delete") + exit(0) -to_delete = latest_prereleases[:len(latest_prereleases) - retain_count] +to_delete = latest_prereleases[: len(latest_prereleases) - retain_count] if len(to_delete) < 1: - raise ValueError("Nothing to delete") + raise ValueError("Nothing to delete") print("List of to-be-deleted releases") for release in to_delete: - print(release[0]) + print(release[0]) to_delete = [x[0] for x in to_delete] print(to_delete) # gah2 cj = http.cookiejar.CookieJar() -opener = urllib.request.build_opener(urllib.request.HTTPCookieProcessor(cj), urllib.request.HTTPSHandler(context=ctx, debuglevel = 0)) +opener = urllib.request.build_opener( + urllib.request.HTTPCookieProcessor(cj), urllib.request.HTTPSHandler(context=ctx, debuglevel=0) +) + def call(url, data=None, headers={}): - return opener.open(urllib.request.Request(url, data, headers)).read().decode() + return opener.open(urllib.request.Request(url, data, headers)).read().decode() csrf_token_re = re.compile(r"name=\"csrf_token\"[^>]+value=\"([^\"]+)\"") + def get_token(url): - return csrf_token_re.search(call(url)).group(1) + return csrf_token_re.search(call(url)).group(1) -login_data = urllib.parse.urlencode({ - "csrf_token" : get_token("https://pypi.org/account/login/"), - "username" : pypi_username, - "password" : pypi_password}).encode() -login_headers = { - "Referer": "https://pypi.org/account/login/"} +login_data = urllib.parse.urlencode( + {"csrf_token": get_token("https://pypi.org/account/login/"), "username": pypi_username, "password": pypi_password} +).encode() +login_headers = {"Referer": "https://pypi.org/account/login/"} # perform login call("https://pypi.org/account/login/", login_data, login_headers) # delete gunk delete_crsf_token = get_token("https://pypi.org/manage/project/duckdb/releases/") -delete_headers = { - "Referer": "https://pypi.org/manage/project/duckdb/releases/"} +delete_headers = {"Referer": "https://pypi.org/manage/project/duckdb/releases/"} for rev in to_delete: - print("Deleting %s" % rev) - - try: - delete_data = urllib.parse.urlencode({ - "confirm_delete_version" : rev, - "csrf_token" : delete_crsf_token - }).encode() - call("https://pypi.org/manage/project/duckdb/release/%s/" % rev, delete_data, delete_headers) - except Exception as e: - print(f"Failed to delete {rev}") - print(e) + print("Deleting %s" % rev) + + try: + delete_data = urllib.parse.urlencode({"confirm_delete_version": rev, "csrf_token": delete_crsf_token}).encode() + call("https://pypi.org/manage/project/duckdb/release/%s/" % rev, delete_data, delete_headers) + except Exception as e: + print(f"Failed to delete {rev}") + print(e) diff --git a/scripts/python_helpers.py b/scripts/python_helpers.py index eb605f833f22..6c0489932674 100644 --- a/scripts/python_helpers.py +++ b/scripts/python_helpers.py @@ -1,13 +1,15 @@ def open_utf8(fpath, flags): import sys + if sys.version_info[0] < 3: return open(fpath, flags) else: return open(fpath, flags, encoding="utf8") + def normalize_path(path): import os - + def normalize(p): return os.path.sep.join(p.split('/')) @@ -17,7 +19,5 @@ def normalize(p): if (isinstance, str): return normalize(path) - - raise Exception("Can only be called with a str or list argument") - \ No newline at end of file + raise Exception("Can only be called with a str or list argument") diff --git a/scripts/reduce_sql.py b/scripts/reduce_sql.py index f81669a73391..e266fd1d1aa0 100644 --- a/scripts/reduce_sql.py +++ b/scripts/reduce_sql.py @@ -11,8 +11,9 @@ SELECT * FROM reduce_sql_statement('${QUERY}'); ''' + def sanitize_error(err): - err = re.sub('Error: near line \d+: ', '', err) + err = re.sub(r'Error: near line \d+: ', '', err) err = err.replace(os.getcwd() + '/', '') err = err.replace(os.getcwd(), '') if 'AddressSanitizer' in err: @@ -20,6 +21,7 @@ def sanitize_error(err): err = 'AddressSanitizer error ' + match return err + def run_shell_command(shell, cmd): command = [shell, '-csv', '--batch', '-init', '/dev/null'] @@ -28,6 +30,7 @@ def run_shell_command(shell, cmd): stderr = res.stderr.decode('utf8').strip() return (stdout, stderr, res.returncode) + def get_reduced_sql(shell, sql_query): reduce_query = get_reduced_query.replace('${QUERY}', sql_query.replace("'", "''")) (stdout, stderr, returncode) = run_shell_command(shell, reduce_query) @@ -40,6 +43,7 @@ def get_reduced_sql(shell, sql_query): reduce_candidates.append(line.strip('"').replace('""', '"')) return reduce_candidates[1:] + def reduce(sql_query, data_load, shell, error_msg, max_time_seconds=300): start = time.time() while True: @@ -66,18 +70,22 @@ def reduce(sql_query, data_load, shell, error_msg, max_time_seconds=300): break return sql_query + def is_ddl_query(query): query = query.lower() if 'create' in query or 'insert' in query or 'update' in query or 'delete' in query: return True return False + def initial_cleanup(query_log): query_log = query_log.replace('SELECT * FROM pragma_version()\n', '') return query_log + def run_queries_until_crash_mp(queries, result_file): import duckdb + con = duckdb.connect() sqlite_con = sqlite3.connect(result_file) sqlite_con.execute('CREATE TABLE queries(id INT, text VARCHAR)') @@ -102,7 +110,7 @@ def run_queries_until_crash_mp(queries, result_file): keep_query = True sqlite_con.execute('UPDATE result SET text=?', (exception_error,)) if not keep_query: - sqlite_con.execute('DELETE FROM queries WHERE id=?', (id, )) + sqlite_con.execute('DELETE FROM queries WHERE id=?', (id,)) if is_internal_error: # found internal error: no need to try further queries break @@ -113,6 +121,7 @@ def run_queries_until_crash_mp(queries, result_file): sqlite_con.commit() sqlite_con.close() + def run_queries_until_crash(queries): sqlite_file = 'cleaned_queries.db' if os.path.isfile(sqlite_file): @@ -140,8 +149,10 @@ def cleanup_irrelevant_queries(query_log): queries = [x for x in query_log.split(';\n') if len(x) > 0] return run_queries_until_crash(queries) + # def reduce_internal(start, sql_query, data_load, queries_final, shell, error_msg, max_time_seconds=300): + def reduce_query_log_query(start, shell, queries, query_index, max_time_seconds): new_query_list = queries[:] sql_query = queries[query_index] @@ -173,6 +184,7 @@ def reduce_query_log_query(start, shell, queries, query_index, max_time_seconds) break return sql_query + def reduce_query_log(queries, shell, max_time_seconds=300): start = time.time() current_index = 0 @@ -183,7 +195,7 @@ def reduce_query_log(queries, shell, max_time_seconds=300): if current_time - start > max_time_seconds: break # remove the query at "current_index" - new_queries = queries[:current_index] + queries[current_index + 1:] + new_queries = queries[:current_index] + queries[current_index + 1 :] # try to run the queries and check if we still get the same error (new_queries_x, current_error) = run_queries_until_crash(new_queries) if current_error is None: @@ -203,7 +215,6 @@ def reduce_query_log(queries, shell, max_time_seconds=300): return queries - # Example usage: # error_msg = 'INTERNAL Error: Assertion triggered in file "/Users/myth/Programs/duckdb-bugfix/src/common/types/data_chunk.cpp" on line 41: !types.empty()' # shell = 'build/debug/duckdb' @@ -260,4 +271,4 @@ def reduce_query_log(queries, shell, max_time_seconds=300): # limit 88 # ''' # -# print(reduce(sql_query, data_load, shell, error_msg)) \ No newline at end of file +# print(reduce(sql_query, data_load, shell, error_msg)) diff --git a/scripts/regression_check.py b/scripts/regression_check.py index 1fa6ae24c47a..0f319cea8bd2 100644 --- a/scripts/regression_check.py +++ b/scripts/regression_check.py @@ -25,14 +25,18 @@ exit(1) con = duckdb.connect() -old_timings_l = con.execute(f"SELECT name, median(time) FROM read_csv_auto('{old_file}') t(name, nrun, time) GROUP BY ALL ORDER BY ALL").fetchall() -new_timings_l = con.execute(f"SELECT name, median(time) FROM read_csv_auto('{new_file}') t(name, nrun, time) GROUP BY ALL ORDER BY ALL").fetchall() +old_timings_l = con.execute( + f"SELECT name, median(time) FROM read_csv_auto('{old_file}') t(name, nrun, time) GROUP BY ALL ORDER BY ALL" +).fetchall() +new_timings_l = con.execute( + f"SELECT name, median(time) FROM read_csv_auto('{new_file}') t(name, nrun, time) GROUP BY ALL ORDER BY ALL" +).fetchall() old_timings = {} new_timings = {} for entry in old_timings_l: - name = entry[0] + name = entry[0] timing = entry[1] old_timings[name] = timing @@ -55,10 +59,12 @@ return_code = 0 if len(slow_keys) > 0: - print('''==================================================== + print( + '''==================================================== ============== REGRESSIONS DETECTED ============= ==================================================== -''') +''' + ) return_code = 1 for key in slow_keys: new_timing = new_timings[key] @@ -68,28 +74,36 @@ print(f"New timing: {new_timing}") print("") - print('''==================================================== + print( + '''==================================================== ================== New Timings ================== ==================================================== -''') +''' + ) with open(new_file, 'r') as f: print(f.read()) - print('''==================================================== + print( + '''==================================================== ================== Old Timings ================== ==================================================== -''') +''' + ) with open(old_file, 'r') as f: print(f.read()) else: - print('''==================================================== + print( + '''==================================================== ============== NO REGRESSIONS DETECTED ============= ==================================================== -''') +''' + ) -print('''==================================================== +print( + '''==================================================== =================== ALL TIMINGS =================== ==================================================== -''') +''' +) for key in test_keys: new_timing = new_timings[key] old_timing = old_timings[key] @@ -98,4 +112,4 @@ print(f"New timing: {new_timing}") print("") -exit(return_code) \ No newline at end of file +exit(return_code) diff --git a/scripts/regression_test_python.py b/scripts/regression_test_python.py index e5c68f08fdf1..3dc6e9f5723a 100644 --- a/scripts/regression_test_python.py +++ b/scripts/regression_test_python.py @@ -32,16 +32,17 @@ main_con.execute('CALL dbgen(sf=1)') tables = [ - "customer", - "lineitem", - "nation", - "orders", - "part", - "partsupp", - "region", - "supplier", + "customer", + "lineitem", + "nation", + "orders", + "part", + "partsupp", + "region", + "supplier", ] + def open_connection(): con = duckdb.connect() if threads is not None: @@ -60,6 +61,7 @@ def write_result(benchmark_name, nrun, t): else: print(bench_result) + def benchmark_queries(benchmark_name, con, queries): if verbose: print(benchmark_name) @@ -95,6 +97,7 @@ def benchmark_queries(benchmark_name, con, queries): print(f"T{padding}: {t}s") write_result(benchmark_name, nrun, t) + def run_dataload(con, type): benchmark_name = type + "_load_lineitem" if verbose: @@ -117,6 +120,7 @@ def run_dataload(con, type): print(f"T{padding}: {t}s") write_result(benchmark_name, nrun, t) + def run_tpch(con, prefix): benchmark_name = f"{prefix}tpch" queries = [] diff --git a/scripts/regression_test_runner.py b/scripts/regression_test_runner.py index 95fa2dd08fc4..5b86f3f24ed5 100644 --- a/scripts/regression_test_runner.py +++ b/scripts/regression_test_runner.py @@ -30,7 +30,9 @@ threads = int(arg.replace("--threads=", "")) if old_runner is None or new_runner is None or benchmark_file is None: - print("Expected usage: python3 scripts/regression_test_runner.py --old=/old/benchmark_runner --new=/new/benchmark_runner --benchmarks=/benchmark/list.csv") + print( + "Expected usage: python3 scripts/regression_test_runner.py --old=/old/benchmark_runner --new=/new/benchmark_runner --benchmarks=/benchmark/list.csv" + ) exit(1) if not os.path.isfile(old_runner): @@ -41,6 +43,7 @@ print(f"Failed to find new runner {new_runner}") exit(1) + def run_benchmark(runner, benchmark): benchmark_args = [runner, benchmark] if threads is not None: @@ -51,15 +54,19 @@ def run_benchmark(runner, benchmark): proc.wait() if proc.returncode != 0: print("Failed to run benchmark " + benchmark) - print('''==================================================== + print( + '''==================================================== ============== STDERR ============= ==================================================== -''') +''' + ) print(err) - print('''==================================================== + print( + '''==================================================== ============== STDOUT ============= ==================================================== -''') +''' + ) print(out) return 'Failed to run benchmark ' + benchmark if verbose: @@ -83,12 +90,14 @@ def run_benchmark(runner, benchmark): print(err) return 'Failed to run benchmark ' + benchmark + def run_benchmarks(runner, benchmark_list): results = {} for benchmark in benchmark_list: results[benchmark] = run_benchmark(runner, benchmark) return results + # read the initial benchmark list with open(benchmark_file, 'r') as f: benchmark_list = [x.strip() for x in f.read().split('\n') if len(x) > 0] @@ -100,11 +109,13 @@ def run_benchmarks(runner, benchmark_list): regression_list = [] if len(benchmark_list) == 0: break - print(f'''==================================================== + print( + f'''==================================================== ============== ITERATION {i} ============= ============== REMAINING {len(benchmark_list)} ============= ==================================================== -''') +''' + ) old_results = run_benchmarks(old_runner, benchmark_list) new_results = run_benchmarks(new_runner, benchmark_list) @@ -125,24 +136,30 @@ def run_benchmarks(runner, benchmark_list): regression_list += error_list if len(regression_list) > 0: exit_code = 1 - print('''==================================================== + print( + '''==================================================== ============== REGRESSIONS DETECTED ============= ==================================================== -''') +''' + ) for regression in regression_list: print(f"{regression[0]}") print(f"Old timing: {regression[1]}") print(f"New timing: {regression[2]}") print("") - print('''==================================================== + print( + '''==================================================== ============== OTHER TIMINGS ============= ==================================================== -''') +''' + ) else: - print('''==================================================== + print( + '''==================================================== ============== NO REGRESSIONS DETECTED ============= ==================================================== -''') +''' + ) other_results.sort() for res in other_results: diff --git a/scripts/regression_test_storage_size.py b/scripts/regression_test_storage_size.py index 4a602e86e247..d4d665e4e0a9 100644 --- a/scripts/regression_test_storage_size.py +++ b/scripts/regression_test_storage_size.py @@ -7,10 +7,8 @@ regression_threshold_percentage = 0.05 parser = argparse.ArgumentParser(description='Generate TPC-DS reference results from Postgres.') -parser.add_argument('--old', dest='old_runner', - action='store', help='Path to the old shell executable') -parser.add_argument('--new', dest='new_runner', - action='store', help='Path to the new shell executable') +parser.add_argument('--old', dest='old_runner', action='store', help='Path to the old shell executable') +parser.add_argument('--new', dest='new_runner', action='store', help='Path to the new shell executable') args = parser.parse_args() @@ -26,6 +24,7 @@ print(f"Failed to find new runner {new_runner}") exit(1) + def load_data(shell_path, load_script): with tempfile.NamedTemporaryFile() as f: filename = f.name @@ -38,6 +37,7 @@ def load_data(shell_path, load_script): return None return os.path.getsize(filename) + def run_benchmark(load_script, benchmark_name): print('----------------------------') print(f'Running benchmark {benchmark_name}') @@ -61,14 +61,12 @@ def run_benchmark(load_script, benchmark_name): print('----------------------------') return True + tpch_load = 'CALL dbgen(sf=1);' tpcds_load = 'CALL dsdgen(sf=1);' -benchmarks = [ - [tpch_load, 'TPC-H SF1'], - [tpcds_load, 'TPC-DS SF1'] -] +benchmarks = [[tpch_load, 'TPC-H SF1'], [tpcds_load, 'TPC-DS SF1']] for benchmark in benchmarks: if not run_benchmark(benchmark[0], benchmark[1]): diff --git a/scripts/repeat_until_success.py b/scripts/repeat_until_success.py index abf49698677e..8a9da70c6797 100644 --- a/scripts/repeat_until_success.py +++ b/scripts/repeat_until_success.py @@ -3,18 +3,18 @@ import time if len(sys.argv) <= 1: - print("Expected usage: python3 repeat_until_success.py [command]") - exit(1) + print("Expected usage: python3 repeat_until_success.py [command]") + exit(1) ntries = 10 sleep_duration = 3 cmd = sys.argv[1] for i in range(ntries): - ret = os.system(cmd) - if ret is None or ret == 0: - exit(0) - print("Command {{ " + cmd + " }} failed, retrying (" + str(i + 1) + "/" + str(ntries) + ")") - time.sleep(sleep_duration) + ret = os.system(cmd) + if ret is None or ret == 0: + exit(0) + print("Command {{ " + cmd + " }} failed, retrying (" + str(i + 1) + "/" + str(ntries) + ")") + time.sleep(sleep_duration) exit(1) diff --git a/scripts/run-clang-tidy.py b/scripts/run-clang-tidy.py index 35e6d714111a..4047af625565 100644 --- a/scripts/run-clang-tidy.py +++ b/scripts/run-clang-tidy.py @@ -1,12 +1,12 @@ #!/usr/bin/env python # -#===- run-clang-tidy.py - Parallel clang-tidy runner ---------*- python -*--===# +# ===- run-clang-tidy.py - Parallel clang-tidy runner ---------*- python -*--===# # # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # -#===------------------------------------------------------------------------===# +# ===------------------------------------------------------------------------===# # FIXME: Integrate with clang-tidy-diff.py """ @@ -49,9 +49,9 @@ import traceback try: - import yaml + import yaml except ImportError: - yaml = None + yaml = None is_py2 = sys.version[0] == '2' @@ -60,267 +60,288 @@ else: import queue as queue + def find_compilation_database(path): - """Adjusts the directory until a compilation database is found.""" - result = './' - while not os.path.isfile(os.path.join(result, path)): - if os.path.realpath(result) == '/': - print('Error: could not find compilation database.') - sys.exit(1) - result += '../' - return os.path.realpath(result) + """Adjusts the directory until a compilation database is found.""" + result = './' + while not os.path.isfile(os.path.join(result, path)): + if os.path.realpath(result) == '/': + print('Error: could not find compilation database.') + sys.exit(1) + result += '../' + return os.path.realpath(result) def make_absolute(f, directory): - if os.path.isabs(f): - return f - return os.path.normpath(os.path.join(directory, f)) - - -def get_tidy_invocation(f, clang_tidy_binary, checks, tmpdir, build_path, - header_filter, extra_arg, extra_arg_before, quiet, - config): - """Gets a command line for clang-tidy.""" - start = [clang_tidy_binary] - if header_filter is not None: - start.append('-header-filter=' + header_filter) - if checks: - start.append('-checks=' + checks) - if tmpdir is not None: - start.append('-export-fixes') - # Get a temporary file. We immediately close the handle so clang-tidy can - # overwrite it. - (handle, name) = tempfile.mkstemp(suffix='.yaml', dir=tmpdir) - os.close(handle) - start.append(name) - for arg in extra_arg: - start.append('-extra-arg=%s' % arg) - for arg in extra_arg_before: - start.append('-extra-arg-before=%s' % arg) - start.append('-p=' + build_path) - if quiet: - start.append('--quiet') - if config: - start.append('-config=' + config) - start.append(f) - return start + if os.path.isabs(f): + return f + return os.path.normpath(os.path.join(directory, f)) + + +def get_tidy_invocation( + f, clang_tidy_binary, checks, tmpdir, build_path, header_filter, extra_arg, extra_arg_before, quiet, config +): + """Gets a command line for clang-tidy.""" + start = [clang_tidy_binary] + if header_filter is not None: + start.append('-header-filter=' + header_filter) + if checks: + start.append('-checks=' + checks) + if tmpdir is not None: + start.append('-export-fixes') + # Get a temporary file. We immediately close the handle so clang-tidy can + # overwrite it. + (handle, name) = tempfile.mkstemp(suffix='.yaml', dir=tmpdir) + os.close(handle) + start.append(name) + for arg in extra_arg: + start.append('-extra-arg=%s' % arg) + for arg in extra_arg_before: + start.append('-extra-arg-before=%s' % arg) + start.append('-p=' + build_path) + if quiet: + start.append('--quiet') + if config: + start.append('-config=' + config) + start.append(f) + return start def merge_replacement_files(tmpdir, mergefile): - """Merge all replacement files in a directory into a single file""" - # The fixes suggested by clang-tidy >= 4.0.0 are given under - # the top level key 'Diagnostics' in the output yaml files - mergekey="Diagnostics" - merged=[] - for replacefile in glob.iglob(os.path.join(tmpdir, '*.yaml')): - content = yaml.safe_load(open(replacefile, 'r')) - if not content: - continue # Skip empty files. - merged.extend(content.get(mergekey, [])) - - if merged: - # MainSourceFile: The key is required by the definition inside - # include/clang/Tooling/ReplacementsYaml.h, but the value - # is actually never used inside clang-apply-replacements, - # so we set it to '' here. - output = { 'MainSourceFile': '', mergekey: merged } - with open(mergefile, 'w') as out: - yaml.safe_dump(output, out) - else: - # Empty the file: - open(mergefile, 'w').close() + """Merge all replacement files in a directory into a single file""" + # The fixes suggested by clang-tidy >= 4.0.0 are given under + # the top level key 'Diagnostics' in the output yaml files + mergekey = "Diagnostics" + merged = [] + for replacefile in glob.iglob(os.path.join(tmpdir, '*.yaml')): + content = yaml.safe_load(open(replacefile, 'r')) + if not content: + continue # Skip empty files. + merged.extend(content.get(mergekey, [])) + + if merged: + # MainSourceFile: The key is required by the definition inside + # include/clang/Tooling/ReplacementsYaml.h, but the value + # is actually never used inside clang-apply-replacements, + # so we set it to '' here. + output = {'MainSourceFile': '', mergekey: merged} + with open(mergefile, 'w') as out: + yaml.safe_dump(output, out) + else: + # Empty the file: + open(mergefile, 'w').close() def check_clang_apply_replacements_binary(args): - """Checks if invoking supplied clang-apply-replacements binary works.""" - try: - subprocess.check_call([args.clang_apply_replacements_binary, '--version']) - except: - print('Unable to run clang-apply-replacements. Is clang-apply-replacements ' - 'binary correctly specified?', file=sys.stderr) - traceback.print_exc() - sys.exit(1) + """Checks if invoking supplied clang-apply-replacements binary works.""" + try: + subprocess.check_call([args.clang_apply_replacements_binary, '--version']) + except: + print( + 'Unable to run clang-apply-replacements. Is clang-apply-replacements ' 'binary correctly specified?', + file=sys.stderr, + ) + traceback.print_exc() + sys.exit(1) def apply_fixes(args, tmpdir): - """Calls clang-apply-fixes on a given directory.""" - invocation = [args.clang_apply_replacements_binary] - if args.format: - invocation.append('-format') - if args.style: - invocation.append('-style=' + args.style) - invocation.append(tmpdir) - subprocess.call(invocation) + """Calls clang-apply-fixes on a given directory.""" + invocation = [args.clang_apply_replacements_binary] + if args.format: + invocation.append('-format') + if args.style: + invocation.append('-style=' + args.style) + invocation.append(tmpdir) + subprocess.call(invocation) def run_tidy(args, tmpdir, build_path, queue, lock, failed_files): - """Takes filenames out of queue and runs clang-tidy on them.""" - while True: - name = queue.get() - invocation = get_tidy_invocation(name, args.clang_tidy_binary, args.checks, - tmpdir, build_path, args.header_filter, - args.extra_arg, args.extra_arg_before, - args.quiet, args.config) - - proc = subprocess.Popen(invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - output, err = proc.communicate() - if proc.returncode != 0: - failed_files.append(name) - with lock: - sys.stdout.write(' '.join(invocation) + '\n' + output.decode('utf-8')) - if len(err) > 0: - sys.stdout.flush() - sys.stderr.write(err.decode('utf-8')) - queue.task_done() + """Takes filenames out of queue and runs clang-tidy on them.""" + while True: + name = queue.get() + invocation = get_tidy_invocation( + name, + args.clang_tidy_binary, + args.checks, + tmpdir, + build_path, + args.header_filter, + args.extra_arg, + args.extra_arg_before, + args.quiet, + args.config, + ) + + proc = subprocess.Popen(invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, err = proc.communicate() + if proc.returncode != 0: + failed_files.append(name) + with lock: + sys.stdout.write(' '.join(invocation) + '\n' + output.decode('utf-8')) + if len(err) > 0: + sys.stdout.flush() + sys.stderr.write(err.decode('utf-8')) + queue.task_done() def main(): - parser = argparse.ArgumentParser(description='Runs clang-tidy over all files ' - 'in a compilation database. Requires ' - 'clang-tidy and clang-apply-replacements in ' - '$PATH.') - parser.add_argument('-clang-tidy-binary', metavar='PATH', - default='clang-tidy', - help='path to clang-tidy binary') - parser.add_argument('-clang-apply-replacements-binary', metavar='PATH', - default='clang-apply-replacements', - help='path to clang-apply-replacements binary') - parser.add_argument('-checks', default=None, - help='checks filter, when not specified, use clang-tidy ' - 'default') - parser.add_argument('-config', default=None, - help='Specifies a configuration in YAML/JSON format: ' - ' -config="{Checks: \'*\', ' - ' CheckOptions: [{key: x, ' - ' value: y}]}" ' - 'When the value is empty, clang-tidy will ' - 'attempt to find a file named .clang-tidy for ' - 'each source file in its parent directories.') - parser.add_argument('-header-filter', default=None, - help='regular expression matching the names of the ' - 'headers to output diagnostics from. Diagnostics from ' - 'the main file of each translation unit are always ' - 'displayed.') - if yaml: - parser.add_argument('-export-fixes', metavar='filename', dest='export_fixes', - help='Create a yaml file to store suggested fixes in, ' - 'which can be applied with clang-apply-replacements.') - parser.add_argument('-j', type=int, default=0, - help='number of tidy instances to be run in parallel.') - parser.add_argument('files', nargs='*', default=['.*'], - help='files to be processed (regex on path)') - parser.add_argument('-fix', action='store_true', help='apply fix-its') - parser.add_argument('-format', action='store_true', help='Reformat code ' - 'after applying fixes') - parser.add_argument('-style', default='file', help='The style of reformat ' - 'code after applying fixes') - parser.add_argument('-p', dest='build_path', - help='Path used to read a compile command database.') - parser.add_argument('-extra-arg', dest='extra_arg', - action='append', default=[], - help='Additional argument to append to the compiler ' - 'command line.') - parser.add_argument('-extra-arg-before', dest='extra_arg_before', - action='append', default=[], - help='Additional argument to prepend to the compiler ' - 'command line.') - parser.add_argument('-quiet', action='store_true', - help='Run clang-tidy in quiet mode') - args = parser.parse_args() - - db_path = 'compile_commands.json' - - if args.build_path is not None: - build_path = args.build_path - else: - # Find our database - build_path = find_compilation_database(db_path) - - try: - invocation = [args.clang_tidy_binary, '-list-checks'] - invocation.append('-p=' + build_path) - if args.checks: - invocation.append('-checks=' + args.checks) - invocation.append('-') - if args.quiet: - # Even with -quiet we still want to check if we can call clang-tidy. - with open(os.devnull, 'w') as dev_null: - subprocess.check_call(invocation, stdout=dev_null) + parser = argparse.ArgumentParser( + description='Runs clang-tidy over all files ' + 'in a compilation database. Requires ' + 'clang-tidy and clang-apply-replacements in ' + '$PATH.' + ) + parser.add_argument('-clang-tidy-binary', metavar='PATH', default='clang-tidy', help='path to clang-tidy binary') + parser.add_argument( + '-clang-apply-replacements-binary', + metavar='PATH', + default='clang-apply-replacements', + help='path to clang-apply-replacements binary', + ) + parser.add_argument('-checks', default=None, help='checks filter, when not specified, use clang-tidy ' 'default') + parser.add_argument( + '-config', + default=None, + help='Specifies a configuration in YAML/JSON format: ' + ' -config="{Checks: \'*\', ' + ' CheckOptions: [{key: x, ' + ' value: y}]}" ' + 'When the value is empty, clang-tidy will ' + 'attempt to find a file named .clang-tidy for ' + 'each source file in its parent directories.', + ) + parser.add_argument( + '-header-filter', + default=None, + help='regular expression matching the names of the ' + 'headers to output diagnostics from. Diagnostics from ' + 'the main file of each translation unit are always ' + 'displayed.', + ) + if yaml: + parser.add_argument( + '-export-fixes', + metavar='filename', + dest='export_fixes', + help='Create a yaml file to store suggested fixes in, ' + 'which can be applied with clang-apply-replacements.', + ) + parser.add_argument('-j', type=int, default=0, help='number of tidy instances to be run in parallel.') + parser.add_argument('files', nargs='*', default=['.*'], help='files to be processed (regex on path)') + parser.add_argument('-fix', action='store_true', help='apply fix-its') + parser.add_argument('-format', action='store_true', help='Reformat code ' 'after applying fixes') + parser.add_argument('-style', default='file', help='The style of reformat ' 'code after applying fixes') + parser.add_argument('-p', dest='build_path', help='Path used to read a compile command database.') + parser.add_argument( + '-extra-arg', + dest='extra_arg', + action='append', + default=[], + help='Additional argument to append to the compiler ' 'command line.', + ) + parser.add_argument( + '-extra-arg-before', + dest='extra_arg_before', + action='append', + default=[], + help='Additional argument to prepend to the compiler ' 'command line.', + ) + parser.add_argument('-quiet', action='store_true', help='Run clang-tidy in quiet mode') + args = parser.parse_args() + + db_path = 'compile_commands.json' + + if args.build_path is not None: + build_path = args.build_path else: - subprocess.check_call(invocation) - except: - print("Unable to run clang-tidy.", file=sys.stderr) - sys.exit(1) - - # Load the database and extract all files. - database = json.load(open(os.path.join(build_path, db_path))) - files = [make_absolute(entry['file'], entry['directory']) - for entry in database] - - max_task = args.j - if max_task == 0: - max_task = multiprocessing.cpu_count() - - tmpdir = None - if args.fix or (yaml and args.export_fixes): - check_clang_apply_replacements_binary(args) - tmpdir = tempfile.mkdtemp() - - # Build up a big regexy filter from all command line arguments. - file_name_re = re.compile('|'.join(args.files)) - - return_code = 0 - try: - # Spin up a bunch of tidy-launching threads. - task_queue = queue.Queue(max_task) - # List of files with a non-zero return code. - failed_files = [] - lock = threading.Lock() - for _ in range(max_task): - t = threading.Thread(target=run_tidy, - args=(args, tmpdir, build_path, task_queue, lock, failed_files)) - t.daemon = True - t.start() - - # Fill the queue with files. - for name in files: - if file_name_re.search(name): - task_queue.put(name) - - # Wait for all threads to be done. - task_queue.join() - if len(failed_files): - return_code = 1 - - except KeyboardInterrupt: - # This is a sad hack. Unfortunately subprocess goes - # bonkers with ctrl-c and we start forking merrily. - print('\nCtrl-C detected, goodbye.') - if tmpdir: - shutil.rmtree(tmpdir) - os.kill(0, 9) + # Find our database + build_path = find_compilation_database(db_path) - if yaml and args.export_fixes: - print('Writing fixes to ' + args.export_fixes + ' ...') try: - merge_replacement_files(tmpdir, args.export_fixes) + invocation = [args.clang_tidy_binary, '-list-checks'] + invocation.append('-p=' + build_path) + if args.checks: + invocation.append('-checks=' + args.checks) + invocation.append('-') + if args.quiet: + # Even with -quiet we still want to check if we can call clang-tidy. + with open(os.devnull, 'w') as dev_null: + subprocess.check_call(invocation, stdout=dev_null) + else: + subprocess.check_call(invocation) except: - print('Error exporting fixes.\n', file=sys.stderr) - traceback.print_exc() - return_code=1 + print("Unable to run clang-tidy.", file=sys.stderr) + sys.exit(1) + + # Load the database and extract all files. + database = json.load(open(os.path.join(build_path, db_path))) + files = [make_absolute(entry['file'], entry['directory']) for entry in database] - if args.fix: - print('Applying fixes ...') + max_task = args.j + if max_task == 0: + max_task = multiprocessing.cpu_count() + + tmpdir = None + if args.fix or (yaml and args.export_fixes): + check_clang_apply_replacements_binary(args) + tmpdir = tempfile.mkdtemp() + + # Build up a big regexy filter from all command line arguments. + file_name_re = re.compile('|'.join(args.files)) + + return_code = 0 try: - apply_fixes(args, tmpdir) - except: - print('Error applying fixes.\n', file=sys.stderr) - traceback.print_exc() - return_code=1 + # Spin up a bunch of tidy-launching threads. + task_queue = queue.Queue(max_task) + # List of files with a non-zero return code. + failed_files = [] + lock = threading.Lock() + for _ in range(max_task): + t = threading.Thread(target=run_tidy, args=(args, tmpdir, build_path, task_queue, lock, failed_files)) + t.daemon = True + t.start() + + # Fill the queue with files. + for name in files: + if file_name_re.search(name): + task_queue.put(name) + + # Wait for all threads to be done. + task_queue.join() + if len(failed_files): + return_code = 1 + + except KeyboardInterrupt: + # This is a sad hack. Unfortunately subprocess goes + # bonkers with ctrl-c and we start forking merrily. + print('\nCtrl-C detected, goodbye.') + if tmpdir: + shutil.rmtree(tmpdir) + os.kill(0, 9) + + if yaml and args.export_fixes: + print('Writing fixes to ' + args.export_fixes + ' ...') + try: + merge_replacement_files(tmpdir, args.export_fixes) + except: + print('Error exporting fixes.\n', file=sys.stderr) + traceback.print_exc() + return_code = 1 + + if args.fix: + print('Applying fixes ...') + try: + apply_fixes(args, tmpdir) + except: + print('Error applying fixes.\n', file=sys.stderr) + traceback.print_exc() + return_code = 1 + + if tmpdir: + shutil.rmtree(tmpdir) + sys.exit(return_code) - if tmpdir: - shutil.rmtree(tmpdir) - sys.exit(return_code) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/scripts/run_fuzzer.py b/scripts/run_fuzzer.py index 1a7130a6b045..86bf65537495 100644 --- a/scripts/run_fuzzer.py +++ b/scripts/run_fuzzer.py @@ -43,6 +43,7 @@ git_hash = fuzzer_helper.get_github_hash() + def create_db_script(db): if db == 'alltypes': return 'create table all_types as select * exclude(small_enum, medium_enum, large_enum) from test_all_types();' @@ -51,6 +52,7 @@ def create_db_script(db): else: raise Exception("Unknown database creation script") + def run_fuzzer_script(fuzzer): if fuzzer == 'sqlsmith': return "call sqlsmith(max_queries=${MAX_QUERIES}, seed=${SEED}, verbose_output=1, log='${LAST_LOG_FILE}', complete_log='${COMPLETE_LOG_FILE}');" @@ -61,6 +63,7 @@ def run_fuzzer_script(fuzzer): else: raise Exception("Unknown fuzzer type") + def get_fuzzer_name(fuzzer): if fuzzer == 'sqlsmith': return 'SQLSmith' @@ -71,6 +74,7 @@ def get_fuzzer_name(fuzzer): else: return 'Unknown' + def run_shell_command(cmd): command = [shell, '--batch', '-init', '/dev/null'] @@ -87,13 +91,21 @@ def run_shell_command(cmd): last_query_log_file = 'sqlsmith.log' complete_log_file = 'sqlsmith.complete.log' -print(f'''========================================== +print( + f'''========================================== RUNNING {fuzzer} on {db} -==========================================''') +==========================================''' +) load_script = create_db_script(db) fuzzer_name = get_fuzzer_name(fuzzer) -fuzzer = run_fuzzer_script(fuzzer).replace('${MAX_QUERIES}', str(max_queries)).replace('${LAST_LOG_FILE}', last_query_log_file).replace('${COMPLETE_LOG_FILE}', complete_log_file).replace('${SEED}', str(seed)) +fuzzer = ( + run_fuzzer_script(fuzzer) + .replace('${MAX_QUERIES}', str(max_queries)) + .replace('${LAST_LOG_FILE}', last_query_log_file) + .replace('${COMPLETE_LOG_FILE}', complete_log_file) + .replace('${SEED}', str(seed)) +) print(load_script) print(fuzzer) @@ -104,9 +116,11 @@ def run_shell_command(cmd): (stdout, stderr, returncode) = run_shell_command(cmd) -print(f'''========================================== +print( + f'''========================================== FINISHED RUNNING -==========================================''') +==========================================''' +) print("============== STDOUT ================") print(stdout) print("============== STDERR =================") @@ -151,7 +165,10 @@ def run_shell_command(cmd): # check if this is a duplicate issue if error_msg in current_errors: print("Skip filing duplicate issue") - print("Issue already exists: https://github.com/duckdb/duckdb-fuzzer/issues/" + str(current_errors[error_msg]['number'])) + print( + "Issue already exists: https://github.com/duckdb/duckdb-fuzzer/issues/" + + str(current_errors[error_msg]['number']) + ) exit(0) print(last_query) diff --git a/scripts/run_sqlancer.py b/scripts/run_sqlancer.py index 8ee27b456898..61d4a13052d9 100644 --- a/scripts/run_sqlancer.py +++ b/scripts/run_sqlancer.py @@ -39,7 +39,7 @@ exit(1) if seed is None: - seed = random.randint(0, 2 ** 30) + seed = random.randint(0, 2**30) git_hash = fuzzer_helper.get_github_hash() @@ -141,7 +141,10 @@ # check if this is a duplicate issue if error_msg in current_errors: print("Skip filing duplicate issue") - print("Issue already exists: https://github.com/duckdb/duckdb-fuzzer/issues/" + str(current_errors[error_msg]['number'])) + print( + "Issue already exists: https://github.com/duckdb/duckdb-fuzzer/issues/" + + str(current_errors[error_msg]['number']) + ) exit(0) -fuzzer_helper.file_issue(reduced_test_case, error_msg, "SQLancer", seed, git_hash) \ No newline at end of file +fuzzer_helper.file_issue(reduced_test_case, error_msg, "SQLancer", seed, git_hash) diff --git a/scripts/run_test_list.py b/scripts/run_test_list.py index 2810bedb72eb..3ba34263e254 100644 --- a/scripts/run_test_list.py +++ b/scripts/run_test_list.py @@ -6,55 +6,60 @@ # wheth no_exit = False for i in range(len(sys.argv)): - if sys.argv[i] == '--no-exit': - no_exit = True - del sys.argv[i] - i-=1 + if sys.argv[i] == '--no-exit': + no_exit = True + del sys.argv[i] + i -= 1 if len(sys.argv) < 2: - print("Expected usage: python3 scripts/run_test_list.py build/debug/test/unittest [--no-exit]") - exit(1) + print("Expected usage: python3 scripts/run_test_list.py build/debug/test/unittest [--no-exit]") + exit(1) unittest_program = sys.argv[1] extra_args = [] if len(sys.argv) > 2: - extra_args = [sys.argv[2]] + extra_args = [sys.argv[2]] test_cases = [] for line in sys.stdin: - if len(line.strip()) == 0: - continue - splits = line.rsplit('\t', 1) - test_cases.append(splits[0]) + if len(line.strip()) == 0: + continue + splits = line.rsplit('\t', 1) + test_cases.append(splits[0]) test_count = len(test_cases) return_code = 0 for test_number in range(test_count): - sys.stdout.write("[" + str(test_number) + "/" + str(test_count) + "]: " + test_cases[test_number]) - sys.stdout.flush() - res = subprocess.run([unittest_program, test_cases[test_number]], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout = res.stdout.decode('utf8') - stderr = res.stderr.decode('utf8') - if res.returncode is not None and res.returncode != 0: - print("FAILURE IN RUNNING TEST") - print("""-------------------- + sys.stdout.write("[" + str(test_number) + "/" + str(test_count) + "]: " + test_cases[test_number]) + sys.stdout.flush() + res = subprocess.run([unittest_program, test_cases[test_number]], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout = res.stdout.decode('utf8') + stderr = res.stderr.decode('utf8') + if res.returncode is not None and res.returncode != 0: + print("FAILURE IN RUNNING TEST") + print( + """-------------------- RETURNCODE -------------------- -""") - print(res.returncode) - print("""-------------------- +""" + ) + print(res.returncode) + print( + """-------------------- STDOUT -------------------- -""") - print(stdout) - print("""-------------------- +""" + ) + print(stdout) + print( + """-------------------- STDERR -------------------- -""") - print(stderr) - return_code = 1 - if not no_exit: - break +""" + ) + print(stderr) + return_code = 1 + if not no_exit: + break exit(return_code) - diff --git a/scripts/run_tests_one_by_one.py b/scripts/run_tests_one_by_one.py index bc57be284406..d7e2e654aa19 100644 --- a/scripts/run_tests_one_by_one.py +++ b/scripts/run_tests_one_by_one.py @@ -7,78 +7,83 @@ no_exit = False profile = False for i in range(len(sys.argv)): - if sys.argv[i] == '--no-exit': - no_exit = True - del sys.argv[i] - i-=1 - elif sys.argv[i] == '--profile': - profile = True - del sys.argv[i] - i-=1 + if sys.argv[i] == '--no-exit': + no_exit = True + del sys.argv[i] + i -= 1 + elif sys.argv[i] == '--profile': + profile = True + del sys.argv[i] + i -= 1 if len(sys.argv) < 2: - print("Expected usage: python3 scripts/run_tests_one_by_one.py build/debug/test/unittest [--no-exit] [--profile]") - exit(1) + print("Expected usage: python3 scripts/run_tests_one_by_one.py build/debug/test/unittest [--no-exit] [--profile]") + exit(1) unittest_program = sys.argv[1] extra_args = [] if len(sys.argv) > 2: - extra_args = [sys.argv[2]] + extra_args = [sys.argv[2]] proc = subprocess.Popen([unittest_program, '-l'] + extra_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout = proc.stdout.read().decode('utf8') stderr = proc.stderr.read().decode('utf8') if proc.returncode is not None and proc.returncode != 0: - print("Failed to run program " + unittest_program) - print(proc.returncode) - print(stdout) - print(stderr) - exit(1) + print("Failed to run program " + unittest_program) + print(proc.returncode) + print(stdout) + print(stderr) + exit(1) test_cases = [] first_line = True for line in stdout.splitlines(): - if first_line: - first_line = False - continue - if len(line.strip()) == 0: - continue - splits = line.rsplit('\t', 1) - test_cases.append(splits[0]) + if first_line: + first_line = False + continue + if len(line.strip()) == 0: + continue + splits = line.rsplit('\t', 1) + test_cases.append(splits[0]) test_count = len(test_cases) return_code = 0 for test_number in range(test_count): - if not profile: - print("[" + str(test_number) + "/" + str(test_count) + "]: " + test_cases[test_number]) - start = time.time() - res = subprocess.run([unittest_program, test_cases[test_number]], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout = res.stdout.decode('utf8') - stderr = res.stderr.decode('utf8') - end = time.time() - if profile: - print(f'{test_cases[test_number]} {end - start}') - if res.returncode is not None and res.returncode != 0: - print("FAILURE IN RUNNING TEST") - print("""-------------------- + if not profile: + print("[" + str(test_number) + "/" + str(test_count) + "]: " + test_cases[test_number]) + start = time.time() + res = subprocess.run([unittest_program, test_cases[test_number]], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout = res.stdout.decode('utf8') + stderr = res.stderr.decode('utf8') + end = time.time() + if profile: + print(f'{test_cases[test_number]} {end - start}') + if res.returncode is not None and res.returncode != 0: + print("FAILURE IN RUNNING TEST") + print( + """-------------------- RETURNCODE -------------------- -""") - print(res.returncode) - print("""-------------------- +""" + ) + print(res.returncode) + print( + """-------------------- STDOUT -------------------- -""") - print(stdout) - print("""-------------------- +""" + ) + print(stdout) + print( + """-------------------- STDERR -------------------- -""") - print(stderr) - return_code = 1 - if not no_exit: - break +""" + ) + print(stderr) + return_code = 1 + if not no_exit: + break exit(return_code) - diff --git a/scripts/runsqlsmith.py b/scripts/runsqlsmith.py index 161c7247d8f7..9569750349c7 100644 --- a/scripts/runsqlsmith.py +++ b/scripts/runsqlsmith.py @@ -1,4 +1,3 @@ - # run SQL smith and collect breaking queries import os import re @@ -16,38 +15,38 @@ c = con.cursor() if len(sys.argv) == 2: - if sys.argv[1] == '--export': - export_queries = True - elif sys.argv[1] == '--reset': - c.execute('DROP TABLE IF EXISTS sqlsmith_errors') - else: - print('Unknown query option ' + sys.argv[1]) - exit(1) + if sys.argv[1] == '--export': + export_queries = True + elif sys.argv[1] == '--reset': + c.execute('DROP TABLE IF EXISTS sqlsmith_errors') + else: + print('Unknown query option ' + sys.argv[1]) + exit(1) if export_queries: - c.execute('SELECT query FROM sqlsmith_errors') - results = c.fetchall() - for fname in os.listdir(sqlsmith_test_dir): - os.remove(os.path.join(sqlsmith_test_dir, fname)) + c.execute('SELECT query FROM sqlsmith_errors') + results = c.fetchall() + for fname in os.listdir(sqlsmith_test_dir): + os.remove(os.path.join(sqlsmith_test_dir, fname)) + + for i in range(len(results)): + with open(os.path.join(sqlsmith_test_dir, 'sqlsmith-%d.sql' % (i + 1)), 'w+') as f: + f.write(results[i][0] + "\n") + exit(0) - for i in range(len(results)): - with open(os.path.join(sqlsmith_test_dir, 'sqlsmith-%d.sql' % (i + 1)), 'w+') as f: - f.write(results[i][0] + "\n") - exit(0) def run_sqlsmith(): - subprocess.call(['build/debug/third_party/sqlsmith/sqlsmith', '--duckdb=:memory:']) + subprocess.call(['build/debug/third_party/sqlsmith/sqlsmith', '--duckdb=:memory:']) c.execute('CREATE TABLE IF NOT EXISTS sqlsmith_errors(query VARCHAR)') while True: - # run SQL smith - run_sqlsmith() - # get the breaking query - with open_utf8('sqlsmith.log', 'r') as f: - text = re.sub('[ \t\n]+', ' ', f.read()) - - c.execute('INSERT INTO sqlsmith_errors VALUES (?)', (text,)) - con.commit() - + # run SQL smith + run_sqlsmith() + # get the breaking query + with open_utf8('sqlsmith.log', 'r') as f: + text = re.sub('[ \t\n]+', ' ', f.read()) + + c.execute('INSERT INTO sqlsmith_errors VALUES (?)', (text,)) + con.commit() diff --git a/scripts/test_compile.py b/scripts/test_compile.py index 51827c6ef391..19a55dc0cf6a 100644 --- a/scripts/test_compile.py +++ b/scripts/test_compile.py @@ -16,62 +16,71 @@ # by default, we resume if the previous test_compile was run on the same commit hash as this one resume = RESUME_AUTO for arg in sys.argv: - if arg == '--resume': - resume = RESUME_ALWAYS - elif arg == '--restart': - cache = RESUME_NEVER + if arg == '--resume': + resume = RESUME_ALWAYS + elif arg == '--restart': + cache = RESUME_NEVER if resume == RESUME_NEVER: - try: - os.remove(cache_file) - except: - pass + try: + os.remove(cache_file) + except: + pass + def get_git_hash(): - proc = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE) - return proc.stdout.read().strip() + proc = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE) + return proc.stdout.read().strip() + current_hash = get_git_hash() # load the cache, and check the commit hash try: - with open(cache_file, 'rb') as cf: - cache = pickle.load(cf) - if resume == RESUME_AUTO: - # auto resume, check - if cache['commit_hash'] != current_hash: - cache = {} + with open(cache_file, 'rb') as cf: + cache = pickle.load(cf) + if resume == RESUME_AUTO: + # auto resume, check + if cache['commit_hash'] != current_hash: + cache = {} except: - cache = {} + cache = {} cache['commit_hash'] = current_hash + def try_compilation(fpath, cache): - if fpath in cache: - return - print(fpath) - - cmd = 'clang++ -std=c++11 -Wno-deprecated -Wno-writable-strings -S -MMD -MF dependencies.d -o deps.s ' + fpath + ' ' + ' '.join(["-I" + x for x in amalgamation.include_paths]) - ret = os.system(cmd) - if ret != 0: - raise Exception('Failed compilation of file "' + fpath + '"!\n Command: ' + cmd) - cache[fpath] = True - with open(cache_file, 'wb') as cf: - pickle.dump(cache, cf) + if fpath in cache: + return + print(fpath) + + cmd = ( + 'clang++ -std=c++11 -Wno-deprecated -Wno-writable-strings -S -MMD -MF dependencies.d -o deps.s ' + + fpath + + ' ' + + ' '.join(["-I" + x for x in amalgamation.include_paths]) + ) + ret = os.system(cmd) + if ret != 0: + raise Exception('Failed compilation of file "' + fpath + '"!\n Command: ' + cmd) + cache[fpath] = True + with open(cache_file, 'wb') as cf: + pickle.dump(cache, cf) + def compile_dir(dir, cache): - files = os.listdir(dir) - files.sort() - for fname in files: - if fname in amalgamation.excluded_compilation_files or fname in ignored_files: - continue - fpath = os.path.join(dir, fname) - if os.path.isdir(fpath): - compile_dir(fpath, cache) - elif fname.endswith('.cpp') or fname.endswith('.hpp') or fname.endswith('.c') or fname.endswith('.cc'): - try_compilation(fpath, cache) + files = os.listdir(dir) + files.sort() + for fname in files: + if fname in amalgamation.excluded_compilation_files or fname in ignored_files: + continue + fpath = os.path.join(dir, fname) + if os.path.isdir(fpath): + compile_dir(fpath, cache) + elif fname.endswith('.cpp') or fname.endswith('.hpp') or fname.endswith('.c') or fname.endswith('.cc'): + try_compilation(fpath, cache) + # compile all files in the src directory (including headers!) individually for cdir in amalgamation.compile_directories: - compile_dir(cdir, cache) - + compile_dir(cdir, cache) diff --git a/scripts/test_vector_sizes.py b/scripts/test_vector_sizes.py index eb399e8ca3fe..11d244ce6c8f 100644 --- a/scripts/test_vector_sizes.py +++ b/scripts/test_vector_sizes.py @@ -6,23 +6,30 @@ current_dir = os.getcwd() build_dir = os.path.join(os.getcwd(), 'build', 'release') + def execute_system_command(cmd): - print(cmd) - retcode = os.system(cmd) - print(retcode) - if retcode != 0: - raise Exception + print(cmd) + retcode = os.system(cmd) + print(retcode) + if retcode != 0: + raise Exception + def replace_in_file(fname, regex, replace): - with open_utf8(fname, 'r') as f: - contents = f.read() - contents = re.sub(regex, replace, contents) - with open_utf8(fname, 'w+') as f: - f.write(contents) + with open_utf8(fname, 'r') as f: + contents = f.read() + contents = re.sub(regex, replace, contents) + with open_utf8(fname, 'w+') as f: + f.write(contents) + for vector_size in vector_sizes: - print("TESTING STANDARD_VECTOR_SIZE=%d" % (vector_size,)) - replace_in_file('src/include/duckdb/common/vector_size.hpp', '#define STANDARD_VECTOR_SIZE \d+', '#define STANDARD_VECTOR_SIZE %d' % (vector_size,)) - execute_system_command('rm -rf build') - execute_system_command('make relassert') - execute_system_command('python3 scripts/run_tests_one_by_one.py build/relassert/test/unittest') + print("TESTING STANDARD_VECTOR_SIZE=%d" % (vector_size,)) + replace_in_file( + 'src/include/duckdb/common/vector_size.hpp', + r'#define STANDARD_VECTOR_SIZE \d+', + '#define STANDARD_VECTOR_SIZE %d' % (vector_size,), + ) + execute_system_command('rm -rf build') + execute_system_command('make relassert') + execute_system_command('python3 scripts/run_tests_one_by_one.py build/relassert/test/unittest') diff --git a/scripts/test_zero_initialize.py b/scripts/test_zero_initialize.py index 250c220b580b..55f8580277f0 100644 --- a/scripts/test_zero_initialize.py +++ b/scripts/test_zero_initialize.py @@ -3,11 +3,20 @@ import subprocess import shutil -parser = argparse.ArgumentParser(description='''Runs storage tests both with explicit one-initialization and with explicit zero-initialization, and verifies that the final storage files are the same. -The purpose of this is to verify all memory is correctly initialized before writing to disk - which prevents leaking of in-memory data in storage files by writing uninitialized memory to disk.''') +parser = argparse.ArgumentParser( + description='''Runs storage tests both with explicit one-initialization and with explicit zero-initialization, and verifies that the final storage files are the same. +The purpose of this is to verify all memory is correctly initialized before writing to disk - which prevents leaking of in-memory data in storage files by writing uninitialized memory to disk.''' +) parser.add_argument('--unittest', default='build/debug/test/unittest', help='path to unittest', dest='unittest') -parser.add_argument('--zero_init_dir', default='test_zero_init_db', help='directory to write zero-initialized databases to', dest='zero_init_dir') -parser.add_argument('--standard_dir', default='test_standard_db', help='directory to write regular databases to', dest='standard_dir') +parser.add_argument( + '--zero_init_dir', + default='test_zero_init_db', + help='directory to write zero-initialized databases to', + dest='zero_init_dir', +) +parser.add_argument( + '--standard_dir', default='test_standard_db', help='directory to write regular databases to', dest='standard_dir' +) args = parser.parse_args() @@ -17,9 +26,10 @@ 'test/sql/storage/test_store_nulls_strings.test', 'test/sql/storage/test_store_null_updates.test', 'test/sql/storage/test_store_integers.test', - 'test/sql/storage/test_update_delete_string.test' + 'test/sql/storage/test_update_delete_string.test', ] + def run_test(args): res = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout = res.stdout.decode('utf8').strip() @@ -35,10 +45,12 @@ def run_test(args): print("---------------------") exit(1) + header_size = 4096 * 3 block_size = 262144 checksum_size = 8 + def handle_error(i, standard_db, zero_init_db, standard_data, zero_data): print("------------------------------------------------------------------") print(f"FAIL - Mismatch between one-initialized and zero-initialized databases at byte position {i}") @@ -50,12 +62,16 @@ def handle_error(i, standard_db, zero_init_db, standard_data, zero_data): else: byte_pos = (i - header_size) % block_size if byte_pos >= checksum_size: - print(f"This byte is in block id {(i - header_size) // block_size} at byte position {byte_pos - checksum_size} (position {byte_pos} including the block checksum)") + print( + f"This byte is in block id {(i - header_size) // block_size} at byte position {byte_pos - checksum_size} (position {byte_pos} including the block checksum)" + ) else: print(f"This byte is in block id {(i - header_size) // block_size} at byte position {byte_pos}") print("This is in the checksum part of the block") print("------------------------------------------------------------------") - print("This error likely means that memory was not correctly zero-initialized in a block before being written out to disk.") + print( + "This error likely means that memory was not correctly zero-initialized in a block before being written out to disk." + ) def compare_database(standard_db, zero_init_db): @@ -64,7 +80,9 @@ def compare_database(standard_db, zero_init_db): with open(zero_init_db, 'rb') as f: zero_data = f.read() if len(standard_data) != len(zero_data): - print(f"FAIL - Length mismatch between database {standard_db} ({str(len(standard_data))}) and {zero_init_db} ({str(len(zero_data))})") + print( + f"FAIL - Length mismatch between database {standard_db} ({str(len(standard_data))}) and {zero_init_db} ({str(len(zero_data))})" + ) return False found_error = None for i in range(len(standard_data)): @@ -85,13 +103,16 @@ def compare_database(standard_db, zero_init_db): print("Success!") return True + def compare_files(standard_dir, zero_init_dir): standard_list = os.listdir(standard_dir) zero_init_list = os.listdir(zero_init_dir) standard_list.sort() zero_init_list.sort() if standard_list != zero_init_list: - print(f"FAIL - Directories contain mismatching files (standard - {str(standard_list)}, zero init - {str(zero_init_list)})") + print( + f"FAIL - Directories contain mismatching files (standard - {str(standard_list)}, zero init - {str(zero_init_list)})" + ) return False if len(standard_list) == 0: print("FAIL - Directory is empty!") @@ -110,6 +131,7 @@ def clear_directories(directories): except FileNotFoundError as e: pass + test_dirs = [args.standard_dir, args.zero_init_dir] success = True @@ -117,7 +139,14 @@ def clear_directories(directories): print(f"Running test {test}") clear_directories(test_dirs) standard_args = [args.unittest, '--test-temp-dir', args.standard_dir, '--one-initialize', '--single-threaded', test] - zero_init_args = [args.unittest, '--test-temp-dir', args.zero_init_dir, '--zero-initialize', '--single-threaded', test] + zero_init_args = [ + args.unittest, + '--test-temp-dir', + args.zero_init_dir, + '--zero-initialize', + '--single-threaded', + test, + ] print(f"Running test in one-initialize mode") run_test(standard_args) print(f"Running test in zero-initialize mode") diff --git a/scripts/try_timeout.py b/scripts/try_timeout.py index e5eddccb7b89..78bd321d9856 100644 --- a/scripts/try_timeout.py +++ b/scripts/try_timeout.py @@ -12,6 +12,7 @@ retries = int(sys.argv[2].replace("--retry=", "")) cmd = sys.argv[3:] + class Command(object): def __init__(self, cmd): self.cmd = cmd @@ -19,6 +20,7 @@ def __init__(self, cmd): def run(self, timeout): self.process = None + def target(): self.process = subprocess.Popen(self.cmd) self.process.communicate() @@ -35,6 +37,7 @@ def target(): return 1 return self.process.returncode + for i in range(retries): print("Attempting to run command \"" + ' '.join(cmd) + '"') command = Command(cmd) diff --git a/scripts/windows_ci.py b/scripts/windows_ci.py index edfb01069734..ace4ea2ed15e 100644 --- a/scripts/windows_ci.py +++ b/scripts/windows_ci.py @@ -1,19 +1,21 @@ - import os common_path = os.path.join('src', 'include', 'duckdb', 'common', 'common.hpp') with open(common_path, 'r') as f: - text = f.read() + text = f.read() -text = text.replace('#pragma once', '''#pragma once +text = text.replace( + '#pragma once', + '''#pragma once #ifdef _WIN32 #ifdef DUCKDB_MAIN_LIBRARY #include "duckdb/common/windows.hpp" #endif #endif -''') +''', +) with open(common_path, 'w+') as f: - f.write(text) + f.write(text)