Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support constant expressions as index attribute variant #551

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub fn quote(

let recurse = data_variants().enumerate().map(|(i, v)| {
let name = &v.ident;
let index = utils::variant_index(v, i);
let index = utils::variant_index(v, i, &mut Vec::new());

let create = create_instance(
quote! { #type_name #type_generics :: #name },
Expand Down
38 changes: 33 additions & 5 deletions derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,11 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
return quote!()
}

let recurse = data_variants().enumerate().map(|(i, f)| {
let mut collected_const_indices: Vec<TokenStream> = Vec::new();

let recurse: Vec<[TokenStream; 2]> = data_variants().enumerate().map(|(i, f)| {
let name = &f.ident;
let index = utils::variant_index(f, i);
let index = utils::variant_index(f, i, &mut collected_const_indices);

match f.fields {
Fields::Named(ref fields) => {
Expand Down Expand Up @@ -397,10 +399,34 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
[hinting, encoding]
},
}
});
}).collect();

if let Some((duplicate, token_str)) =
utils::find_const_duplicate(&collected_const_indices)
{
let error_message =
format!("index value `{}` is assigned more than once", token_str);
return syn::Error::new_spanned(&duplicate, error_message).to_compile_error();
}

let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting);
let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding);
let recurse_hinting = recurse.iter().map(|[hinting, _]| hinting.clone());
let recurse_encoding = recurse.iter().map(|[_, encoding]| encoding.clone());

// Runtime check to ensure index attribute variant is within u8 range.
let runtime_checks: Vec<_> = collected_const_indices
.iter()
.enumerate()
.map(|(idx, expr)| {
let check_const =
syn::Ident::new(&format!("CHECK_{}", idx), proc_macro2::Span::call_site());
quote! {
const #check_const: u8 = #expr;
if #check_const as u32 > 255 {
panic!("Index attribute variant must be in 0..255, found {}", #check_const);
}
}
})
.collect();

let hinting = quote! {
// The variant index uses 1 byte.
Expand All @@ -411,6 +437,8 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
};

let encoding = quote! {
#( #runtime_checks )*

match *#self_ {
#( #recurse_encoding )*,
_ => (),
Expand Down
157 changes: 110 additions & 47 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::{
parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput,
Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant,
Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant, Expr,
};

fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option<R>
Expand All @@ -37,32 +37,41 @@ where
})
}

/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// Look for a `#[codec(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some(byte)
}
}
}

None
});

