Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion bon-macros/src/builder/builder_gen/finish_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,34 @@ impl super::BuilderGenCtx {
}
});

// Collect `Default` bounds for members with bare `#[builder(default)]`
// (no explicit default value). This allows us to set the required
// default implementations on the builder method, rather than reusing
// the bounds on the struct (which might require over-specifying bounds).

let type_params: Vec<_> = self
.generics
.decl_without_defaults
.iter()
.filter_map(|decl| match decl {
syn::GenericParam::Type(type_param) => Some(&type_param.ident),
_ => None,
})
.collect();

let mut seen_default_generic_types = std::collections::HashSet::new();
let default_bounds: Vec<_> = self.named_members().filter_map(|member| {
if !matches!(member.config.default.as_ref(), Some(default) if default.value.is_none()) {
return None;
}
let ty = &member.ty.norm;
let uses_generic = super::generic_setters::type_uses_generic_params(ty, &type_params);

(uses_generic && seen_default_generic_types.insert(ty)).then(|| {
quote! { #ty: ::core::default::Default }
})
}).collect();

let state_mod = &self.state_mod.ident;

let finish_fn_params = self.finish_fn_args().map(PosFnMember::fn_input_param);
Expand Down Expand Up @@ -181,7 +209,8 @@ impl super::BuilderGenCtx {
#(#finish_fn_params,)*
) #output
where
#state_var: #state_mod::IsComplete
#state_var: #state_mod::IsComplete,
#(#default_bounds,)*
{
#(#members_vars_decls)*
#body
Expand Down
20 changes: 12 additions & 8 deletions bon-macros/src/builder/builder_gen/generic_setters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,13 @@ fn replace_type_param_in_predicate(
/// Check if a member's type uses a specific generic parameter
fn member_uses_generic_param(member: &super::NamedMember, param_ident: &syn::Ident) -> bool {
let member_ty = member.underlying_norm_ty();
type_uses_generic_param(member_ty, param_ident)
type_uses_generic_params(member_ty, &[param_ident])
}

/// Recursively check if a type uses a specific generic parameter
fn type_uses_generic_param(ty: &syn::Type, param_ident: &syn::Ident) -> bool {
/// Recursively check if a type uses any of the given generic type parameters
pub(super) fn type_uses_generic_params(ty: &syn::Type, param_idents: &[&syn::Ident]) -> bool {
struct GenericParamVisitor<'a> {
param_ident: &'a syn::Ident,
param_idents: &'a [&'a syn::Ident],
found: bool,
}

Expand All @@ -382,8 +382,12 @@ fn type_uses_generic_param(ty: &syn::Type, param_ident: &syn::Ident) -> bool {
return;
}

// Check if the path is the generic parameter we're looking for
if type_path.path.is_ident(self.param_ident) {
// Check if the path is one of the generic parameters we're looking for
if type_path
.path
.get_ident()
.map_or(false, |ident| self.param_idents.contains(&ident))
{
self.found = true;
return;
}
Expand All @@ -396,7 +400,7 @@ fn type_uses_generic_param(ty: &syn::Type, param_ident: &syn::Ident) -> bool {
self.visit_type(&qself.ty);
} else if let Some(segment) = type_path.path.segments.first() {
// For T::Assoc syntax
if segment.ident == *self.param_ident {
if self.param_idents.contains(&&segment.ident) {
self.found = true;
return;
}
Expand All @@ -408,7 +412,7 @@ fn type_uses_generic_param(ty: &syn::Type, param_ident: &syn::Ident) -> bool {
}

let mut visitor = GenericParamVisitor {
param_ident,
param_idents,
found: false,
};
visitor.visit_type(ty);
Expand Down
82 changes: 82 additions & 0 deletions bon/tests/integration/builder/attr_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,88 @@ fn fn_generic_default() {
sut::<(), ()>().call();
}

#[test]
fn different_generic_with_default() {
#[derive(Builder)]
struct Sut<A> {
#[builder(default)]
x1: A,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
struct B;

let result = Sut::<B>::builder().build();
assert_eq!(result.x1, B);
}

#[test]
fn default_bound_not_required_when_provided() {
trait MyDefault {
fn my_default() -> Self;
}

#[derive(Debug, PartialEq)]
struct NotDefault(u32);

impl MyDefault for NotDefault {
fn my_default() -> Self {
NotDefault(42)
}
}

#[derive(Builder)]
struct Sut<A: MyDefault> {
#[builder(default = A::my_default())]
x1: A,
}

let result = Sut::<NotDefault>::builder().build();
assert_eq!(result.x1, NotDefault(42));
}

#[test]
fn generic_type_which_always_implements_default_compiles() {
#[derive(Builder)]
struct Sut<A> {
#[builder(default)]
items: Vec<A>,
}

let result = Sut::<String>::builder().build();
assert!(result.items.is_empty());

let result = Sut::<u32>::builder().items(vec![1, 2, 3]).build();
assert_eq!(result.items, vec![1, 2, 3]);
}

#[test]
fn generic_type_only_requires_default_on_fields_with_default() {
#[derive(Builder)]
struct Sut<A, B> {
required: A,

#[builder(default)]
optional2: B,

#[builder(default)]
optional: B,
}

#[derive(Debug, PartialEq)]
struct NotDefault;

#[derive(Debug, PartialEq, Default)]
struct HasDefault;

let result = Sut::<NotDefault, HasDefault>::builder()
.required(NotDefault)
.build();

assert_eq!(result.required, NotDefault);
assert_eq!(result.optional, HasDefault);
}

mod interaction_with_positional_members {
use crate::prelude::*;
use core::fmt;
Expand Down
29 changes: 29 additions & 0 deletions bon/tests/integration/builder/generics_setters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,32 @@ fn test_with_trait_bounds_false_friend() {
.build();
assert_eq!(result2.value, 99u64);
}

#[test]
fn test_with_default_on_generic_field() {
trait MyTrait {}

#[derive(Debug, PartialEq, Default)]
struct MyImplWithDefault;
impl MyTrait for MyImplWithDefault {}

#[derive(Debug, PartialEq, Default)]
struct MyOtherImpl;
impl MyTrait for MyOtherImpl {}

#[derive(Builder)]
#[builder(generics(setters = "conv_{}"))]
struct Sut<A: MyTrait> {
#[builder(default)]
value: A,
}

let result = Sut::<MyImplWithDefault>::builder().build();
assert_eq!(result.value, MyImplWithDefault);

let result = Sut::<MyImplWithDefault>::builder()
.conv_a::<MyOtherImpl>()
.value(MyOtherImpl)
.build();
assert_eq!(result.value, MyOtherImpl);
}