Skip to content

Commit 29390b0

Browse files
committed
Support encoding any enum type whose value is supported
Previously we only supported encoding and decoding enums with integer or string values. We now support encoding enums with any value that is also a supported type. The restriction on decoding only integer or string values remains.
1 parent 9d69034 commit 29390b0

File tree

4 files changed

+37
-66
lines changed

4 files changed

+37
-66
lines changed

docs/source/supported-types.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,9 +1059,13 @@ other than a field for one of these types.
10591059
------------------------------------
10601060

10611061
Enum types (`enum.Enum`, `enum.IntEnum`, `enum.StrEnum`, ...) encode as their
1062-
member *values* in all protocols. Only enums composed of all string or all
1063-
integer values are supported. An error is raised during decoding if the value
1064-
isn't the proper type, or doesn't match any valid member.
1062+
member *values* in all protocols.
1063+
1064+
Any enum whose *value* is a supported type may be encoded, but only enums
1065+
composed of all string or all integer values may be decoded.
1066+
1067+
An error is raised during decoding if the value isn't the proper type, or
1068+
doesn't match any valid member.
10651069

10661070
.. code-block:: python
10671071

msgspec/_core.c

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12944,22 +12944,9 @@ mpack_encode_enum(EncoderState *self, PyObject *obj)
1294412944
if (PyUnicode_Check(obj))
1294512945
return mpack_encode_str(self, obj);
1294612946

12947-
int status;
1294812947
PyObject *value = PyObject_GetAttr(obj, self->mod->str__value_);
1294912948
if (value == NULL) return -1;
12950-
if (PyLong_CheckExact(value)) {
12951-
status = mpack_encode_long(self, value);
12952-
}
12953-
else if (PyUnicode_CheckExact(value)) {
12954-
status = mpack_encode_str(self, value);
12955-
}
12956-
else {
12957-
PyErr_SetString(
12958-
self->mod->EncodeError,
12959-
"Only enums with int or str values are supported"
12960-
);
12961-
status = -1;
12962-
}
12949+
int status = mpack_encode(self, value);
1296312950
Py_DECREF(value);
1296412951
return status;
1296512952
}
@@ -13590,40 +13577,25 @@ json_encode_raw(EncoderState *self, PyObject *obj)
1359013577
return ms_write(self, raw->buf, raw->len);
1359113578
}
1359213579

13580+
static int json_encode_dict_key_noinline(EncoderState *, PyObject *);
13581+
1359313582
static int
1359413583
json_encode_enum(EncoderState *self, PyObject *obj, bool is_key)
1359513584
{
1359613585
if (PyLong_Check(obj)) {
13597-
if (MS_UNLIKELY(is_key)) {
13598-
return json_encode_long_as_str(self, obj);
13599-
}
13600-
return json_encode_long(self, obj);
13586+
return is_key ? json_encode_long_as_str(self, obj) : json_encode_long(self, obj);
1360113587
}
1360213588
if (PyUnicode_Check(obj)) {
1360313589
return json_encode_str(self, obj);
1360413590
}
1360513591

13606-
int status;
1360713592
PyObject *value = PyObject_GetAttr(obj, self->mod->str__value_);
1360813593
if (value == NULL) return -1;
13609-
if (PyLong_CheckExact(value)) {
13610-
if (MS_UNLIKELY(is_key)) {
13611-
status = json_encode_long_as_str(self, value);
13612-
}
13613-
else {
13614-
status = json_encode_long(self, value);
13615-
}
13616-
}
13617-
else if (PyUnicode_CheckExact(value)) {
13618-
status = json_encode_str(self, value);
13619-
}
13620-
else {
13621-
PyErr_SetString(
13622-
self->mod->EncodeError,
13623-
"Only enums with int or str values are supported"
13624-
);
13625-
status = -1;
13626-
}
13594+
13595+
int status = (
13596+
is_key ? json_encode_dict_key_noinline(self, value) : json_encode(self, value)
13597+
);
13598+
1362713599
Py_DECREF(value);
1362813600
return status;
1362913601
}
@@ -13792,8 +13764,6 @@ json_encode_set(EncoderState *self, PyObject *obj)
1379213764
return status;
1379313765
}
1379413766

