Skip to content

Commit

Permalink
Merge pull request duckdb#11730 from Tishj/map_null_behavior_rework
Browse files Browse the repository at this point in the history
[Map] Rework `MAP` creation method behavior when input is NULL
  • Loading branch information
Mytherin authored Apr 20, 2024
2 parents dcdb408 + d54e152 commit bfb8f48
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 52 deletions.
10 changes: 0 additions & 10 deletions src/common/enum_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3802,14 +3802,10 @@ const char* EnumUtil::ToChars<MapInvalidReason>(MapInvalidReason value) {
switch(value) {
case MapInvalidReason::VALID:
return "VALID";
case MapInvalidReason::NULL_KEY_LIST:
return "NULL_KEY_LIST";
case MapInvalidReason::NULL_KEY:
return "NULL_KEY";
case MapInvalidReason::DUPLICATE_KEY:
return "DUPLICATE_KEY";
case MapInvalidReason::NULL_VALUE_LIST:
return "NULL_VALUE_LIST";
case MapInvalidReason::NOT_ALIGNED:
return "NOT_ALIGNED";
case MapInvalidReason::INVALID_PARAMS:
Expand All @@ -3824,18 +3820,12 @@ MapInvalidReason EnumUtil::FromString<MapInvalidReason>(const char *value) {
if (StringUtil::Equals(value, "VALID")) {
return MapInvalidReason::VALID;
}
if (StringUtil::Equals(value, "NULL_KEY_LIST")) {
return MapInvalidReason::NULL_KEY_LIST;
}
if (StringUtil::Equals(value, "NULL_KEY")) {
return MapInvalidReason::NULL_KEY;
}
if (StringUtil::Equals(value, "DUPLICATE_KEY")) {
return MapInvalidReason::DUPLICATE_KEY;
}
if (StringUtil::Equals(value, "NULL_VALUE_LIST")) {
return MapInvalidReason::NULL_VALUE_LIST;
}
if (StringUtil::Equals(value, "NOT_ALIGNED")) {
return MapInvalidReason::NOT_ALIGNED;
}
Expand Down
4 changes: 0 additions & 4 deletions src/common/types/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2089,10 +2089,6 @@ void MapVector::EvalMapInvalidReason(MapInvalidReason reason) {
throw InvalidInputException("Map keys must be unique.");
case MapInvalidReason::NULL_KEY:
throw InvalidInputException("Map keys can not be NULL.");
case MapInvalidReason::NULL_KEY_LIST:
throw InvalidInputException("The list of map keys must not be NULL.");
case MapInvalidReason::NULL_VALUE_LIST:
throw InvalidInputException("The list of map values must not be NULL.");
case MapInvalidReason::NOT_ALIGNED:
throw InvalidInputException("The map key list does not align with the map value list.");
case MapInvalidReason::INVALID_PARAMS:
Expand Down
58 changes: 44 additions & 14 deletions src/core_functions/scalar/map/map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,38 @@ static void MapFunctionEmptyInput(Vector &result, const idx_t row_count) {
result.Verify(row_count);
}

static bool MapIsNull(DataChunk &chunk) {
if (chunk.data.empty()) {
return false;
}
D_ASSERT(chunk.data.size() == 2);
auto &keys = chunk.data[0];
auto &values = chunk.data[1];

if (keys.GetType().id() == LogicalTypeId::SQLNULL) {
return true;
}
if (values.GetType().id() == LogicalTypeId::SQLNULL) {
return true;
}
return false;
}

static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) {

// internal MAP representation
// - LIST-vector that contains STRUCTs as child entries
// - STRUCTs have exactly two fields, a key-field, and a value-field
// - key names are unique

D_ASSERT(result.GetType().id() == LogicalTypeId::MAP);

if (MapIsNull(args)) {
auto &validity = FlatVector::Validity(result);
validity.SetInvalid(0);
result.SetVectorType(VectorType::CONSTANT_VECTOR);
return;
}

auto row_count = args.size();

// early-out, if no data
Expand Down Expand Up @@ -63,13 +87,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) {
UnifiedVectorFormat result_data;
result.ToUnifiedFormat(row_count, result_data);
auto result_entries = UnifiedVectorFormat::GetDataNoConst<list_entry_t>(result_data);
result_data.validity.SetAllValid(row_count);

auto &result_validity = FlatVector::Validity(result);

// get the resulting size of the key/value child lists
idx_t result_child_size = 0;
for (idx_t row_idx = 0; row_idx < row_count; row_idx++) {
auto keys_idx = keys_data.sel->get_index(row_idx);
if (!keys_data.validity.RowIsValid(keys_idx)) {
auto values_idx = values_data.sel->get_index(row_idx);
if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) {
continue;
}
auto keys_entry = keys_entries[keys_idx];
Expand All @@ -87,22 +113,15 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) {
auto values_idx = values_data.sel->get_index(row_idx);
auto result_idx = result_data.sel->get_index(row_idx);

// empty map
if (!keys_data.validity.RowIsValid(keys_idx) && !values_data.validity.RowIsValid(values_idx)) {
result_entries[result_idx] = list_entry_t();
// NULL MAP
if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) {
result_validity.SetInvalid(row_idx);
continue;
}

auto keys_entry = keys_entries[keys_idx];
auto values_entry = values_entries[values_idx];

// validity checks
if (!keys_data.validity.RowIsValid(keys_idx)) {
MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_KEY_LIST);
}
if (!values_data.validity.RowIsValid(values_idx)) {
MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_VALUE_LIST);
}
if (keys_entry.length != values_entry.length) {
MapVector::EvalMapInvalidReason(MapInvalidReason::NOT_ALIGNED);
}
Expand Down Expand Up @@ -160,8 +179,19 @@ static unique_ptr<FunctionData> MapBind(ClientContext &, ScalarFunction &bound_f
MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS);
}

// bind an empty MAP
bool is_null = false;
if (arguments.empty()) {
is_null = true;
}
if (!is_null) {
auto key_id = arguments[0]->return_type.id();
auto value_id = arguments[1]->return_type.id();
if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) {
is_null = true;
}
}

if (is_null) {
bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL);
return make_uniq<VariableReturnBindData>(bound_function.return_type);
}
Expand Down
3 changes: 0 additions & 3 deletions src/function/table/arrow_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,6 @@ static void ArrowToDuckDBMapVerify(Vector &vector, idx_t count) {
case MapInvalidReason::NULL_KEY: {
throw InvalidInputException("Arrow map contains NULL as map key, which isn't supported by DuckDB map type");
}
case MapInvalidReason::NULL_KEY_LIST: {
throw InvalidInputException("Arrow map contains NULL as key list, which isn't supported by DuckDB map type");
}
default: {
throw InternalException("MapInvalidReason not implemented");
}
Expand Down
10 changes: 1 addition & 9 deletions src/include/duckdb/common/types/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,7 @@ struct FSSTVector {
DUCKDB_API static idx_t GetCount(Vector &vector);
};

enum class MapInvalidReason : uint8_t {
VALID,
NULL_KEY_LIST,
NULL_KEY,
DUPLICATE_KEY,
NULL_VALUE_LIST,
NOT_ALIGNED,
INVALID_PARAMS
};
enum class MapInvalidReason : uint8_t { VALID, NULL_KEY, DUPLICATE_KEY, NOT_ALIGNED, INVALID_PARAMS };

struct MapVector {
DUCKDB_API static const Vector &GetKeys(const Vector &vector);
Expand Down
69 changes: 69 additions & 0 deletions test/sql/types/map/map_null.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# name: test/sql/types/map/map_null.test
# group: [map]

statement ok
pragma enable_verification;

query I
select map(NULL::INT[], [1,2,3])
----
NULL

query I
select map(NULL, [1,2,3])
----
NULL

query I
select map(NULL, NULL)
----
NULL

query I
select map(NULL, [1,2,3]) IS NULL
----
true

query I
select map([1,2,3], NULL)
----
NULL

query I
select map([1,2,3], NULL::INT[])
----
NULL

query I
SELECT * FROM ( VALUES
(MAP(NULL, NULL)),
(MAP(NULL::INT[], NULL::INT[])),
(MAP([1,2,3], [1,2,3]))
)
----
NULL
NULL
{1=1, 2=2, 3=3}

query I
select MAP(a, b) FROM ( VALUES
(NULL, ['b', 'c']),
(NULL::INT[], NULL),
(NULL::INT[], NULL::VARCHAR[]),
(NULL::INT[], ['a', 'b', 'c']),
(NULL, ['longer string than inlined', 'smol']),
(NULL, NULL),
([1,2,3], NULL),
([1,2,3], ['z', 'y', 'x']),
([1,2,3], NULL::VARCHAR[]),
) t(a, b)
----
NULL
NULL
NULL
NULL
NULL
NULL
NULL
{1=z, 2=y, 3=x}
NULL
10 changes: 6 additions & 4 deletions test/sql/types/nested/map/map_error.test
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,20 @@ CREATE TABLE null_keys_list (k INT[], v INT[]);
statement ok
INSERT INTO null_keys_list VALUES ([1], [2]), (NULL, [4]);

statement error
query I
SELECT MAP(k, v) FROM null_keys_list;
----
The list of map keys must not be NULL.
{1=2}
NULL

statement ok
CREATE TABLE null_values_list (k INT[], v INT[]);

statement ok
INSERT INTO null_values_list VALUES ([1], [2]), ([4], NULL);

statement error
query I
SELECT MAP(k, v) FROM null_values_list;
----
The list of map values must not be NULL.
{1=2}
NULL
3 changes: 3 additions & 0 deletions test/sql/types/nested/map/test_map_subscript.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# description: Test cardinality function for maps
# group: [map]

statement ok
pragma enable_verification

# Single element on map
query I
select m[1] from (select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7)) as m) as T
Expand Down
37 changes: 34 additions & 3 deletions tools/pythonpkg/src/native/python_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ vector<string> TransformStructKeys(py::handle keys, idx_t size, const LogicalTyp
return res;
}

