Skip to content

[Draft] Skip null values when serializing in rust #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions example_with_targets/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,34 @@ fn test_target_b() -> Result<()> {
assert_eq!(result, expected);
Ok(())
}

#[test]
fn test_null_field_skipping() -> Result<()> {
let test_function_none =
|_input: crate::schema::target_a::Input| -> Result<crate::schema::FunctionTargetAResult> {
Ok(crate::schema::FunctionTargetAResult {
status: None, // This should not appear in serialized output
})
};

let test_function_some =
|_input: crate::schema::target_a::Input| -> Result<crate::schema::FunctionTargetAResult> {
Ok(crate::schema::FunctionTargetAResult {
status: Some(200), // This should appear in serialized output
})
};

let test_input = r#"{
"id": "gid://shopify/Order/1234567890",
"num": 123,
"name": "test"
}"#;

let result_none = run_function_with_input(test_function_none, test_input)?;
let result_some = run_function_with_input(test_function_some, test_input)?;

assert_eq!(result_none.status, None);
assert_eq!(result_some.status, Some(200));

Ok(())
}
14 changes: 6 additions & 8 deletions integration_tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn build_example(name: &str) -> Result<()> {
}

static FUNCTION_RUNNER_PATH: LazyLock<anyhow::Result<PathBuf>> = LazyLock::new(|| {
let path = workspace_root().join(format!("tmp/function-runner-{}", FUNCTION_RUNNER_VERSION));
let path = workspace_root().join(format!("tmp/function-runner-{FUNCTION_RUNNER_VERSION}"));

if !path.exists() {
std::fs::create_dir_all(workspace_root().join("tmp"))?;
Expand All @@ -44,7 +44,7 @@ static FUNCTION_RUNNER_PATH: LazyLock<anyhow::Result<PathBuf>> = LazyLock::new(|
});

static TRAMPOLINE_PATH: LazyLock<anyhow::Result<PathBuf>> = LazyLock::new(|| {
let path = workspace_root().join(format!("tmp/trampoline-{}", TRAMPOLINE_VERSION));
let path = workspace_root().join(format!("tmp/trampoline-{TRAMPOLINE_VERSION}"));
if !path.exists() {
std::fs::create_dir_all(workspace_root().join("tmp"))?;
download_trampoline(&path)?;
Expand All @@ -56,8 +56,7 @@ fn download_function_runner(destination: &PathBuf) -> Result<()> {
download_from_github(
|target_arch, target_os| {
format!(
"https://github.com/Shopify/function-runner/releases/download/v{}/function-runner-{}-{}-v{}.gz",
FUNCTION_RUNNER_VERSION, target_arch, target_os, FUNCTION_RUNNER_VERSION,
"https://github.com/Shopify/function-runner/releases/download/v{FUNCTION_RUNNER_VERSION}/function-runner-{target_arch}-{target_os}-v{FUNCTION_RUNNER_VERSION}.gz"
)
},
destination,
Expand All @@ -68,8 +67,7 @@ fn download_trampoline(destination: &PathBuf) -> Result<()> {
download_from_github(
|target_arch, target_os| {
format!(
"https://github.com/Shopify/shopify-function-wasm-api/releases/download/shopify_function_trampoline/v{}/shopify-function-trampoline-{}-{}-v{}.gz",
TRAMPOLINE_VERSION, target_arch, target_os, TRAMPOLINE_VERSION,
"https://github.com/Shopify/shopify-function-wasm-api/releases/download/shopify_function_trampoline/v{TRAMPOLINE_VERSION}/shopify-function-trampoline-{target_arch}-{target_os}-v{TRAMPOLINE_VERSION}.gz"
)
},
destination,
Expand Down Expand Up @@ -127,10 +125,10 @@ pub fn prepare_example(name: &str) -> Result<PathBuf> {
build_example(name)?;
let wasm_path = workspace_root()
.join("target/wasm32-wasip1/release")
.join(format!("{}.wasm", name));
.join(format!("{name}.wasm"));
let trampolined_path = workspace_root()
.join("target/wasm32-wasip1/release")
.join(format!("{}-trampolined.wasm", name));
.join(format!("{name}-trampolined.wasm"));
let trampoline_path = TRAMPOLINE_PATH
.as_ref()
.map_err(|e| anyhow::anyhow!("Failed to download trampoline: {}", e))?;
Expand Down
71 changes: 62 additions & 9 deletions shopify_function_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@ pub fn typegen(
module.to_token_stream().into()
}

/// Helper function to determine if a GraphQL input field type is nullable
/// Uses conservative detection to identify Optional fields from GraphQL schema
fn is_input_field_nullable(ivd: &impl InputValueDefinition) -> bool {
// Use std::any::type_name to get type information as a string
let type_name = std::any::type_name_of_val(&ivd.r#type());
Copy link
Contributor Author

@mssalemi mssalemi Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super familiar with bluejay, but I think if we have something like:

ivd.r#type().is_nullable() 

Then we could avoid using this approach. We need to get the type information so we can match correctly. I ran into issues when trying to use matching because some types were required (aka not optional), and some were optional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should be using bluejay here for the type information


// only treat fields that are explicitly nullable as Option types
// This prevents incorrectly wrapping required fields in Option<T>
type_name.contains("Option")
|| (type_name.contains("Nullable") && !type_name.contains("NonNull"))
}

struct ShopifyFunctionCodeGenerator;

impl CodeGenerator for ShopifyFunctionCodeGenerator {
Expand Down Expand Up @@ -496,35 +508,76 @@ impl CodeGenerator for ShopifyFunctionCodeGenerator {
) -> Vec<syn::ItemImpl> {
let name_ident = names::type_ident(input_object_type_definition.name());

// Conditionally serialize fields based on GraphQL schema nullability
// Nullable fields (Option<T>) are only serialized if Some(_)
// Required fields are always serialized

let field_statements: Vec<syn::Stmt> = input_object_type_definition
.input_field_definitions()
.iter()
.flat_map(|ivd| {
let field_name_ident = names::field_ident(ivd.name());
let field_name_lit_str = syn::LitStr::new(ivd.name(), Span::mixed_site());

vec![
// Check if this field is nullable in the GraphQL schema
if is_input_field_nullable(ivd) {
// For nullable fields, only serialize if Some(_)
vec![parse_quote! {
if let ::std::option::Option::Some(ref value) = self.#field_name_ident {
context.write_utf8_str(#field_name_lit_str)?;
value.serialize(context)?;
}
}]
} else {
// For required fields, always serialize
vec![
parse_quote! {
context.write_utf8_str(#field_name_lit_str)?;
},
parse_quote! {
self.#field_name_ident.serialize(context)?;
},
]
}
})
.collect();

// Generate field counting statements for dynamic field count calculation
let field_count_statements: Vec<syn::Stmt> = input_object_type_definition
.input_field_definitions()
.iter()
.map(|ivd| {
let field_name_ident = names::field_ident(ivd.name());

if is_input_field_nullable(ivd) {
// For nullable fields, count only if Some(_)
parse_quote! {
context.write_utf8_str(#field_name_lit_str)?;
},
if let ::std::option::Option::Some(_) = self.#field_name_ident {
field_count += 1;
}
}
} else {
// For required fields, always count
parse_quote! {
self.#field_name_ident.serialize(context)?;
},
]
field_count += 1;
}
}
})
.collect();
Comment on lines +545 to 566
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just consolidate this into one line where we pre-compute the number of non-nullable fields and add the rest based on value, e.g. we might end up with:

let field_count: usize = 5 + self.foo.is_some().into() + self.bar.is_some().into();


let num_fields = input_object_type_definition.input_field_definitions().len();

let serialize_impl = parse_quote! {
impl shopify_function::wasm_api::Serialize for #name_ident {
fn serialize(&self, context: &mut shopify_function::wasm_api::Context) -> ::std::result::Result<(), shopify_function::wasm_api::write::Error> {
// Calculate dynamic field count based on non-null fields
let mut field_count = 0usize;
#(#field_count_statements)*

context.write_object(
|context| {
#(#field_statements)*
::std::result::Result::Ok(())
},
#num_fields,
field_count,
)
}
}
Expand Down