Skip to content

Feature: Allow to skip seq access for internally and adjacently tagged enums #2933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions serde/src/private/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ use crate::de::{MapAccess, Unexpected};
#[cfg(any(feature = "std", feature = "alloc"))]
pub use self::content::{
Content, ContentDeserializer, ContentRefDeserializer, EnumDeserializer,
InternallyTaggedUnitVisitor, TagContentOtherField, TagContentOtherFieldVisitor,
TagOrContentField, TagOrContentFieldVisitor, TaggedContentVisitor, UntaggedUnitVisitor,
InternallyTaggedUnitVisitor, NoSeqTaggedContentVisitor, TagContentOtherField,
TagContentOtherFieldVisitor, TagOrContentField, TagOrContentFieldVisitor, TaggedContentVisitor,
UntaggedUnitVisitor,
};

pub use crate::seed::InPlaceSeed;
Expand Down Expand Up @@ -829,6 +830,40 @@ mod content {
}
}

/// Used by generated code to deserialize an internally tagged enum without sequence format.
///
/// Captures map from the original deserializer and searches
/// a tag in it.
///
/// Not public API.
pub struct NoSeqTaggedContentVisitor<T>(TaggedContentVisitor<T>);

impl<T> NoSeqTaggedContentVisitor<T> {
/// Visitor for the content of an internally tagged enum with the given
/// tag name.
pub fn new(name: &'static str, expecting: &'static str) -> Self {
Self(TaggedContentVisitor::new(name, expecting))
}
}

impl<'de, T> Visitor<'de> for NoSeqTaggedContentVisitor<T>
where
T: Deserialize<'de>,
{
type Value = (T, Content<'de>);

fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
self.0.expecting(fmt)
}

fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
self.0.visit_map(map)
}
}

/// Used by generated code to deserialize an internally tagged enum.
///
/// Captures map or sequence from the original deserializer and searches
Expand Down
90 changes: 54 additions & 36 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1229,12 +1229,14 @@ fn deserialize_homogeneous_enum(
) -> Fragment {
match cattrs.tag() {
attr::TagType::External => deserialize_externally_tagged_enum(params, variants, cattrs),
attr::TagType::Internal { tag } => {
deserialize_internally_tagged_enum(params, variants, cattrs, tag)
}
attr::TagType::Adjacent { tag, content } => {
deserialize_adjacently_tagged_enum(params, variants, cattrs, tag, content)
}
attr::TagType::Internal { tag, use_seq } => {
deserialize_internally_tagged_enum(params, variants, cattrs, tag, *use_seq)
}
attr::TagType::Adjacent {
tag,
content,
use_seq,
} => deserialize_adjacently_tagged_enum(params, variants, cattrs, tag, content, *use_seq),
attr::TagType::None => deserialize_untagged_enum(params, variants, cattrs),
}
}
Expand Down Expand Up @@ -1380,6 +1382,7 @@ fn deserialize_internally_tagged_enum(
variants: &[Variant],
cattrs: &attr::Container,
tag: &str,
use_seq: bool,
) -> Fragment {
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);