// then fallback to discriminant or just index
index.map(|i| quote! { #i }).unwrap_or_else(|| {
v.discriminant
.as_ref()
.map(|(_, expr)| quote! { #expr })
.unwrap_or_else(|| quote! { #i })
})
pub fn variant_index(
variant: &Variant,
i: usize,
collected_const_indices: &mut Vec<TokenStream>,
) -> TokenStream {
let mut index_option: Option<TokenStream> = None;

for attr in variant.attrs.iter().filter(|attr| attr.path.is_ident("codec")) {
if let Ok(codec_variants) = attr.parse_args::<CodecVariants>() {
if let Some(codec_index) = codec_variants.index {
match codec_index {
CodecIndex::U8(value) => {
index_option = Some(quote! { #value });
break;
},
CodecIndex::ExprConst(expr) => {
collected_const_indices.push(quote! { #expr });
index_option = Some(quote! { #expr });
break;
}
}
}
}
}

// Fallback to discriminant or index
index_option.unwrap_or_else(|| {
variant
.discriminant
.as_ref()
.map(|(_, expr)| quote! { #expr })
.unwrap_or_else(|| quote! { #i })
})
}

/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given
Expand Down Expand Up @@ -369,34 +378,75 @@ fn check_field_attribute(attr: &Attribute) -> syn::Result<()> {
}
}

pub enum CodecIndex {
U8(u8),
ExprConst(Expr),
}

struct CodecVariants {
skip: bool,
index: Option<CodecIndex>,
}

const INDEX_RANGE_ERROR: &str = "Index attribute variant must be in 0..255";
const INDEX_TYPE_ERROR: &str =
"Only u8 indices are accepted for attribute variant `#[codec(index = $u8)]`";
const ATTRIBUTE_ERROR: &str =
"Invalid attribute on variant, only `#[codec(skip)]` and `#[codec(index = $u8)]` are accepted.";

impl Parse for CodecVariants {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut skip = false;
let mut index = None;

while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(syn::Ident) {
let ident: syn::Ident = input.parse()?;
if ident == "skip" {
skip = true;
} else if ident == "index" {
input.parse::<Token![=]>()?;
if let Ok(lit) = input.parse::<syn::LitInt>() {
let parsed_index = lit
.base10_parse::<u8>()
.map_err(|_| syn::Error::new(lit.span(), INDEX_RANGE_ERROR));
index = Some(CodecIndex::U8(parsed_index?));
} else {
let expr = input
.parse::<syn::Expr>()
.map_err(|_| syn::Error::new_spanned(ident, INDEX_TYPE_ERROR))?;
index = Some(CodecIndex::ExprConst(expr));
}
} else {
return Err(syn::Error::new_spanned(ident, ATTRIBUTE_ERROR));
}
} else {
return Err(lookahead.error());
}
}

Ok(CodecVariants { skip, index })
}
}

// Ensure a field is decorated only with the following attributes:
// * `#[codec(skip)]`
// * `#[codec(index = $int)]`
fn check_variant_attribute(attr: &Attribute) -> syn::Result<()> {
let variant_error = "Invalid attribute on variant, only `#[codec(skip)]` and \
`#[codec(index = $u8)]` are accepted.";

if attr.path.is_ident("codec") {
match attr.parse_meta()? {
Meta::List(ref meta_list) if meta_list.nested.len() == 1 => {
match meta_list.nested.first().expect("Just checked that there is one item; qed") {
NestedMeta::Meta(Meta::Path(path))
if path.get_ident().map_or(false, |i| i == "skip") =>
Ok(()),

NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Int(lit_int),
..
})) if path.get_ident().map_or(false, |i| i == "index") => lit_int
.base10_parse::<u8>()
.map(|_| ())
.map_err(|_| syn::Error::new(lit_int.span(), "Index must be in 0..255")),

elt => Err(syn::Error::new(elt.span(), variant_error)),
match attr.parse_args::<CodecVariants>() {
Ok(codec_variants) => {
if codec_variants.skip || codec_variants.index.is_some() {
Ok(())
} else {
Err(syn::Error::new_spanned(attr, ATTRIBUTE_ERROR))
}
},
meta => Err(syn::Error::new(meta.span(), variant_error)),
Err(e) => Err(syn::Error::new_spanned(
attr,
format!("Error checking variant attribute: {}", e),
)),
}
} else {
Ok(())
Expand Down Expand Up @@ -451,3 +501,16 @@ pub fn is_transparent(attrs: &[syn::Attribute]) -> bool {
// TODO: When migrating to syn 2 the `"(transparent)"` needs to be changed into `"transparent"`.
check_repr(attrs, "(transparent)")
}

/// Find a duplicate `TokenStream` in a list of indices.
/// Each `TokenStream` is a constant expression expected to represent an u8 index for an enum variant.
pub fn find_const_duplicate(indices: &[TokenStream]) -> Option<(TokenStream, String)> {
let mut seen = std::collections::HashSet::new();
for index in indices {
let token_str = index.to_token_stream().to_string();
if !seen.insert(token_str.clone()) {
return Some((index.clone(), token_str));
}
}
None
}
12 changes: 12 additions & 0 deletions tests/scale_codec_ui/duplicate_const_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#[derive(::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
pub enum Enum {
#[codec(index = MY_CONST_INDEX)]
Variant1,
#[codec(index = MY_CONST_INDEX)]
Variant2,
}

const MY_CONST_INDEX: u8 = 1;

fn main() {}
5 changes: 5 additions & 0 deletions tests/scale_codec_ui/duplicate_const_expr.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: index value `MY_CONST_INDEX` is assigned more than once
--> tests/scale_codec_ui/duplicate_const_expr.rs:6:21
|
6 | #[codec(index = MY_CONST_INDEX)]
| ^^^^^^^^^^^^^^
8 changes: 8 additions & 0 deletions tests/scale_codec_ui/invalid_attr_name.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#[derive(::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
pub enum Enum {
#[codec(scale = 1)]
Variant1,
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/scale_codec_ui/invalid_attr_name.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: Error checking variant attribute: Invalid attribute on variant, only `#[codec(skip)]` and `#[codec(index = $u8)]` are accepted.
--> tests/scale_codec_ui/invalid_attr_name.rs:4:5
|
4 | #[codec(scale = 1)]
| ^^^^^^^^^^^^^^^^^^^
8 changes: 8 additions & 0 deletions tests/scale_codec_ui/invalid_attr_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#[derive(::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
pub enum Enum {
#[codec(index = "invalid")]
Variant1,
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/scale_codec_ui/invalid_attr_type.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: Error checking variant attribute: Only u8 indices are accepted for attribute variant `#[codec(index = $u8)]`
--> tests/scale_codec_ui/invalid_attr_type.rs:4:5
|
4 | #[codec(index = "invalid")]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
8 changes: 8 additions & 0 deletions tests/scale_codec_ui/overflowing_index_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#[derive(::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
pub enum Enum {
#[codec(index = 256)]
Variant1,
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/scale_codec_ui/overflowing_index_value.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: Error checking variant attribute: Index attribute variant must be in 0..255
--> tests/scale_codec_ui/overflowing_index_value.rs:4:5
|
4 | #[codec(index = 256)]
| ^^^^^^^^^^^^^^^^^^^^^
35 changes: 35 additions & 0 deletions tests/variant_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,38 @@ fn index_attr_variant_counted_and_reused_in_default_index() {
assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![1]);
}

#[test]
fn different_const_expr_in_index_attr_variant() {
const MY_CONST_INDEX: u8 = 1;
const ANOTHER_CONST_INDEX: u8 = 2;

#[derive(DeriveEncode)]
enum T {
#[codec(index = MY_CONST_INDEX)]
A,
B,
#[codec(index = ANOTHER_CONST_INDEX)]
C,
#[codec(index = 3)]
D,
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![1]);
assert_eq!(T::C.encode(), vec![2]);
assert_eq!(T::D.encode(), vec![3]);
}

#[test]
fn complex_const_expr_in_index_attr_variant() {
const MY_CONST_INDEX: u8 = 1;

#[derive(DeriveEncode)]
enum T {
#[codec(index = MY_CONST_INDEX + 1_u8)]
A,
}

assert_eq!(T::A.encode(), vec![2]);
}