Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate an enum for @oneOf input #450

Merged
merged 6 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions graphql_client/tests/input_object_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn input_object_variables_default() {
msg: default_input_object_variables_query::Variables::default_msg(),
};

let out = serde_json::to_value(&variables).unwrap();
let out = serde_json::to_value(variables).unwrap();

let expected_default = serde_json::json!({
"msg":{"content":null,"to":{"category":null,"email":"rosa.luxemburg@example.com","name":null}}
Expand Down Expand Up @@ -130,7 +130,7 @@ pub struct RustNameQuery;
#[test]
fn rust_name_correctly_mapped() {
use rust_name_query::*;
let value = serde_json::to_value(&Variables {
let value = serde_json::to_value(Variables {
extern_: Some("hello".to_owned()),
msg: <_>::default(),
})
Expand Down
30 changes: 30 additions & 0 deletions graphql_client/tests/one_of_input.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use graphql_client::*;
use serde_json::*;

#[derive(GraphQLQuery)]
#[graphql(
schema_path = "tests/one_of_input/schema.graphql",
query_path = "tests/one_of_input/query.graphql",
variables_derives = "Clone"
)]
pub struct OneOfMutation;

#[test]
fn one_of_input() {
use one_of_mutation::*;

let author = Param::Author(Author { id: 1 });
let _ = Param::Name("Mark Twain".to_string());
let _ = Param::RecursiveDirect(Box::new(author.clone()));
let _ = Param::RecursiveIndirect(Box::new(Recursive {
param: Box::new(author.clone()),
}));
let _ = Param::RequiredInts(vec![1]);
let _ = Param::OptionalInts(vec![Some(1)]);

let query = OneOfMutation::build_query(Variables { param: author });
assert_eq!(
json!({ "param": { "author":{ "id": 1 } } }),
serde_json::to_value(&query.variables).expect("json"),
);
}
3 changes: 3 additions & 0 deletions graphql_client/tests/one_of_input/query.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mutation OneOfMutation($param: Param!) {
oneOfMutation(query: $param)
}
24 changes: 24 additions & 0 deletions graphql_client/tests/one_of_input/schema.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
schema {
mutation: Mutation
}

type Mutation {
oneOfMutation(mutation: Param!): Int
}

input Param @oneOf {
author: Author
name: String
recursiveDirect: Param
recursiveIndirect: Recursive
requiredInts: [Int!]
optionalInts: [Int]
}

input Author {
id: Int!
}

input Recursive {
param: Param!
}
2 changes: 1 addition & 1 deletion graphql_client_cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ fn set_env_logger() {
.init();
}

fn colored_level<'a>(style: &'a mut Style, level: Level) -> StyledValue<'a, &'static str> {
fn colored_level(style: &mut Style, level: Level) -> StyledValue<'_, &'static str> {
match level {
Level::Trace => style.set_color(Color::Magenta).value("TRACE"),
Level::Debug => style.set_color(Color::Blue).value("DEBUG"),
Expand Down
151 changes: 108 additions & 43 deletions graphql_client_codegen/src/codegen/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use super::shared::{field_rename_annotation, keyword_replace};
use crate::{
codegen_options::GraphQLClientCodegenOptions,
query::{BoundQuery, UsedTypes},
schema::input_is_recursive_without_indirection,
schema::{input_is_recursive_without_indirection, StoredInputType},
type_qualifiers::GraphqlTypeQualifier,
};
use heck::ToSnakeCase;
use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;

Expand All @@ -17,48 +18,112 @@ pub(super) fn generate_input_object_definitions(
all_used_types
.inputs(query.schema)
.map(|(_input_id, input)| {
let normalized_name = options.normalization().input_name(input.name.as_str());
let safe_name = keyword_replace(normalized_name);
let struct_name = Ident::new(safe_name.as_ref(), Span::call_site());

let fields = input.fields.iter().map(|(field_name, field_type)| {
let safe_field_name = keyword_replace(field_name.to_snake_case());
let annotation = field_rename_annotation(field_name, safe_field_name.as_ref());
let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site());
let normalized_field_type_name = options
.normalization()
.field_type(field_type.id.name(query.schema));
let optional_skip_serializing_none =
if *options.skip_serializing_none() && field_type.is_optional() {
Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
} else {
None
};
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers);
let field_type = if field_type
.id
.as_input_id()
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
.unwrap_or(false)
{
quote!(Box<#field_type_tokens>)
} else {
field_type_tokens
};

quote!(
#optional_skip_serializing_none
#annotation pub #name_ident: #field_type
)
});

quote! {
#variable_derives
pub struct #struct_name{
#(#fields,)*
}
if input.is_one_of {
generate_enum(input, options, variable_derives, query)
} else {
generate_struct(input, options, variable_derives, query)
}
})
.collect()
}

fn generate_struct(
input: &StoredInputType,
options: &GraphQLClientCodegenOptions,
variable_derives: &impl quote::ToTokens,
query: &BoundQuery<'_>,
) -> TokenStream {
let normalized_name = options.normalization().input_name(input.name.as_str());
let safe_name = keyword_replace(normalized_name);
let struct_name = Ident::new(safe_name.as_ref(), Span::call_site());

let fields = input.fields.iter().map(|(field_name, field_type)| {
let safe_field_name = keyword_replace(field_name.to_snake_case());
let annotation = field_rename_annotation(field_name, safe_field_name.as_ref());
let name_ident = Ident::new(safe_field_name.as_ref(), Span::call_site());
let normalized_field_type_name = options
.normalization()
.field_type(field_type.id.name(query.schema));
let optional_skip_serializing_none =
if *options.skip_serializing_none() && field_type.is_optional() {
Some(quote!(#[serde(skip_serializing_if = "Option::is_none")]))
} else {
None
};
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers);
let field_type = if field_type
.id
.as_input_id()
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
.unwrap_or(false)
{
quote!(Box<#field_type_tokens>)
} else {
field_type_tokens
};

quote!(
#optional_skip_serializing_none
#annotation pub #name_ident: #field_type
)
});

quote! {
#variable_derives
pub struct #struct_name{
#(#fields,)*
}
}
}

fn generate_enum(
input: &StoredInputType,
options: &GraphQLClientCodegenOptions,
variable_derives: &impl quote::ToTokens,
query: &BoundQuery<'_>,
) -> TokenStream {
let normalized_name = options.normalization().input_name(input.name.as_str());
let safe_name = keyword_replace(normalized_name);
let enum_name = Ident::new(safe_name.as_ref(), Span::call_site());

let variants = input.fields.iter().map(|(field_name, field_type)| {
let variant_name = field_name.to_upper_camel_case();
let safe_variant_name = keyword_replace(&variant_name);

let annotation = field_rename_annotation(field_name.as_ref(), &variant_name);
let name_ident = Ident::new(safe_variant_name.as_ref(), Span::call_site());

let normalized_field_type_name = options
.normalization()
.field_type(field_type.id.name(query.schema));
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());

// Add the required qualifier so that the variant's field isn't wrapped in Option
let mut qualifiers = vec![GraphqlTypeQualifier::Required];
qualifiers.extend(field_type.qualifiers.iter().cloned());
surma marked this conversation as resolved.
Show resolved Hide resolved
surma marked this conversation as resolved.
Show resolved Hide resolved

let field_type_tokens = super::decorate_type(&type_name, &qualifiers);
let field_type = if field_type
.id
.as_input_id()
.map(|input_id| input_is_recursive_without_indirection(input_id, query.schema))
.unwrap_or(false)
{
quote!(Box<#field_type_tokens>)
} else {
field_type_tokens
};

quote!(
#annotation #name_ident(#field_type)
)
});

quote! {
#variable_derives
pub enum #enum_name{
#(#variants,)*
}
}
}
9 changes: 2 additions & 7 deletions graphql_client_codegen/src/deprecation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@ pub enum DeprecationStatus {
}

/// The available deprecation strategies.
#[derive(Debug, PartialEq, Eq, Clone)]
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub enum DeprecationStrategy {
/// Allow use of deprecated items in queries, and say nothing.
Allow,
/// Fail compilation if a deprecated item is used.
Deny,
/// Allow use of deprecated items in queries, but warn about them (default).
#[default]
Warn,
}

impl Default for DeprecationStrategy {
fn default() -> Self {
DeprecationStrategy::Warn
}
}

impl std::str::FromStr for DeprecationStrategy {
type Err = ();

Expand Down
1 change: 1 addition & 0 deletions graphql_client_codegen/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ impl StoredInputFieldType {
pub(crate) struct StoredInputType {
pub(crate) name: String,
pub(crate) fields: Vec<(String, StoredInputFieldType)>,
pub(crate) is_one_of: bool,
}

/// Intermediate representation for a parsed GraphQL schema used during code generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ fn ingest_input<'doc, T>(schema: &mut Schema, input: &mut parser::InputObjectTyp
where
T: graphql_parser::query::Text<'doc>,
{
let is_one_of = input
.directives
.iter()
.any(|directive| directive.name.as_ref() == "oneOf");

let input = super::StoredInputType {
name: input.name.as_ref().into(),
fields: input
Expand All @@ -305,6 +310,7 @@ where
)
})
.collect(),
is_one_of,
};

schema.stored_inputs.push(input);
Expand Down
3 changes: 3 additions & 0 deletions graphql_client_codegen/src/schema/json_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ fn ingest_input(schema: &mut Schema, input: &mut FullType) {
let input = super::StoredInputType {
fields,
name: input.name.take().expect("Input without a name"),
// The one-of input spec is not stable yet, thus the introspection query does not have
// `isOneOf`, so this is always false.
is_one_of: false,
};

schema.stored_inputs.push(input);
Expand Down