Expand All @@ -1406,14 +1409,20 @@ fn deserialize_internally_tagged_enum(
let expecting = format!("internally tagged enum {}", params.type_name());
let expecting = cattrs.expecting().unwrap_or(&expecting);

let tagged_content_visitor = if use_seq {
quote!(_serde::__private::de::TaggedContentVisitor)
} else {
quote!(_serde::__private::de::NoSeqTaggedContentVisitor)
};

quote_block! {
#variant_visitor

#variants_stmt

let (__tag, __content) = _serde::Deserializer::deserialize_any(
__deserializer,
_serde::__private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting))?;
#tagged_content_visitor::<__Field>::new(#tag, #expecting))?;
let __deserializer = _serde::__private::de::ContentDeserializer::<__D::Error>::new(__content);

match __tag {
Expand All @@ -1428,6 +1437,7 @@ fn deserialize_adjacently_tagged_enum(
cattrs: &attr::Container,
tag: &str,
content: &str,
use_seq: bool,
) -> Fragment {
let this_type = &params.this_type;
let this_value = &params.this_value;
Expand Down Expand Up @@ -1595,6 +1605,42 @@ fn deserialize_adjacently_tagged_enum(
}
};

let visit_seq = if use_seq {
quote! {
fn visit_seq<__A>(self, mut __seq: __A) -> _serde::__private::Result<Self::Value, __A::Error>
where
__A: _serde::de::SeqAccess<#delife>,
{
// Visit the first element - the tag.
match _serde::de::SeqAccess::next_element(&mut __seq)? {
_serde::__private::Some(__field) => {
// Visit the second element - the content.
match _serde::de::SeqAccess::next_element_seed(
&mut __seq,
__Seed {
field: __field,
marker: _serde::__private::PhantomData,
lifetime: _serde::__private::PhantomData,
},
)? {
_serde::__private::Some(__ret) => _serde::__private::Ok(__ret),
// There is no second element.
_serde::__private::None => {
_serde::__private::Err(_serde::de::Error::invalid_length(1, &self))
}
}
}
// There is no first element.
_serde::__private::None => {
_serde::__private::Err(_serde::de::Error::invalid_length(0, &self))
}
}
}
}
} else {
quote! {}
};

quote_block! {
#variant_visitor

Expand Down Expand Up @@ -1694,35 +1740,7 @@ fn deserialize_adjacently_tagged_enum(
}
}

fn visit_seq<__A>(self, mut __seq: __A) -> _serde::__private::Result<Self::Value, __A::Error>
where
__A: _serde::de::SeqAccess<#delife>,
{
// Visit the first element - the tag.
match _serde::de::SeqAccess::next_element(&mut __seq)? {
_serde::__private::Some(__field) => {
// Visit the second element - the content.
match _serde::de::SeqAccess::next_element_seed(
&mut __seq,
__Seed {
field: __field,
marker: _serde::__private::PhantomData,
lifetime: _serde::__private::PhantomData,
},
)? {
_serde::__private::Some(__ret) => _serde::__private::Ok(__ret),
// There is no second element.
_serde::__private::None => {
_serde::__private::Err(_serde::de::Error::invalid_length(1, &self))
}
}
}
// There is no first element.
_serde::__private::None => {
_serde::__private::Err(_serde::de::Error::invalid_length(0, &self))
}
}
}
#visit_seq
}

#[doc(hidden)]
Expand Down
72 changes: 60 additions & 12 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,18 @@ pub enum TagType {
/// ```json
/// {"type": "variant1", "key1": "value1", "key2": "value2"}
/// ```
Internal { tag: String },
Internal { tag: String, use_seq: bool },

/// `#[serde(tag = "t", content = "c")]`
///
/// ```json
/// {"t": "variant1", "c": {"key1": "value1", "key2": "value2"}}
/// ```
Adjacent { tag: String, content: String },
Adjacent {
tag: String,
content: String,
use_seq: bool,
},

/// `#[serde(untagged)]`
///
Expand Down Expand Up @@ -248,6 +252,7 @@ impl Container {
let mut ser_bound = Attr::none(cx, BOUND);
let mut de_bound = Attr::none(cx, BOUND);
let mut untagged = BoolAttr::none(cx, UNTAGGED);
let mut use_seq = Attr::none(cx, USE_SEQ);
let mut internal_tag = Attr::none(cx, TAG);
let mut content = Attr::none(cx, CONTENT);
let mut type_from = Attr::none(cx, FROM);
Expand Down Expand Up @@ -434,6 +439,12 @@ impl Container {
}
}
}
} else if meta.path == USE_SEQ {
// #[serde(seq = false)]
if let Some(s) = get_lit_bool(cx, CONTENT, &meta)? {
// Should we introduce there a compile time checks? If so, which ones?
use_seq.set(&meta.path, s.value());
}
} else if meta.path == CONTENT {
// #[serde(content = "c")]
if let Some(s) = get_lit_str(cx, CONTENT, &meta)? {
Expand Down Expand Up @@ -532,7 +543,7 @@ impl Container {
},
ser_bound: ser_bound.get(),
de_bound: de_bound.get(),
tag: decide_tag(cx, item, untagged, internal_tag, content),
tag: decide_tag(cx, item, untagged, internal_tag, content, use_seq),
type_from: type_from.get(),
type_try_from: type_try_from.get(),
type_into: type_into.get(),
Expand Down Expand Up @@ -631,15 +642,17 @@ fn decide_tag(
untagged: BoolAttr,
internal_tag: Attr<String>,
content: Attr<String>,
use_seq: Attr<bool>,
) -> TagType {
match (
untagged.0.get_with_tokens(),
internal_tag.get_with_tokens(),
content.get_with_tokens(),
use_seq.get(),
) {
(None, None, None) => TagType::External,
(Some(_), None, None) => TagType::None,
(None, Some((_, tag)), None) => {
(None, None, None, _) => TagType::External,
(Some(_), None, None, _) => TagType::None,
(None, Some((_, tag)), None, use_seq) => {
// Check that there are no tuple variants.
if let syn::Data::Enum(data) = &item.data {
for variant in &data.variants {
Expand All @@ -656,27 +669,34 @@ fn decide_tag(
}
}
}
TagType::Internal { tag }
TagType::Internal {
tag,
use_seq: use_seq.unwrap_or(true),
}
}
(Some((untagged_tokens, ())), Some((tag_tokens, _)), None) => {
(Some((untagged_tokens, ())), Some((tag_tokens, _)), None, _) => {
let msg = "enum cannot be both untagged and internally tagged";
cx.error_spanned_by(untagged_tokens, msg);
cx.error_spanned_by(tag_tokens, msg);
TagType::External // doesn't matter, will error
}
(None, None, Some((content_tokens, _))) => {
(None, None, Some((content_tokens, _)), _) => {
let msg = "#[serde(tag = \"...\", content = \"...\")] must be used together";
cx.error_spanned_by(content_tokens, msg);
TagType::External
}
(Some((untagged_tokens, ())), None, Some((content_tokens, _))) => {
(Some((untagged_tokens, ())), None, Some((content_tokens, _)), _) => {
let msg = "untagged enum cannot have #[serde(content = \"...\")]";
cx.error_spanned_by(untagged_tokens, msg);
cx.error_spanned_by(content_tokens, msg);
TagType::External
}
(None, Some((_, tag)), Some((_, content))) => TagType::Adjacent { tag, content },
(Some((untagged_tokens, ())), Some((tag_tokens, _)), Some((content_tokens, _))) => {
(None, Some((_, tag)), Some((_, content)), use_seq) => TagType::Adjacent {
tag,
content,
use_seq: use_seq.unwrap_or(true),
},
(Some((untagged_tokens, ())), Some((tag_tokens, _)), Some((content_tokens, _)), _) => {
let msg = "untagged enum cannot have #[serde(tag = \"...\", content = \"...\")]";
cx.error_spanned_by(untagged_tokens, msg);
cx.error_spanned_by(tag_tokens, msg);
Expand Down Expand Up @@ -1415,6 +1435,34 @@ fn get_where_predicates(
Ok((ser.at_most_one(), de.at_most_one()))
}

fn get_lit_bool(
cx: &Ctxt,
attr_name: Symbol,
meta: &ParseNestedMeta,
) -> syn::Result<Option<syn::LitBool>> {
let expr: syn::Expr = meta.value()?.parse()?;
let mut value = &expr;
while let syn::Expr::Group(e) = value {
value = &e.expr;
}
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Bool(lit),
..
}) = value
{
Ok(Some(lit.clone()))
} else {
cx.error_spanned_by(
expr,
format!(
"expected serde {} attribute to be a bool: `{} = false/true`",
attr_name, attr_name
),
);
Ok(None)
}
}

fn get_lit_str(
cx: &Ctxt,
attr_name: Symbol,
Expand Down
4 changes: 2 additions & 2 deletions serde_derive/src/internals/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ fn check_internal_tag_field_name_conflict(cx: &Ctxt, cont: &Container) {
};

let tag = match cont.attrs.tag() {
TagType::Internal { tag } => tag.as_str(),
TagType::Internal { tag, .. } => tag.as_str(),
TagType::External | TagType::Adjacent { .. } | TagType::None => return,
};

Expand Down Expand Up @@ -351,7 +351,7 @@ fn check_internal_tag_field_name_conflict(cx: &Ctxt, cont: &Container) {
// differ, for the same reason.
fn check_adjacent_tag_conflict(cx: &Ctxt, cont: &Container) {
let (type_tag, content_tag) = match cont.attrs.tag() {
TagType::Adjacent { tag, content } => (tag, content),
TagType::Adjacent { tag, content, .. } => (tag, content),
TagType::Internal { .. } | TagType::External | TagType::None => return,
};

Expand Down
1 change: 1 addition & 0 deletions serde_derive/src/internals/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub const TAG: Symbol = Symbol("tag");
pub const TRANSPARENT: Symbol = Symbol("transparent");
pub const TRY_FROM: Symbol = Symbol("try_from");
pub const UNTAGGED: Symbol = Symbol("untagged");
pub const USE_SEQ: Symbol = Symbol("seq_form");
pub const VARIANT_IDENTIFIER: Symbol = Symbol("variant_identifier");
pub const WITH: Symbol = Symbol("with");

Expand Down
6 changes: 3 additions & 3 deletions serde_derive/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ fn serialize_struct(params: &Parameters, fields: &[Field], cattrs: &attr::Contai

fn serialize_struct_tag_field(cattrs: &attr::Container, struct_trait: &StructTrait) -> TokenStream {
match cattrs.tag() {
attr::TagType::Internal { tag } => {
attr::TagType::Internal { tag, .. } => {
let type_name = cattrs.name().serialize_name();
let func = struct_trait.serialize_field(Span::call_site());
quote! {
Expand Down Expand Up @@ -473,10 +473,10 @@ fn serialize_variant(
(attr::TagType::External, false) => {
serialize_externally_tagged_variant(params, variant, variant_index, cattrs)
}
(attr::TagType::Internal { tag }, false) => {
(attr::TagType::Internal { tag, .. }, false) => {
serialize_internally_tagged_variant(params, variant, cattrs, tag)
}
(attr::TagType::Adjacent { tag, content }, false) => {
(attr::TagType::Adjacent { tag, content, .. }, false) => {
serialize_adjacently_tagged_variant(
params,
variant,
Expand Down
Loading