diff --git a/graphql_client/tests/input_object_variables.rs b/graphql_client/tests/input_object_variables.rs index 1fdf5f3b..621ca824 100644 --- a/graphql_client/tests/input_object_variables.rs +++ b/graphql_client/tests/input_object_variables.rs @@ -39,7 +39,7 @@ fn input_object_variables_default() { msg: default_input_object_variables_query::Variables::default_msg(), }; - let out = serde_json::to_value(&variables).unwrap(); + let out = serde_json::to_value(variables).unwrap(); let expected_default = serde_json::json!({ "msg":{"content":null,"to":{"category":null,"email":"rosa.luxemburg@example.com","name":null}} @@ -130,7 +130,7 @@ pub struct RustNameQuery; #[test] fn rust_name_correctly_mapped() { use rust_name_query::*; - let value = serde_json::to_value(&Variables { + let value = serde_json::to_value(Variables { extern_: Some("hello".to_owned()), msg: <_>::default(), }) diff --git a/graphql_client/tests/one_of_input.rs b/graphql_client/tests/one_of_input.rs new file mode 100644 index 00000000..80ebde3c --- /dev/null +++ b/graphql_client/tests/one_of_input.rs @@ -0,0 +1,30 @@ +use graphql_client::*; +use serde_json::*; + +#[derive(GraphQLQuery)] +#[graphql( + schema_path = "tests/one_of_input/schema.graphql", + query_path = "tests/one_of_input/query.graphql", + variables_derives = "Clone" +)] +pub struct OneOfMutation; + +#[test] +fn one_of_input() { + use one_of_mutation::*; + + let author = Param::Author(Author { id: 1 }); + let _ = Param::Name("Mark Twain".to_string()); + let _ = Param::RecursiveDirect(Box::new(author.clone())); + let _ = Param::RecursiveIndirect(Box::new(Recursive { + param: Box::new(author.clone()), + })); + let _ = Param::RequiredInts(vec![1]); + let _ = Param::OptionalInts(vec![Some(1)]); + + let query = OneOfMutation::build_query(Variables { param: author }); + assert_eq!( + json!({ "param": { "author":{ "id": 1 } } }), + serde_json::to_value(&query.variables).expect("json"), + ); +} diff --git a/graphql_client/tests/one_of_input/query.graphql b/graphql_client/tests/one_of_input/query.graphql new file mode 100644 index 00000000..52f167ee --- /dev/null +++ b/graphql_client/tests/one_of_input/query.graphql @@ -0,0 +1,3 @@ +mutation OneOfMutation($param: Param!) { + oneOfMutation(query: $param) +} diff --git a/graphql_client/tests/one_of_input/schema.graphql b/graphql_client/tests/one_of_input/schema.graphql new file mode 100644 index 00000000..d2a0e875 --- /dev/null +++ b/graphql_client/tests/one_of_input/schema.graphql @@ -0,0 +1,24 @@ +schema { + mutation: Mutation +} + +type Mutation { + oneOfMutation(mutation: Param!): Int +} + +input Param @oneOf { + author: Author + name: String + recursiveDirect: Param + recursiveIndirect: Recursive + requiredInts: [Int!] + optionalInts: [Int] +} + +input Author { + id: Int! +} + +input Recursive { + param: Param! +} diff --git a/graphql_client_cli/src/main.rs b/graphql_client_cli/src/main.rs index 09acab66..d74e5451 100644 --- a/graphql_client_cli/src/main.rs +++ b/graphql_client_cli/src/main.rs @@ -153,7 +153,7 @@ fn set_env_logger() { .init(); } -fn colored_level<'a>(style: &'a mut Style, level: Level) -> StyledValue<'a, &'static str> { +fn colored_level(style: &mut Style, level: Level) -> StyledValue<'_, &'static str> { match level { Level::Trace => style.set_color(Color::Magenta).value("TRACE"), Level::Debug => style.set_color(Color::Blue).value("DEBUG"), diff --git a/graphql_client_codegen/src/codegen/inputs.rs b/graphql_client_codegen/src/codegen/inputs.rs index bc2ee5cc..475ad171 100644 --- a/graphql_client_codegen/src/codegen/inputs.rs +++ b/graphql_client_codegen/src/codegen/inputs.rs @@ -2,9 +2,10 @@ use super::shared::{field_rename_annotation, keyword_replace}; use crate::{ codegen_options::GraphQLClientCodegenOptions, query::{BoundQuery, UsedTypes}, - schema::input_is_recursive_without_indirection, + schema::{input_is_recursive_without_indirection, StoredInputType}, + type_qualifiers::GraphqlTypeQualifier, }; -use heck::ToSnakeCase; +use heck::{ToSnakeCase, ToUpperCamelCase}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -17,48 +18,112 @@ pub(super) fn generate_input_object_definitions( all_used_types .inputs(query.schema) .map(|(_input_id, input)| { - let normalized_name = options.normalization().input_name(input.name.as_str()); - let safe_name = keyword_replace(normalized_name); - let struct_name = Ident::new(safe_name.as_ref(), Span::call_site()); - - let fields = input.fields.iter().map(|(field_name, field_type)| { - let safe_field_name = keyword_replace(field_name.to_snake_case()); - let annotation = field_rename_annotation(field_name, safe_field_name.as_ref()); - let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site()); - let normalized_field_type_name = options - .normalization() - .field_type(field_type.id.name(query.schema)); - let optional_skip_serializing_none = - if *options.skip_serializing_none() && field_type.is_optional() { - Some(quote!(#[serde(skip_serializing_if = "Option::is_none")])) - } else { - None - }; - let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site()); - let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers); - let field_type = if field_type - .id - .as_input_id() - .map(|input_id| input_is_recursive_without_indirection(input_id, query.schema)) - .unwrap_or(false) - { - quote!(Box<#field_type_tokens>) - } else { - field_type_tokens - }; - - quote!( - #optional_skip_serializing_none - #annotation pub #name_ident: #field_type - ) - }); - - quote! { - #variable_derives - pub struct #struct_name{ - #(#fields,)* - } + if input.is_one_of { + generate_enum(input, options, variable_derives, query) + } else { + generate_struct(input, options, variable_derives, query) } }) .collect() } + +fn generate_struct( + input: &StoredInputType, + options: &GraphQLClientCodegenOptions, + variable_derives: &impl quote::ToTokens, + query: &BoundQuery<'_>, +) -> TokenStream { + let normalized_name = options.normalization().input_name(input.name.as_str()); + let safe_name = keyword_replace(normalized_name); + let struct_name = Ident::new(safe_name.as_ref(), Span::call_site()); + + let fields = input.fields.iter().map(|(field_name, field_type)| { + let safe_field_name = keyword_replace(field_name.to_snake_case()); + let annotation = field_rename_annotation(field_name, safe_field_name.as_ref()); + let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site()); + let normalized_field_type_name = options + .normalization() + .field_type(field_type.id.name(query.schema)); + let optional_skip_serializing_none = + if *options.skip_serializing_none() && field_type.is_optional() { + Some(quote!(#[serde(skip_serializing_if = "Option::is_none")])) + } else { + None + }; + let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site()); + let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers); + let field_type = if field_type + .id + .as_input_id() + .map(|input_id| input_is_recursive_without_indirection(input_id, query.schema)) + .unwrap_or(false) + { + quote!(Box<#field_type_tokens>) + } else { + field_type_tokens + }; + + quote!( + #optional_skip_serializing_none + #annotation pub #name_ident: #field_type + ) + }); + + quote! { + #variable_derives + pub struct #struct_name{ + #(#fields,)* + } + } +} + +fn generate_enum( + input: &StoredInputType, + options: &GraphQLClientCodegenOptions, + variable_derives: &impl quote::ToTokens, + query: &BoundQuery<'_>, +) -> TokenStream { + let normalized_name = options.normalization().input_name(input.name.as_str()); + let safe_name = keyword_replace(normalized_name); + let enum_name = Ident::new(safe_name.as_ref(), Span::call_site()); + + let variants = input.fields.iter().map(|(field_name, field_type)| { + let variant_name = field_name.to_upper_camel_case(); + let safe_variant_name = keyword_replace(&variant_name); + + let annotation = field_rename_annotation(field_name.as_ref(), &variant_name); + let name_ident = Ident::new(safe_variant_name.as_ref(), Span::call_site()); + + let normalized_field_type_name = options + .normalization() + .field_type(field_type.id.name(query.schema)); + let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site()); + + // Add the required qualifier so that the variant's field isn't wrapped in Option + let mut qualifiers = vec![GraphqlTypeQualifier::Required]; + qualifiers.extend(field_type.qualifiers.iter().cloned()); + + let field_type_tokens = super::decorate_type(&type_name, &qualifiers); + let field_type = if field_type + .id + .as_input_id() + .map(|input_id| input_is_recursive_without_indirection(input_id, query.schema)) + .unwrap_or(false) + { + quote!(Box<#field_type_tokens>) + } else { + field_type_tokens + }; + + quote!( + #annotation #name_ident(#field_type) + ) + }); + + quote! { + #variable_derives + pub enum #enum_name{ + #(#variants,)* + } + } +} diff --git a/graphql_client_codegen/src/deprecation.rs b/graphql_client_codegen/src/deprecation.rs index c33f9260..4c5566c2 100644 --- a/graphql_client_codegen/src/deprecation.rs +++ b/graphql_client_codegen/src/deprecation.rs @@ -8,22 +8,17 @@ pub enum DeprecationStatus { } /// The available deprecation strategies. -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, Default)] pub enum DeprecationStrategy { /// Allow use of deprecated items in queries, and say nothing. Allow, /// Fail compilation if a deprecated item is used. Deny, /// Allow use of deprecated items in queries, but warn about them (default). + #[default] Warn, } -impl Default for DeprecationStrategy { - fn default() -> Self { - DeprecationStrategy::Warn - } -} - impl std::str::FromStr for DeprecationStrategy { type Err = (); diff --git a/graphql_client_codegen/src/schema.rs b/graphql_client_codegen/src/schema.rs index 2be24c3e..fc3b5094 100644 --- a/graphql_client_codegen/src/schema.rs +++ b/graphql_client_codegen/src/schema.rs @@ -210,6 +210,7 @@ impl StoredInputFieldType { pub(crate) struct StoredInputType { pub(crate) name: String, pub(crate) fields: Vec<(String, StoredInputFieldType)>, + pub(crate) is_one_of: bool, } /// Intermediate representation for a parsed GraphQL schema used during code generation. diff --git a/graphql_client_codegen/src/schema/graphql_parser_conversion.rs b/graphql_client_codegen/src/schema/graphql_parser_conversion.rs index a2e5b6ed..fadd9f46 100644 --- a/graphql_client_codegen/src/schema/graphql_parser_conversion.rs +++ b/graphql_client_codegen/src/schema/graphql_parser_conversion.rs @@ -289,6 +289,11 @@ fn ingest_input<'doc, T>(schema: &mut Schema, input: &mut parser::InputObjectTyp where T: graphql_parser::query::Text<'doc>, { + let is_one_of = input + .directives + .iter() + .any(|directive| directive.name.as_ref() == "oneOf"); + let input = super::StoredInputType { name: input.name.as_ref().into(), fields: input @@ -305,6 +310,7 @@ where ) }) .collect(), + is_one_of, }; schema.stored_inputs.push(input); diff --git a/graphql_client_codegen/src/schema/json_conversion.rs b/graphql_client_codegen/src/schema/json_conversion.rs index 52a92cc9..2018207c 100644 --- a/graphql_client_codegen/src/schema/json_conversion.rs +++ b/graphql_client_codegen/src/schema/json_conversion.rs @@ -296,6 +296,9 @@ fn ingest_input(schema: &mut Schema, input: &mut FullType) { let input = super::StoredInputType { fields, name: input.name.take().expect("Input without a name"), + // The one-of input spec is not stable yet, thus the introspection query does not have + // `isOneOf`, so this is always false. + is_one_of: false, }; schema.stored_inputs.push(input);