Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions crates/bindings-macro/src/sats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
de_generics.params.insert(0, de_lt_param.into());
let (de_impl_generics, _, de_where_clause) = de_generics.split_for_impl();

let (iter_n, iter_n2, iter_n3) = (0usize.., 0usize.., 0usize..);
let (iter_n, iter_n2, iter_n3, iter_n4) = (0usize.., 0usize.., 0usize.., 0usize..);

match &ty.data {
SatsTypeData::Product(fields) => {
Expand Down Expand Up @@ -443,8 +443,8 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
impl #de_impl_generics #spacetimedb_lib::de::FieldNameVisitor<'de> for __ProductVisitor #ty_generics #de_where_clause {
type Output = __ProductFieldIdent;

fn field_names(&self, names: &mut dyn #spacetimedb_lib::de::ValidNames) {
names.extend::<&[&str]>(&[#(#field_strings),*])
fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>> {
[#(#field_strings),*].into_iter().map(Some)
}

fn visit<__E: #spacetimedb_lib::de::Error>(self, name: &str) -> Result<Self::Output, __E> {
Expand All @@ -453,6 +453,13 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
_ => Err(#spacetimedb_lib::de::Error::unknown_field_name(name, &self)),
}
}

fn visit_seq<__E: #spacetimedb_lib::de::Error>(self, index: usize, name: &str) -> Result<Self::Output, __E> {
match index {
#(#iter_n4 => Ok(__ProductFieldIdent::#field_names),)*
_ => Err(#spacetimedb_lib::de::Error::unknown_field_name(name, &self)),
}
}
}

#[allow(non_camel_case_types)]
Expand Down Expand Up @@ -516,11 +523,11 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
#(#variant_idents,)*
}

impl #de_impl_generics #spacetimedb_lib::de::VariantVisitor for __SumVisitor #ty_generics #de_where_clause {
impl #de_impl_generics #spacetimedb_lib::de::VariantVisitor<'de> for __SumVisitor #ty_generics #de_where_clause {
type Output = __Variant;

fn variant_names(&self, names: &mut dyn #spacetimedb_lib::de::ValidNames) {
names.extend::<&[&str]>(&[#(#variant_names,)*])
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
[#(#variant_names,)*].into_iter()
}

fn visit_tag<E: #spacetimedb_lib::de::Error>(self, __tag: u8) -> Result<Self::Output, E> {
Expand Down
6 changes: 3 additions & 3 deletions crates/sats/src/algebraic_value/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ impl SumAccess {
}
}

impl de::SumAccess<'_> for SumAccess {
impl<'de> de::SumAccess<'de> for SumAccess {
type Error = ValueDeserializeError;

type Variant = ValueDeserializer;

fn variant<V: de::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
let tag = visitor.visit_tag(self.sum.tag)?;
let val = *self.sum.value;
Ok((tag, ValueDeserializer { val }))
Expand Down Expand Up @@ -313,7 +313,7 @@ impl<'de> de::SumAccess<'de> for &'de SumAccess {

type Variant = &'de ValueDeserializer;

fn variant<V: de::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
let tag = visitor.visit_tag(self.sum.tag)?;
Ok((tag, ValueDeserializer::from_ref(&self.sum.value)))
}
Expand Down
2 changes: 1 addition & 1 deletion crates/sats/src/bsatn/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl<'de, R: BufReader<'de>> SumAccess<'de> for Deserializer<'_, R> {
type Error = DecodeError;
type Variant = Self;

fn variant<V: de::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
let tag = self.reader.get_u8()?;
visitor.visit_tag(tag).map(|variant| (variant, self))
}
Expand Down
140 changes: 45 additions & 95 deletions crates/sats/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ pub trait Error: Sized {
ProductKind::Normal => "field",
ProductKind::ReducerArgs => "reducer argument",
};
if let Some(one_of) = one_of_names(|n| expected.field_names(n)) {
if let Some(one_of) = one_of_names(|| expected.field_names()) {
Self::custom(format_args!("unknown {el_ty} `{field_name}`, expected {one_of}"))
} else {
Self::custom(format_args!("unknown {el_ty} `{field_name}`, there are no {el_ty}s"))
Expand All @@ -200,8 +200,8 @@ pub trait Error: Sized {
}

/// The `name` is not that of a variant of the sum type.
fn unknown_variant_name<T: VariantVisitor>(name: &str, expected: &T) -> Self {
if let Some(one_of) = one_of_names(|n| expected.variant_names(n)) {
fn unknown_variant_name<'de, T: VariantVisitor<'de>>(name: &str, expected: &T) -> Self {
if let Some(one_of) = one_of_names(|| expected.variant_names().map(Some)) {
Self::custom(format_args!("unknown variant `{name}`, expected {one_of}",))
} else {
Self::custom(format_args!("unknown variant `{name}`, there are no variants"))
Expand Down Expand Up @@ -358,39 +358,18 @@ pub trait FieldNameVisitor<'de> {
ProductKind::Normal
}

/// Provides the visitor the chance to add valid names into `names`.
fn field_names(&self, names: &mut dyn ValidNames);
/// Provides a list of valid field names.
///
/// Where `None` is yielded, this indicates a nameless field.
fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>>;

/// Deserializes the name of a field using `name`.
fn visit<E: Error>(self, name: &str) -> Result<Self::Output, E>;
}

/// A trait for types storing a set of valid names.
pub trait ValidNames {
/// Adds the name `s` to the set.
fn push(&mut self, s: &str);

/// Runs the function `names` provided with `self` as the store
/// and then returns back `self`.
/// This method exists for convenience.
fn run(mut self, names: &impl Fn(&mut dyn ValidNames)) -> Self
where
Self: Sized,
{
names(&mut self);
self
}
}

impl dyn ValidNames + '_ {
/// Adds the names in `iter` to the set.
pub fn extend<I: IntoIterator>(&mut self, iter: I)
where
I::Item: AsRef<str>,
{
for name in iter {
self.push(name.as_ref())
}
}
/// Deserializes the name of a field using `index`.
///
/// The `name` is provided for error messages.
fn visit_seq<E: Error>(self, index: usize, name: &str) -> Result<Self::Output, E>;
}

/// A visitor walking through a [`Deserializer`] for sums.
Expand Down Expand Up @@ -442,17 +421,17 @@ pub trait SumAccess<'de> {
/// The `visitor` is provided by the [`Deserializer`].
/// This method is typically called from [`SumVisitor::visit_sum`]
/// which will provide the [`V: VariantVisitor`](VariantVisitor).
fn variant<V: VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error>;
fn variant<V: VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error>;
}

/// A visitor passed from [`SumVisitor`] to [`SumAccess::variant`]
/// which the latter uses to decide what variant to deserialize.
pub trait VariantVisitor {
pub trait VariantVisitor<'de> {
/// The result of identifying a variant, e.g., some index type.
type Output;

/// Provides the visitor the chance to add valid names into `names`.
fn variant_names(&self, names: &mut dyn ValidNames);
/// Provides a list of variant names.
fn variant_names(&self) -> impl '_ + Iterator<Item = &str>;

/// Identify the variant based on `tag`.
fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E>;
Expand Down Expand Up @@ -669,71 +648,42 @@ impl<'de, T, const N: usize> ArrayVisitor<'de, T> for BasicArrayVisitor<N> {
}
}

/// Provided a function `names` that is allowed to store a name into a valid set,
/// Provided a list of names,
/// returns a human readable list of all the names,
/// or `None` in the case of an empty list of names.
fn one_of_names(names: impl Fn(&mut dyn ValidNames)) -> Option<impl fmt::Display> {
/// An implementation of `ValidNames` that just counts how many valid names are pushed into it.
struct CountNames(usize);

impl ValidNames for CountNames {
fn push(&mut self, _: &str) {
self.0 += 1
}
}
fn one_of_names<'a, I: Iterator<Item = Option<&'a str>>>(names: impl Fn() -> I) -> Option<impl fmt::Display> {
// Count how many names there are.
let count = names().count();

/// An implementation of `ValidNames` that provides a human friendly enumeration of names.
struct OneOfNames<'a, 'b> {
/// A `.push(_)` counter.
index: usize,
/// How many names there were.
count: usize,
/// Result of formatting thus far.
f: Result<&'a mut fmt::Formatter<'b>, fmt::Error>,
}

impl<'a, 'b> OneOfNames<'a, 'b> {
fn new(count: usize, f: &'a mut fmt::Formatter<'b>) -> Self {
Self {
index: 0,
count,
f: Ok(f),
}
}
}

impl ValidNames for OneOfNames<'_, '_> {
fn push(&mut self, name: &str) {
// This will give us, after all `.push()`es have been made, the following:
// There was at least one name; render those names.
(count != 0).then(move || {
fmt_fn(move |f| {
let mut anon_name = 0;
// An example of what happens for names "foo", "bar", and "baz":
//
// count = 1 -> "`foo`"
// = 2 -> "`foo` or `bar`"
// > 2 -> "one of `foo`, `bar`, or `baz`"

let Ok(f) = &mut self.f else {
return;
};

self.index += 1;

if let Err(e) = match (self.count, self.index) {
(1, _) => write!(f, "`{name}`"),
(2, 1) => write!(f, "`{name}`"),
(2, 2) => write!(f, "`or `{name}`"),
(_, 1) => write!(f, "one of `{name}`"),
(c, i) if i < c => write!(f, ", `{name}`"),
(_, _) => write!(f, ", `, or {name}`"),
} {
self.f = Err(e);
for (index, mut name) in names().enumerate() {
let mut name_buf: String = String::new();
let name = name.get_or_insert_with(|| {
name_buf = format!("{anon_name}");
anon_name += 1;
&name_buf
});
match (count, index) {
(1, _) => write!(f, "`{name}`"),
(2, 1) => write!(f, "`{name}`"),
(2, 2) => write!(f, "`or `{name}`"),
(_, 1) => write!(f, "one of `{name}`"),
(c, i) if i < c => write!(f, ", `{name}`"),
(_, _) => write!(f, ", `, or {name}`"),
}?;
}
}
}

// Count how many names have been pushed.
let count = CountNames(0).run(&names).0;

// There was at least one name; render those names.
(count != 0).then(|| fmt_fn(move |fmt| OneOfNames::new(count, fmt).run(&names).f.map(drop)))
Ok(())
})
})
}

/// Deserializes `none` variant of an optional value.
Expand All @@ -752,11 +702,11 @@ impl<E: Error> Default for NoneAccess<E> {
}
}

impl<E: Error> SumAccess<'_> for NoneAccess<E> {
impl<'de, E: Error> SumAccess<'de> for NoneAccess<E> {
type Error = E;
type Variant = Self;

fn variant<V: VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
fn variant<V: VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
visitor.visit_name("none").map(|var| (var, self))
}
}
Expand Down
36 changes: 22 additions & 14 deletions crates/sats/src/de/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ impl<'de, T: Deserialize<'de>> SumVisitor<'de> for OptionVisitor<T> {
}
}

impl<'de, T: Deserialize<'de>> VariantVisitor for OptionVisitor<T> {
impl<'de, T: Deserialize<'de>> VariantVisitor<'de> for OptionVisitor<T> {
type Output = bool;

fn variant_names(&self, names: &mut dyn super::ValidNames) {
names.extend(["some", "none"])
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
["some", "none"].into_iter()
}

fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
Expand Down Expand Up @@ -268,11 +268,11 @@ impl<'de, T: Deserialize<'de>, E: Deserialize<'de>> SumVisitor<'de> for ResultVi
}
}

impl<'de, T: Deserialize<'de>, U: Deserialize<'de>> VariantVisitor for ResultVisitor<T, U> {
impl<'de, T: Deserialize<'de>, U: Deserialize<'de>> VariantVisitor<'de> for ResultVisitor<T, U> {
type Output = ResultVariant;

fn variant_names(&self, names: &mut dyn super::ValidNames) {
names.extend(["ok", "err"])
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
["ok", "err"].into_iter()
}

fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
Expand Down Expand Up @@ -335,11 +335,11 @@ impl<'de, S: Copy + DeserializeSeed<'de>> SumVisitor<'de> for BoundVisitor<S> {
}
}

impl<'de, T: Copy + DeserializeSeed<'de>> VariantVisitor for BoundVisitor<T> {
impl<'de, T: Copy + DeserializeSeed<'de>> VariantVisitor<'de> for BoundVisitor<T> {
type Output = BoundVariant;

fn variant_names(&self, names: &mut dyn super::ValidNames) {
names.extend(["included", "excluded", "unbounded"])
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
["included", "excluded", "unbounded"].into_iter()
}

fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
Expand Down Expand Up @@ -420,12 +420,12 @@ impl<'de> SumVisitor<'de> for WithTypespace<'_, SumType> {
}
}

impl VariantVisitor for WithTypespace<'_, SumType> {
impl VariantVisitor<'_> for WithTypespace<'_, SumType> {
type Output = u8;

fn variant_names(&self, names: &mut dyn super::ValidNames) {
fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
// Provide the names known from the `SumType`.
names.extend(self.ty().variants.iter().filter_map(|v| v.name()))
self.ty().variants.iter().filter_map(|v| v.name())
}

fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
Expand Down Expand Up @@ -643,8 +643,8 @@ impl FieldNameVisitor<'_> for TupleNameVisitor<'_> {
// The index of the field name.
type Output = usize;

fn field_names(&self, names: &mut dyn super::ValidNames) {
names.extend(self.elems.iter().filter_map(|f| f.name()))
fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>> {
self.elems.iter().map(|f| f.name())
}

fn kind(&self) -> ProductKind {
Expand All @@ -658,6 +658,14 @@ impl FieldNameVisitor<'_> for TupleNameVisitor<'_> {
.position(|f| f.has_name(name))
.ok_or_else(|| Error::unknown_field_name(name, &self))
}

fn visit_seq<E: Error>(self, index: usize, name: &str) -> Result<Self::Output, E> {
self.elems
.get(index)
.ok_or_else(|| Error::unknown_field_name(name, &self))?;

Ok(index)
}
}

impl_deserialize!([] spacetimedb_primitives::TableId, de => u32::deserialize(de).map(Self));
Expand Down
Loading
Loading