Skip to content

Commit

Permalink
return properly typed NULL value, fix up python tests and behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishj committed Apr 19, 2024
1 parent 738b301 commit 58c44f6
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 15 deletions.
29 changes: 20 additions & 9 deletions src/core_functions/scalar/map/map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,28 @@ static void MapFunctionEmptyInput(Vector &result, const idx_t row_count) {
result.Verify(row_count);
}

static bool MapIsNull(const LogicalType &map) {
D_ASSERT(map.id() == LogicalTypeId::MAP);
auto &key = MapType::KeyType(map);
auto &value = MapType::ValueType(map);
return (key.id() == LogicalTypeId::SQLNULL && value.id() == LogicalTypeId::SQLNULL);
}

static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) {

// internal MAP representation
// - LIST-vector that contains STRUCTs as child entries
// - STRUCTs have exactly two fields, a key-field, and a value-field
// - key names are unique
D_ASSERT(result.GetType().id() == LogicalTypeId::MAP);

if (result.GetType().id() == LogicalTypeId::SQLNULL) {
if (MapIsNull(result.GetType())) {
auto &validity = FlatVector::Validity(result);
validity.SetInvalid(0);
result.SetVectorType(VectorType::CONSTANT_VECTOR);
return;
}

D_ASSERT(result.GetType().id() == LogicalTypeId::MAP);
auto row_count = args.size();

// early-out, if no data
Expand Down Expand Up @@ -162,16 +169,20 @@ static unique_ptr<FunctionData> MapBind(ClientContext &, ScalarFunction &bound_f
MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS);
}

// bind an empty MAP
bool is_null = false;
if (arguments.empty()) {
bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL);
return make_uniq<VariableReturnBindData>(bound_function.return_type);
is_null = true;
}
if (!is_null) {
auto key_id = arguments[0]->return_type.id();
auto value_id = arguments[1]->return_type.id();
if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) {
is_null = true;
}
}

auto key_id = arguments[0]->return_type.id();
auto value_id = arguments[1]->return_type.id();
if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) {
bound_function.return_type = LogicalTypeId::SQLNULL;
if (is_null) {
bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL);
return make_uniq<VariableReturnBindData>(bound_function.return_type);
}

Expand Down
37 changes: 34 additions & 3 deletions tools/pythonpkg/src/native/python_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ vector<string> TransformStructKeys(py::handle keys, idx_t size, const LogicalTyp
return res;
}

static bool IsValidMapComponent(const py::handle &component) {
// The component is either NULL
if (py::none().is(component)) {
return true;
}
if (!py::hasattr(component, "__getitem__")) {
return false;
}
if (!py::hasattr(component, "__len__")) {
return false;
}
return true;
}

bool DictionaryHasMapFormat(const PyDictionary &dict) {
if (dict.len != 2) {
return false;
Expand All @@ -51,13 +65,19 @@ bool DictionaryHasMapFormat(const PyDictionary &dict) {
return false;
}

// Dont check for 'py::list' to allow ducktyping
if (!py::hasattr(keys, "__getitem__") || !py::hasattr(keys, "__len__")) {
if (!IsValidMapComponent(keys)) {
return false;
}
if (!py::hasattr(values, "__getitem__") || !py::hasattr(values, "__len__")) {
if (!IsValidMapComponent(values)) {
return false;
}

// If either of the components is NULL, return early
if (py::none().is(keys) || py::none().is(values)) {
return true;
}

// Verify that both the keys and values are of the same length
auto size = py::len(keys);
if (size != py::len(values)) {
return false;
Expand Down Expand Up @@ -91,6 +111,11 @@ Value TransformStructFormatDictionaryToMap(const PyDictionary &dict, const Logic
if (target_type.id() != LogicalTypeId::MAP) {
throw InvalidInputException("Please provide a valid target type for transform from Python to Value");
}

if (py::none().is(dict.keys) || py::none().is(dict.values)) {
return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL));
}

auto size = py::len(dict.keys);
D_ASSERT(size == py::len(dict.values));

Expand Down Expand Up @@ -130,12 +155,18 @@ Value TransformDictionaryToMap(const PyDictionary &dict, const LogicalType &targ
auto keys = dict.values.attr("__getitem__")(0);
auto values = dict.values.attr("__getitem__")(1);

if (py::none().is(keys) || py::none().is(values)) {
// Either 'key' or 'value' is None, return early with a NULL value
return Value(LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL));
}

auto key_size = py::len(keys);
D_ASSERT(key_size == py::len(values));
if (key_size == 0) {
// dict == { 'key': [], 'value': [] }
return EmptyMapValue();
}

// dict == { 'key': [ ... ], 'value' : [ ... ] }
LogicalType key_target = LogicalTypeId::UNKNOWN;
LogicalType value_target = LogicalTypeId::UNKNOWN;
Expand Down
4 changes: 4 additions & 0 deletions tools/pythonpkg/src/pandas/analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ LogicalType PandasAnalyzer::DictToMap(const PyDictionary &dict, bool &can_conver
auto keys = dict.values.attr("__getitem__")(0);
auto values = dict.values.attr("__getitem__")(1);

if (py::none().is(keys) || py::none().is(values)) {
return LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL);
}

auto key_type = GetListType(keys, can_convert);
if (!can_convert) {
return EmptyMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_map_duplicate(self, pandas, duckdb_cursor):
with pytest.raises(
duckdb.InvalidInputException, match="Dict->Map conversion failed because 'key' list contains duplicates"
):
converted_col = duckdb_cursor.sql("select * from x").df()
duckdb_cursor.sql("select * from x").show()

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_map_nullkey(self, pandas, duckdb_cursor):
Expand All @@ -337,9 +337,8 @@ def test_map_nullkey(self, pandas, duckdb_cursor):
@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
def test_map_nullkeylist(self, pandas, duckdb_cursor):
x = pandas.DataFrame([[{'key': None, 'value': None}]])
# Isn't actually converted to MAP because isinstance(None, list) != True
converted_col = duckdb_cursor.sql("select * from x").df()
duckdb_col = duckdb_cursor.sql("SELECT {key: NULL, value: NULL} as '0'").df()
duckdb_col = duckdb_cursor.sql("SELECT MAP(NULL, NULL) as '0'").df()
pandas.testing.assert_frame_equal(duckdb_col, converted_col)

@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()])
Expand Down

0 comments on commit 58c44f6

Please sign in to comment.