13795-
static int json_encode_dict_key_noinline(EncoderState *, PyObject *);
13796-
1379713767
static MS_INLINE int
1379813768
json_encode_dict_key(EncoderState *self, PyObject *key) {
1379913769
if (MS_LIKELY(PyUnicode_Check(key))) {
@@ -19414,17 +19384,7 @@ static PyObject * to_builtins(ToBuiltinsState *, PyObject *, bool);
1941419384
static PyObject *
1941519385
to_builtins_enum(ToBuiltinsState *self, PyObject *obj)
1941619386
{
19417-
PyObject *value = PyObject_GetAttr(obj, self->mod->str__value_);
19418-
if (value == NULL) return NULL;
19419-
if (PyLong_CheckExact(value) || PyUnicode_CheckExact(value)) {
19420-
return value;
19421-
}
19422-
Py_DECREF(value);
19423-
PyErr_SetString(
19424-
self->mod->EncodeError,
19425-
"Only enums with int or str values are supported"
19426-
);
19427-
return NULL;
19387+
return PyObject_GetAttr(obj, self->mod->str__value_);
1942819388
}
1942919389

1943019390
static PyObject *

tests/test_common.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -601,20 +601,27 @@ class Empty(enum.Enum):
601601
with pytest.raises(TypeError, match="Enum types must have at least one item"):
602602
proto.Decoder(Empty)
603603

604-
def test_unsupported_type_errors(self, proto):
605-
class Bad(enum.Enum):
604+
def test_encode_complex(self, proto):
605+
class Complex(enum.Enum):
606606
A = 1.5
607607

608-
with pytest.raises(
609-
msgspec.EncodeError, match="Only enums with int or str values are supported"
610-
):
611-
proto.encode(Bad.A)
608+
res = proto.encode(Complex.A)
609+
sol = proto.encode(1.5)
610+
assert res == sol
611+
612+
res = proto.encode({Complex.A: 1})
613+
sol = proto.encode({1.5: 1})
614+
assert res == sol
615+
616+
def test_decode_complex_errors(self, proto):
617+
class Complex(enum.Enum):
618+
A = 1.5
612619

613620
with pytest.raises(TypeError) as rec:
614-
proto.Decoder(Bad)
621+
proto.Decoder(Complex)
615622

616623
assert "Enums must contain either all str or all int values" in str(rec.value)
617-
assert repr(Bad) in str(rec.value)
624+
assert repr(Complex) in str(rec.value)
618625

619626
@pytest.mark.parametrize(
620627
"values",

tests/test_to_builtins.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pytest
1313

14-
from msgspec import UNSET, EncodeError, Struct, UnsetType, to_builtins, defstruct
14+
from msgspec import UNSET, Struct, UnsetType, to_builtins, defstruct
1515

1616
PY310 = sys.version_info[:2] >= (3, 10)
1717
PY311 = sys.version_info[:2] >= (3, 11)
@@ -216,12 +216,12 @@ def test_enum(self):
216216
assert res == "apple"
217217
assert type(res) is str
218218

219-
def test_enum_invalid(self):
220-
class Bad(enum.Enum):
219+
def test_enum_complex(self):
220+
class Complex(enum.Enum):
221221
x = (1, 2)
222222

223-
with pytest.raises(EncodeError, match="Only enums with int or str"):
224-
to_builtins(Bad.x)
223+
res = to_builtins(Complex.x)
224+
assert res is Complex.x.value
225225

226226
@pytest.mark.parametrize(
227227
"in_type, out_type",

0 commit comments

Comments
 (0)