Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer naming convention when converting objects to structs #636

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -21263,11 +21263,51 @@ convert_object_to_struct(
bool is_gc = MS_TYPE_IS_GC(struct_type);
bool should_untrack = is_gc;

/* If no fields are renamed we only have one fields tuple to choose from */
PyObject *fields = NULL;
if (struct_type->struct_fields == struct_type->struct_encode_fields) {
fields = struct_type->struct_fields;
}

for (Py_ssize_t i = 0; i < nfields; i++) {
PyObject *field = PyTuple_GET_ITEM(struct_type->struct_encode_fields, i);
PyObject *attr = getter(obj, field);
PyObject *val;
if (attr == NULL) {
PyObject *field, *attr, *val;

if (MS_LIKELY(fields != NULL)) {
/* fields tuple already determined, just get the next field name */
field = PyTuple_GET_ITEM(fields, i);
attr = getter(obj, field);
}
else {
/* fields tuple undetermined. Try the attribute name first */
PyObject *encode_field;
field = PyTuple_GET_ITEM(struct_type->struct_fields, i);
encode_field = PyTuple_GET_ITEM(struct_type->struct_encode_fields, i);
attr = getter(obj, field);
if (field != encode_field) {
/* The field _was_ renamed */
if (attr != NULL) {
/* Got a match, lock-in using attribute names */
fields = struct_type->struct_fields;
}
else {
/* No match. Try using the renamed name */
PyErr_Clear();
attr = getter(obj, encode_field);
if (attr != NULL) {
/* Got a match, lock-in using renamed names */
field = encode_field;
fields = struct_type->struct_encode_fields;
}
}
}
}

if (attr != NULL) {
PathNode field_path = {path, PATH_STR, field};
val = convert(self, attr, info->types[i], &field_path);
Py_DECREF(attr);
}
else {
PyErr_Clear();
PyObject *default_val = NULL;
if (MS_LIKELY(i >= (nfields - ndefaults))) {
Expand All @@ -21284,11 +21324,6 @@ convert_object_to_struct(
}
val = get_default(default_val);
}
else {
PathNode field_path = {path, PATH_STR, field};
val = convert(self, attr, info->types[i], &field_path);
Py_DECREF(attr);
}
if (val == NULL) goto error;
Struct_set_index(out, i, val);
if (should_untrack) {
Expand Down
62 changes: 54 additions & 8 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,16 +1701,16 @@ class Ex2(Struct):


class TestStruct:
class Account(Struct):
class Account(Struct, kw_only=True):
first: str
last: str
age: int
verified: bool = False
age: int

@mapcls_and_from_attributes
def test_struct(self, mapcls, from_attributes):
msg = mapcls(first="alice", last="munro", age=91, verified=True)
sol = self.Account("alice", "munro", 91, True)
sol = self.Account(first="alice", last="munro", verified=True, age=91)
res = convert(msg, self.Account, from_attributes=from_attributes)
assert res == sol

Expand Down Expand Up @@ -1754,12 +1754,58 @@ class Ex(Struct):

assert convert(msg, Ex, from_attributes=True) == Ex(1)

def test_from_attributes_option_uses_renamed_fields(self):
@pytest.mark.parametrize("mapcls", [GetAttrObj, GetItemObj])
def test_object_to_struct_with_renamed_fields(self, mapcls):
class Ex(Struct, rename="camel"):
field_one: int
fa: int
f_b: int
fc: int
f_d: int

sol = Ex(1, 2, 3, 4)

# Use attribute names
msg = mapcls(fa=1, f_b=2, fc=3, f_d=4)
assert convert(msg, Ex, from_attributes=True) == sol

# Use renamed names
msg = mapcls(fa=1, fB=2, fc=3, fD=4)
assert convert(msg, Ex, from_attributes=True) == sol

msg = GetAttrObj(fieldOne=2)
assert convert(msg, Ex, from_attributes=True) == Ex(2)
# Priority to attribute names
msg = mapcls(fa=1, f_b=2, fB=100, fc=3, f_d=4, fD=100)
assert convert(msg, Ex, from_attributes=True) == sol

# Don't allow renamed names if determined to be attributes
msg = mapcls(fa=1, f_b=2, fc=3, fD=4)
with pytest.raises(ValidationError, match="missing required field `f_d`"):
convert(msg, Ex, from_attributes=True)

# Don't allow attributes if determined to be renamed names
msg = mapcls(fa=1, fB=2, fc=3, f_d=4)
with pytest.raises(ValidationError, match="missing required field `fD`"):
convert(msg, Ex, from_attributes=True)

# Errors use attribute name if using attributes
msg = mapcls(fa=1, f_b=2, fc=3, f_d="bad")
with pytest.raises(
ValidationError, match=r"Expected `int`, got `str` - at `\$.f_d`"
):
convert(msg, Ex, from_attributes=True)

# Errors use renamed name if using renamed names
msg = mapcls(fa=1, fB=2, fc=3, fD="bad")
with pytest.raises(
ValidationError, match=r"Expected `int`, got `str` - at `\$.fD`"
):
convert(msg, Ex, from_attributes=True)

# Errors use attribute name if undecided
msg = mapcls(fa="bad")
with pytest.raises(
ValidationError, match=r"Expected `int`, got `str` - at `\$.fa`"
):
convert(msg, Ex, from_attributes=True)

@pytest.mark.parametrize("forbid_unknown_fields", [False, True])
@mapcls_and_from_attributes
Expand All @@ -1780,7 +1826,7 @@ class Ex(Struct, forbid_unknown_fields=forbid_unknown_fields):
def test_struct_defaults_missing_fields(self, mapcls, from_attributes):
msg = mapcls(first="alice", last="munro", age=91)
res = convert(msg, self.Account, from_attributes=from_attributes)
assert res == self.Account("alice", "munro", 91)
assert res == self.Account(first="alice", last="munro", age=91)

@mapcls_from_attributes_and_array_like
def test_struct_gc_maybe_untracked_on_decode(
Expand Down
Loading