diff --git a/src/common/enum_util.cpp b/src/common/enum_util.cpp index b2db02f412c9..12a94a2b0f1c 100644 --- a/src/common/enum_util.cpp +++ b/src/common/enum_util.cpp @@ -3802,14 +3802,10 @@ const char* EnumUtil::ToChars(MapInvalidReason value) { switch(value) { case MapInvalidReason::VALID: return "VALID"; - case MapInvalidReason::NULL_KEY_LIST: - return "NULL_KEY_LIST"; case MapInvalidReason::NULL_KEY: return "NULL_KEY"; case MapInvalidReason::DUPLICATE_KEY: return "DUPLICATE_KEY"; - case MapInvalidReason::NULL_VALUE_LIST: - return "NULL_VALUE_LIST"; case MapInvalidReason::NOT_ALIGNED: return "NOT_ALIGNED"; case MapInvalidReason::INVALID_PARAMS: @@ -3824,18 +3820,12 @@ MapInvalidReason EnumUtil::FromString(const char *value) { if (StringUtil::Equals(value, "VALID")) { return MapInvalidReason::VALID; } - if (StringUtil::Equals(value, "NULL_KEY_LIST")) { - return MapInvalidReason::NULL_KEY_LIST; - } if (StringUtil::Equals(value, "NULL_KEY")) { return MapInvalidReason::NULL_KEY; } if (StringUtil::Equals(value, "DUPLICATE_KEY")) { return MapInvalidReason::DUPLICATE_KEY; } - if (StringUtil::Equals(value, "NULL_VALUE_LIST")) { - return MapInvalidReason::NULL_VALUE_LIST; - } if (StringUtil::Equals(value, "NOT_ALIGNED")) { return MapInvalidReason::NOT_ALIGNED; } diff --git a/src/common/types/vector.cpp b/src/common/types/vector.cpp index e61ad46dd259..112dc0de96ab 100644 --- a/src/common/types/vector.cpp +++ b/src/common/types/vector.cpp @@ -2089,10 +2089,6 @@ void MapVector::EvalMapInvalidReason(MapInvalidReason reason) { throw InvalidInputException("Map keys must be unique."); case MapInvalidReason::NULL_KEY: throw InvalidInputException("Map keys can not be NULL."); - case MapInvalidReason::NULL_KEY_LIST: - throw InvalidInputException("The list of map keys must not be NULL."); - case MapInvalidReason::NULL_VALUE_LIST: - throw InvalidInputException("The list of map values must not be NULL."); case MapInvalidReason::NOT_ALIGNED: throw InvalidInputException("The map key list does not align with the map value list."); case MapInvalidReason::INVALID_PARAMS: diff --git a/src/core_functions/scalar/map/map.cpp b/src/core_functions/scalar/map/map.cpp index e27fe3fd6503..ab67475d151b 100644 --- a/src/core_functions/scalar/map/map.cpp +++ b/src/core_functions/scalar/map/map.cpp @@ -21,14 +21,38 @@ static void MapFunctionEmptyInput(Vector &result, const idx_t row_count) { result.Verify(row_count); } +static bool MapIsNull(DataChunk &chunk) { + if (chunk.data.empty()) { + return false; + } + D_ASSERT(chunk.data.size() == 2); + auto &keys = chunk.data[0]; + auto &values = chunk.data[1]; + + if (keys.GetType().id() == LogicalTypeId::SQLNULL) { + return true; + } + if (values.GetType().id() == LogicalTypeId::SQLNULL) { + return true; + } + return false; +} + static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { // internal MAP representation // - LIST-vector that contains STRUCTs as child entries // - STRUCTs have exactly two fields, a key-field, and a value-field // - key names are unique - D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); + + if (MapIsNull(args)) { + auto &validity = FlatVector::Validity(result); + validity.SetInvalid(0); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + return; + } + auto row_count = args.size(); // early-out, if no data @@ -63,13 +87,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { UnifiedVectorFormat result_data; result.ToUnifiedFormat(row_count, result_data); auto result_entries = UnifiedVectorFormat::GetDataNoConst(result_data); - result_data.validity.SetAllValid(row_count); + + auto &result_validity = FlatVector::Validity(result); // get the resulting size of the key/value child lists idx_t result_child_size = 0; for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { auto keys_idx = keys_data.sel->get_index(row_idx); - if (!keys_data.validity.RowIsValid(keys_idx)) { + auto values_idx = values_data.sel->get_index(row_idx); + if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { continue; } auto keys_entry = keys_entries[keys_idx]; @@ -87,22 +113,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { auto values_idx = values_data.sel->get_index(row_idx); auto result_idx = result_data.sel->get_index(row_idx); - // empty map - if (!keys_data.validity.RowIsValid(keys_idx) && !values_data.validity.RowIsValid(values_idx)) { - result_entries[result_idx] = list_entry_t(); + // NULL MAP + if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { + result_validity.SetInvalid(row_idx); continue; } auto keys_entry = keys_entries[keys_idx]; auto values_entry = values_entries[values_idx]; - // validity checks - if (!keys_data.validity.RowIsValid(keys_idx)) { - MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_KEY_LIST); - } - if (!values_data.validity.RowIsValid(values_idx)) { - MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_VALUE_LIST); - } if (keys_entry.length != values_entry.length) { MapVector::EvalMapInvalidReason(MapInvalidReason::NOT_ALIGNED); } @@ -160,8 +179,19 @@ static unique_ptr MapBind(ClientContext &, ScalarFunction &bound_f MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); } - // bind an empty MAP + bool is_null = false; if (arguments.empty()) { + is_null = true; + } + if (!is_null) { + auto key_id = arguments[0]->return_type.id(); + auto value_id = arguments[1]->return_type.id(); + if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) { + is_null = true; + } + } + + if (is_null) { bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); return make_uniq(bound_function.return_type); } diff --git a/src/function/table/arrow_conversion.cpp b/src/function/table/arrow_conversion.cpp index 78d4cca859f6..c1759ef8484c 100644 --- a/src/function/table/arrow_conversion.cpp +++ b/src/function/table/arrow_conversion.cpp @@ -317,9 +317,6 @@ static void ArrowToDuckDBMapVerify(Vector &vector, idx_t count) { case MapInvalidReason::NULL_KEY: { throw InvalidInputException("Arrow map contains NULL as map key, which isn't supported by DuckDB map type"); } - case MapInvalidReason::NULL_KEY_LIST: { - throw InvalidInputException("Arrow map contains NULL as key list, which isn't supported by DuckDB map type"); - } default: { throw InternalException("MapInvalidReason not implemented"); } diff --git a/src/include/duckdb/common/types/vector.hpp b/src/include/duckdb/common/types/vector.hpp index b0786597a662..49cb9111c464 100644 --- a/src/include/duckdb/common/types/vector.hpp +++ b/src/include/duckdb/common/types/vector.hpp @@ -464,15 +464,7 @@ struct FSSTVector { DUCKDB_API static idx_t GetCount(Vector &vector); }; -enum class MapInvalidReason : uint8_t { - VALID, - NULL_KEY_LIST, - NULL_KEY, - DUPLICATE_KEY, - NULL_VALUE_LIST, - NOT_ALIGNED, - INVALID_PARAMS -}; +enum class MapInvalidReason : uint8_t { VALID, NULL_KEY, DUPLICATE_KEY, NOT_ALIGNED, INVALID_PARAMS }; struct MapVector { DUCKDB_API static const Vector &GetKeys(const Vector &vector); diff --git a/test/sql/types/map/map_null.test b/test/sql/types/map/map_null.test new file mode 100644 index 000000000000..68dc58be808a --- /dev/null +++ b/test/sql/types/map/map_null.test @@ -0,0 +1,69 @@ +# name: test/sql/types/map/map_null.test +# group: [map] + +statement ok +pragma enable_verification; + +query I +select map(NULL::INT[], [1,2,3]) +---- +NULL + +query I +select map(NULL, [1,2,3]) +---- +NULL + +query I +select map(NULL, NULL) +---- +NULL + +query I +select map(NULL, [1,2,3]) IS NULL +---- +true + +query I +select map([1,2,3], NULL) +---- +NULL + +query I +select map([1,2,3], NULL::INT[]) +---- +NULL + +query I +SELECT * FROM ( VALUES + (MAP(NULL, NULL)), + (MAP(NULL::INT[], NULL::INT[])), + (MAP([1,2,3], [1,2,3])) +) +---- +NULL +NULL +{1=1, 2=2, 3=3} + +query I +select MAP(a, b) FROM ( VALUES + (NULL, ['b', 'c']), + (NULL::INT[], NULL), + (NULL::INT[], NULL::VARCHAR[]), + (NULL::INT[], ['a', 'b', 'c']), + (NULL, ['longer string than inlined', 'smol']), + (NULL, NULL), + ([1,2,3], NULL), + ([1,2,3], ['z', 'y', 'x']), + ([1,2,3], NULL::VARCHAR[]), +) t(a, b) +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +{1=z, 2=y, 3=x} +NULL diff --git a/test/sql/types/nested/map/map_error.test b/test/sql/types/nested/map/map_error.test index 315a1620ea7a..f67c19d4ead2 100644 --- a/test/sql/types/nested/map/map_error.test +++ b/test/sql/types/nested/map/map_error.test @@ -75,10 +75,11 @@ CREATE TABLE null_keys_list (k INT[], v INT[]); statement ok INSERT INTO null_keys_list VALUES ([1], [2]), (NULL, [4]); -statement error +query I SELECT MAP(k, v) FROM null_keys_list; ---- -The list of map keys must not be NULL. +{1=2} +NULL statement ok CREATE TABLE null_values_list (k INT[], v INT[]); @@ -86,7 +87,8 @@ CREATE TABLE null_values_list (k INT[], v INT[]); statement ok INSERT INTO null_values_list VALUES ([1], [2]), ([4], NULL); -statement error +query I SELECT MAP(k, v) FROM null_values_list; ---- -The list of map values must not be NULL. \ No newline at end of file +{1=2} +NULL diff --git a/test/sql/types/nested/map/test_map_subscript.test b/test/sql/types/nested/map/test_map_subscript.test index f75482857dad..8ad48d29e48b 100644 --- a/test/sql/types/nested/map/test_map_subscript.test +++ b/test/sql/types/nested/map/test_map_subscript.test @@ -2,6 +2,9 @@ # description: Test cardinality function for maps # group: [map] +statement ok +pragma enable_verification + # Single element on map query I select m[1] from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as m) as T diff --git a/tools/pythonpkg/src/native/python_conversion.cpp b/tools/pythonpkg/src/native/python_conversion.cpp index cbdbb9f57f00..133d3fb768f1 100644 --- a/tools/pythonpkg/src/native/python_conversion.cpp +++ b/tools/pythonpkg/src/native/python_conversion.cpp @@ -37,6 +37,20 @@ vector TransformStructKeys(py::handle keys, idx_t size, const LogicalTyp return res; } +static bool IsValidMapComponent(const py::handle &component) { + // The component is either NULL + if (py::none().is(component)) { + return true; + } + if (!py::hasattr(component, "__getitem__")) { + return false; + } + if (!py::hasattr(component, "__len__")) { + return false; + } + return true; +} + bool DictionaryHasMapFormat(const PyDictionary &dict) { if (dict.len != 2) { return false; @@ -51,13 +65,19 @@ bool DictionaryHasMapFormat(const PyDictionary &dict) { return false; } - // Dont check for 'py::list' to allow ducktyping - if (!py::hasattr(keys, "__getitem__") || !py::hasattr(keys, "__len__")) { + if (!IsValidMapComponent(keys)) { return false; } - if (!py::hasattr(values, "__getitem__") || !py::hasattr(values, "__len__")) { + if (!IsValidMapComponent(values)) { return false; } + + // If either of the components is NULL, return early + if (py::none().is(keys) || py::none().is(values)) { + return true; + } + + // Verify that both the keys and values are of the same length auto size = py::len(keys); if (size != py::len(values)) { return false; @@ -91,6 +111,11 @@ Value TransformStructFormatDictionaryToMap(const PyDictionary &dict, const Logic if (target_type.id() != LogicalTypeId::MAP) { throw InvalidInputException("Please provide a valid target type for transform from Python to Value"); } + + if (py::none().is(dict.keys) || py::none().is(dict.values)) { + return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL)); + } + auto size = py::len(dict.keys); D_ASSERT(size == py::len(dict.values)); @@ -130,12 +155,18 @@ Value TransformDictionaryToMap(const PyDictionary &dict, const LogicalType &targ auto keys = dict.values.attr("__getitem__")(0); auto values = dict.values.attr("__getitem__")(1); + if (py::none().is(keys) || py::none().is(values)) { + // Either 'key' or 'value' is None, return early with a NULL value + return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL)); + } + auto key_size = py::len(keys); D_ASSERT(key_size == py::len(values)); if (key_size == 0) { // dict == { 'key': [], 'value': [] } return EmptyMapValue(); } + // dict == { 'key': [ ... ], 'value' : [ ... ] } LogicalType key_target = LogicalTypeId::UNKNOWN; LogicalType value_target = LogicalTypeId::UNKNOWN; diff --git a/tools/pythonpkg/src/numpy/numpy_scan.cpp b/tools/pythonpkg/src/numpy/numpy_scan.cpp index 032d3b97f014..b4b1d3dbe276 100644 --- a/tools/pythonpkg/src/numpy/numpy_scan.cpp +++ b/tools/pythonpkg/src/numpy/numpy_scan.cpp @@ -153,8 +153,6 @@ static void VerifyMapConstraints(Vector &vec, idx_t count) { return; case MapInvalidReason::DUPLICATE_KEY: throw InvalidInputException("Dict->Map conversion failed because 'key' list contains duplicates"); - case MapInvalidReason::NULL_KEY_LIST: - throw InvalidInputException("Dict->Map conversion failed because 'key' list is None"); case MapInvalidReason::NULL_KEY: throw InvalidInputException("Dict->Map conversion failed because 'key' list contains None"); default: diff --git a/tools/pythonpkg/src/pandas/analyzer.cpp b/tools/pythonpkg/src/pandas/analyzer.cpp index 508270894403..660d1fb2b3d2 100644 --- a/tools/pythonpkg/src/pandas/analyzer.cpp +++ b/tools/pythonpkg/src/pandas/analyzer.cpp @@ -331,6 +331,10 @@ LogicalType PandasAnalyzer::DictToMap(const PyDictionary &dict, bool &can_conver auto keys = dict.values.attr("__getitem__")(0); auto values = dict.values.attr("__getitem__")(1); + if (py::none().is(keys) || py::none().is(values)) { + return LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); + } + auto key_type = GetListType(keys, can_convert); if (!can_convert) { return EmptyMap(); diff --git a/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py b/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py index 592778146835..9c6ceb06b4fe 100644 --- a/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py +++ b/tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py @@ -183,6 +183,15 @@ def test_map_arrow_to_duckdb(self, duckdb_cursor): ): rel = duckdb.from_arrow(arrow_table).fetchall() + def test_null_map_arrow_to_duckdb(self, duckdb_cursor): + if not can_run: + return + map_type = pa.map_(pa.int32(), pa.int32()) + values = [None, [(5, 42)]] + arrow_table = pa.table({'detail': pa.array(values, map_type)}) + res = duckdb_cursor.sql("select * from arrow_table").fetchall() + assert res == [(None,), ({'key': [5], 'value': [42]},)] + def test_map_arrow_to_pandas(self, duckdb_cursor): if not can_run: return diff --git a/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py b/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py index 0f20f9fe0309..1a07e47fc8f4 100644 --- a/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py +++ b/tools/pythonpkg/tests/fast/pandas/test_df_object_resolution.py @@ -324,7 +324,7 @@ def test_map_duplicate(self, pandas, duckdb_cursor): with pytest.raises( duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains duplicates" ): - converted_col = duckdb_cursor.sql("select * from x").df() + duckdb_cursor.sql("select * from x").show() @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_nullkey(self, pandas, duckdb_cursor): @@ -337,9 +337,8 @@ def test_map_nullkey(self, pandas, duckdb_cursor): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_map_nullkeylist(self, pandas, duckdb_cursor): x = pandas.DataFrame([[{'key': None, 'value': None}]]) - # Isn't actually converted to MAP because isinstance(None, list) != True converted_col = duckdb_cursor.sql("select * from x").df() - duckdb_col = duckdb_cursor.sql("SELECT {key: NULL, value: NULL} as '0'").df() + duckdb_col = duckdb_cursor.sql("SELECT MAP(NULL, NULL) as '0'").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])