static bool IsValidMapComponent(const py::handle &component) {
// The component is either NULL
if (py::none().is(component)) {
return true;
}
if (!py::hasattr(component, "__getitem__")) {
return false;
}
if (!py::hasattr(component, "__len__")) {
return false;
}
return true;
}

bool DictionaryHasMapFormat(const PyDictionary &dict) {
if (dict.len != 2) {
return false;
Expand All @@ -51,13 +65,19 @@ bool DictionaryHasMapFormat(const PyDictionary &dict) {
return false;
}

// Dont check for 'py::list' to allow ducktyping
if (!py::hasattr(keys, "__getitem__") || !py::hasattr(keys, "__len__")) {
if (!IsValidMapComponent(keys)) {
return false;
}
if (!py::hasattr(values, "__getitem__") || !py::hasattr(values, "__len__")) {
if (!IsValidMapComponent(values)) {
return false;
}

// If either of the components is NULL, return early
if (py::none().is(keys) || py::none().is(values)) {
return true;
}

// Verify that both the keys and values are of the same length
auto size = py::len(keys);
if (size != py::len(values)) {
return false;
Expand Down Expand Up @@ -91,6 +111,11 @@ Value TransformStructFormatDictionaryToMap(const PyDictionary &dict, const Logic
if (target_type.id() != LogicalTypeId::MAP) {
throw InvalidInputException("Please provide a valid target type for transform from Python to Value");
}

if (py::none().is(dict.keys) || py::none().is(dict.values)) {
return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL));
}

auto size = py::len(dict.keys);
D_ASSERT(size == py::len(dict.values));

Expand Down Expand Up @@ -130,12 +155,18 @@ Value TransformDictionaryToMap(const PyDictionary &dict, const LogicalType &targ
auto keys = dict.values.attr("__getitem__")(0);
auto values = dict.values.attr("__getitem__")(1);

if (py::none().is(keys) || py::none().is(values)) {
// Either 'key' or 'value' is None, return early with a NULL value
return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL));
}

auto key_size = py::len(keys);
D_ASSERT(key_size == py::len(values));
if (key_size == 0) {
// dict == { 'key': [], 'value': [] }
return EmptyMapValue();
}

// dict == { 'key': [ ... ], 'value' : [ ... ] }
LogicalType key_target = LogicalTypeId::UNKNOWN;
LogicalType value_target = LogicalTypeId::UNKNOWN;
Expand Down
2 changes: 0 additions & 2 deletions tools/pythonpkg/src/numpy/numpy_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ static void VerifyMapConstraints(Vector &vec, idx_t count) {
return;
case MapInvalidReason::DUPLICATE_KEY:
throw InvalidInputException("Dict->Map conversion failed because 'key' list contains duplicates");
case MapInvalidReason::NULL_KEY_LIST:
throw InvalidInputException("Dict->Map conversion failed because 'key' list is None");
case MapInvalidReason::NULL_KEY:
throw InvalidInputException("Dict->Map conversion failed because 'key' list contains None");
default:
Expand Down
4 changes: 4 additions & 0 deletions tools/pythonpkg/src/pandas/analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ LogicalType PandasAnalyzer::DictToMap(const PyDictionary &dict, bool &can_conver
auto keys = dict.values.attr("__getitem__")(0);
auto values = dict.values.attr("__getitem__")(1);

if (py::none().is(keys) || py::none().is(values)) {
return LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL);
}

auto key_type = GetListType(keys, can_convert);
if (!can_convert) {
return EmptyMap();
Expand Down
9 changes: 9 additions & 0 deletions tools/pythonpkg/tests/fast/arrow/test_nested_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ def test_map_arrow_to_duckdb(self, duckdb_cursor):
):
rel = duckdb.from_arrow(arrow_table).fetchall()

def test_null_map_arrow_to_duckdb(self, duckdb_cursor):
if not can_run:
return
map_type = pa.map_(pa.int32(), pa.int32())
values = [None, [(5, 42)]]
arrow_table = pa.table({'detail': pa.array(values, map_type)})
res = duckdb_cursor.sql("select * from arrow_table").fetchall()
assert res == [(None,), ({'key': [5], 'value': [42]},)]

def test_map_arrow_to_pandas(self, duckdb_cursor):
if not can_run:
return
Expand Down
Loading

0 comments on commit bfb8f48

Please sign in to comment.