From caf0dc033192eb7225e75db93b5300cf00681787 Mon Sep 17 00:00:00 2001 From: Wodann Date: Thu, 5 Mar 2020 15:48:07 +0100 Subject: [PATCH 1/6] feat(abi): add memory kind to struct information --- crates/mun_abi/c | 2 +- crates/mun_abi/src/autogen.rs | 14 ++++++++- crates/mun_abi/src/autogen_impl.rs | 28 +++++++++++++++-- crates/mun_abi/src/lib.rs | 32 +++++++++++++++++++- crates/mun_codegen/src/code_gen/abi_types.rs | 1 + crates/mun_codegen/src/code_gen/symbols.rs | 5 +++ crates/mun_runtime_capi/ffi | 2 +- crates/mun_syntax/Cargo.toml | 1 + crates/mun_syntax/src/ast.rs | 3 +- crates/mun_syntax/src/ast/extensions.rs | 17 +---------- crates/tools/src/abi.rs | 4 ++- 11 files changed, 85 insertions(+), 24 deletions(-) diff --git a/crates/mun_abi/c b/crates/mun_abi/c index 75d941290..1380c1959 160000 --- a/crates/mun_abi/c +++ b/crates/mun_abi/c @@ -1 +1 @@ -Subproject commit 75d9412905ed5cc2f9842148d5ce6d4da9493b91 +Subproject commit 1380c1959532df7c889ad994e8ff6338e0105cab diff --git a/crates/mun_abi/src/autogen.rs b/crates/mun_abi/src/autogen.rs index bd4dd33a1..8eeb2b693 100644 --- a/crates/mun_abi/src/autogen.rs +++ b/crates/mun_abi/src/autogen.rs @@ -3,7 +3,7 @@ /* automatically generated by rust-bindgen */ #![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)] -use crate::{Privacy, TypeGroup}; +use crate::{Privacy, StructMemoryKind, TypeGroup}; #[doc = " Represents a globally unique identifier (GUID)."] #[doc = ""] @@ -236,6 +236,8 @@ pub struct StructInfo { pub field_sizes: *const u16, #[doc = " Number of fields"] pub num_fields: u16, + #[doc = " Struct memory kind"] + pub memory_kind: StructMemoryKind, } #[test] fn bindgen_test_layout_StructInfo() { @@ -309,6 +311,16 @@ fn bindgen_test_layout_StructInfo() { stringify!(num_fields) ) ); + assert_eq!( + unsafe { &(*(::std::ptr::null::())).memory_kind as *const _ as usize }, + 42usize, + concat!( + "Offset of field: ", + stringify!(StructInfo), + "::", + stringify!(memory_kind) + ) + ); } #[doc = " Represents a module declaration."] #[doc = ""] diff --git a/crates/mun_abi/src/autogen_impl.rs b/crates/mun_abi/src/autogen_impl.rs index e63b5b4ca..f6c7b08c2 100644 --- a/crates/mun_abi/src/autogen_impl.rs +++ b/crates/mun_abi/src/autogen_impl.rs @@ -468,6 +468,7 @@ mod tests { field_types: &[&TypeInfo], field_offsets: &[u16], field_sizes: &[u16], + memory_kind: StructMemoryKind, ) -> StructInfo { assert!(field_names.len() == field_types.len()); assert!(field_types.len() == field_offsets.len()); @@ -480,6 +481,7 @@ mod tests { field_offsets: field_offsets.as_ptr(), field_sizes: field_sizes.as_ptr(), num_fields: field_names.len() as u16, + memory_kind, } } @@ -488,7 +490,7 @@ mod tests { #[test] fn test_struct_info_name() { let struct_name = CString::new(FAKE_STRUCT_NAME).expect("Invalid fake struct name."); - let struct_info = fake_struct_info(&struct_name, &[], &[], &[], &[]); + let struct_info = fake_struct_info(&struct_name, &[], &[], &[], &[], Default::default()); assert_eq!(struct_info.name(), FAKE_STRUCT_NAME); } @@ -506,6 +508,7 @@ mod tests { field_types, field_offsets, field_sizes, + Default::default(), ); assert_eq!(struct_info.field_names().count(), 0); @@ -531,6 +534,7 @@ mod tests { field_types, field_offsets, field_sizes, + Default::default(), ); for (lhs, rhs) in struct_info.field_names().zip([FAKE_FIELD_NAME].iter()) { @@ -555,6 +559,26 @@ mod tests { } } + #[test] + fn test_struct_info_memory_kind_gc() { + let struct_name = CString::new(FAKE_STRUCT_NAME).expect("Invalid fake struct name."); + let struct_memory_kind = StructMemoryKind::GC; + let struct_info = + fake_struct_info(&struct_name, &[], &[], &[], &[], struct_memory_kind.clone()); + + assert_eq!(struct_info.memory_kind, struct_memory_kind); + } + + #[test] + fn test_struct_info_memory_kind_value() { + let struct_name = CString::new(FAKE_STRUCT_NAME).expect("Invalid fake struct name."); + let struct_memory_kind = StructMemoryKind::Value; + let struct_info = + fake_struct_info(&struct_name, &[], &[], &[], &[], struct_memory_kind.clone()); + + assert_eq!(struct_info.memory_kind, struct_memory_kind); + } + const FAKE_MODULE_PATH: &str = "path::to::module"; #[test] @@ -592,7 +616,7 @@ mod tests { let functions = &[fn_info]; let struct_name = CString::new(FAKE_STRUCT_NAME).expect("Invalid fake struct name"); - let struct_info = fake_struct_info(&struct_name, &[], &[], &[], &[]); + let struct_info = fake_struct_info(&struct_name, &[], &[], &[], &[], Default::default()); let struct_type_info = fake_struct_type_info(&struct_name, struct_info); let types = &[unsafe { mem::transmute(&struct_type_info) }]; diff --git a/crates/mun_abi/src/lib.rs b/crates/mun_abi/src/lib.rs index 6d6d4e289..050fc2b3d 100644 --- a/crates/mun_abi/src/lib.rs +++ b/crates/mun_abi/src/lib.rs @@ -15,7 +15,37 @@ pub use autogen::*; /// The *prelude* contains imports that are used almost every time. pub mod prelude { pub use crate::autogen::*; - pub use crate::{Privacy, TypeGroup}; + pub use crate::{Privacy, StructMemoryKind, TypeGroup}; +} + +/// Represents the kind of memory management a struct uses. +#[repr(u8)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StructMemoryKind { + /// A garbage collected struct is allocated on the heap and uses reference semantics when passed + /// around. + GC, + + /// A value struct is allocated on the stack and uses value semantics when passed around. + /// + /// NOTE: When a value struct is used in an external API, a wrapper is created that _pins_ the + /// value on the heap. The heap-allocated value needs to be *manually deallocated*! + Value, +} + +impl Default for StructMemoryKind { + fn default() -> Self { + StructMemoryKind::GC + } +} + +impl From for u64 { + fn from(kind: StructMemoryKind) -> Self { + match kind { + StructMemoryKind::GC => 0, + StructMemoryKind::Value => 1, + } + } } /// Represents the privacy level of modules, functions, or variables. diff --git a/crates/mun_codegen/src/code_gen/abi_types.rs b/crates/mun_codegen/src/code_gen/abi_types.rs index 380191c4e..4ed140cb1 100644 --- a/crates/mun_codegen/src/code_gen/abi_types.rs +++ b/crates/mun_codegen/src/code_gen/abi_types.rs @@ -78,6 +78,7 @@ pub(super) fn gen_abi_types(context: ContextRef) -> AbiTypes { context.i16_type().ptr_type(AddressSpace::Const).into(), // field_offsets context.i16_type().ptr_type(AddressSpace::Const).into(), // field_sizes context.i16_type().into(), // num_fields + context.i8_type().into(), // memory_kind ], false, ); diff --git a/crates/mun_codegen/src/code_gen/symbols.rs b/crates/mun_codegen/src/code_gen/symbols.rs index 07e189598..ffa2f382f 100644 --- a/crates/mun_codegen/src/code_gen/symbols.rs +++ b/crates/mun_codegen/src/code_gen/symbols.rs @@ -337,6 +337,11 @@ fn gen_struct_info( .i16_type() .const_int(num_fields as u64, false) .into(), + module + .get_context() + .i8_type() + .const_int(s.data(db).memory_kind.clone().into(), false) + .into(), ]) } diff --git a/crates/mun_runtime_capi/ffi b/crates/mun_runtime_capi/ffi index dc1d3ef3f..75e6f0fec 160000 --- a/crates/mun_runtime_capi/ffi +++ b/crates/mun_runtime_capi/ffi @@ -1 +1 @@ -Subproject commit dc1d3ef3fa135b7461e4ef871b3abb5e15b1d3e2 +Subproject commit 75e6f0fec3b34fd6926517d2a4e99d991a7f1db3 diff --git a/crates/mun_syntax/Cargo.toml b/crates/mun_syntax/Cargo.toml index dd1c78fb3..e38f74697 100644 --- a/crates/mun_syntax/Cargo.toml +++ b/crates/mun_syntax/Cargo.toml @@ -9,6 +9,7 @@ license = "MIT OR Apache-2.0" description = "Parsing functionality for the Mun programming language" [dependencies] +abi = { path = "../mun_abi", package = "mun_abi" } rowan = "0.6.1" text_unit = { version = "0.1.6", features = ["serde"] } smol_str = { version = "0.1.12", features = ["serde"] } diff --git a/crates/mun_syntax/src/ast.rs b/crates/mun_syntax/src/ast.rs index 47fc03d8a..e3cbe062a 100644 --- a/crates/mun_syntax/src/ast.rs +++ b/crates/mun_syntax/src/ast.rs @@ -9,11 +9,12 @@ use crate::{syntax_node::SyntaxNodeChildren, SmolStr, SyntaxKind, SyntaxNode, Sy pub use self::{ expr_extensions::*, - extensions::{PathSegmentKind, StructKind, StructMemoryKind}, + extensions::{PathSegmentKind, StructKind}, generated::*, tokens::*, traits::*, }; +pub use abi::StructMemoryKind; use std::marker::PhantomData; diff --git a/crates/mun_syntax/src/ast/extensions.rs b/crates/mun_syntax/src/ast/extensions.rs index 6e4670c76..ce5251b0e 100644 --- a/crates/mun_syntax/src/ast/extensions.rs +++ b/crates/mun_syntax/src/ast/extensions.rs @@ -3,6 +3,7 @@ use crate::{ SyntaxKind, T, }; use crate::{SmolStr, SyntaxNode}; +use abi::StructMemoryKind; use text_unit::TextRange; impl ast::Name { @@ -118,22 +119,6 @@ impl StructKind { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum StructMemoryKind { - /// A garbage collected struct is allocated on the heap and uses reference semantics when passed - /// around. - GC, - - /// A value struct is allocated on the stack and uses value semantics when passed around. - Value, -} - -impl Default for StructMemoryKind { - fn default() -> Self { - StructMemoryKind::GC - } -} - impl ast::MemoryTypeSpecifier { pub fn kind(&self) -> StructMemoryKind { if self.is_value() { diff --git a/crates/tools/src/abi.rs b/crates/tools/src/abi.rs index 06e6ec6b0..184934358 100644 --- a/crates/tools/src/abi.rs +++ b/crates/tools/src/abi.rs @@ -28,6 +28,8 @@ impl ParseCallbacks for RemoveVendorName { Some("Privacy".to_string()) } else if original_item_name == "MunTypeGroup_t" { Some("TypeGroup".to_string()) + } else if original_item_name == "MunStructMemoryKind_t" { + Some("StructMemoryKind".to_string()) } else { Some(original_item_name.trim_start_matches("Mun").to_string()) } @@ -55,7 +57,7 @@ pub fn generate(mode: Mode) -> Result<()> { .derive_copy(false) .derive_debug(false) .raw_line("#![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)]") - .raw_line("use crate::{Privacy, TypeGroup};") + .raw_line("use crate::{Privacy, StructMemoryKind, TypeGroup};") .generate() .map_err(|_| format_err!("Unable to generate bindings from 'mun_abi.h'"))?; From b209c383bad049e011ec966990e1548f96c377d8 Mon Sep 17 00:00:00 2001 From: Wodann Date: Thu, 5 Mar 2020 15:58:41 +0100 Subject: [PATCH 2/6] fix(code_gen): invalid dispatch table indices --- crates/mun_codegen/src/ir/dispatch_table.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/crates/mun_codegen/src/ir/dispatch_table.rs b/crates/mun_codegen/src/ir/dispatch_table.rs index 0e6af3387..33f924f34 100644 --- a/crates/mun_codegen/src/ir/dispatch_table.rs +++ b/crates/mun_codegen/src/ir/dispatch_table.rs @@ -242,6 +242,7 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { arg_types, ret_type, }; + let index = self.entries.len(); self.entries.push(TypedDispatchableFunction { function: DispatchableFunction { prototype: prototype.clone(), @@ -249,10 +250,8 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { }, ir_type, }); - self.prototype_to_idx - .insert(prototype, self.function_to_idx.len()); - self.function_to_idx - .insert(function, self.function_to_idx.len()); + self.prototype_to_idx.insert(prototype, index); + self.function_to_idx.insert(function, index); } } @@ -264,6 +263,7 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { // If the function is not yet contained in the table add it let prototype = intrinsic.prototype(); if !self.prototype_to_idx.contains_key(&prototype) { + let index = self.entries.len(); self.entries.push(TypedDispatchableFunction { function: DispatchableFunction { prototype: prototype.clone(), @@ -272,8 +272,7 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { ir_type: intrinsic.ir_type(&self.module.get_context()), }); - self.prototype_to_idx - .insert(prototype, self.function_to_idx.len()); + self.prototype_to_idx.insert(prototype, index); } } From ef0253d8fb130db95da214c718116c86f19a7e3d Mon Sep 17 00:00:00 2001 From: Wodann Date: Sat, 7 Mar 2020 10:49:43 +0100 Subject: [PATCH 3/6] improvement(runtime): add function to generate a managed FunctionInfo --- crates/mun_runtime/src/function.rs | 81 ++++++++++++++++++++++++++++++ crates/mun_runtime/src/lib.rs | 74 +++++---------------------- 2 files changed, 94 insertions(+), 61 deletions(-) create mode 100644 crates/mun_runtime/src/function.rs diff --git a/crates/mun_runtime/src/function.rs b/crates/mun_runtime/src/function.rs new file mode 100644 index 000000000..478e9aa8e --- /dev/null +++ b/crates/mun_runtime/src/function.rs @@ -0,0 +1,81 @@ +use std::ffi::CString; +use std::ptr; + +use abi::{FunctionInfo, FunctionSignature, Guid, Privacy, TypeGroup, TypeInfo}; + +pub struct FunctionInfoStorage { + _name: CString, + _type_names: Vec, + _type_infos: Vec>, +} + +impl FunctionInfoStorage { + pub fn new_function( + name: &str, + args: &[String], + ret: Option, + privacy: Privacy, + fn_ptr: *const std::ffi::c_void, + ) -> (FunctionInfo, FunctionInfoStorage) { + let name = CString::new(name).unwrap(); + let (mut type_names, mut type_infos): (Vec, Vec>) = args + .iter() + .cloned() + .map(|name| { + let name = CString::new(name).unwrap(); + let type_info = Box::new(TypeInfo { + guid: Guid { + b: md5::compute(name.as_bytes()).0, + }, + name: name.as_ptr(), + group: TypeGroup::FundamentalTypes, + }); + (name, type_info) + }) + .unzip(); + + let ret = ret.map(|name| { + let name = CString::new(name).unwrap(); + let type_info = Box::new(TypeInfo { + guid: Guid { + b: md5::compute(name.as_bytes()).0, + }, + name: name.as_ptr(), + group: TypeGroup::FundamentalTypes, + }); + (name, type_info) + }); + + let num_arg_types = type_infos.len() as u16; + let return_type = if let Some((type_name, type_info)) = ret { + type_names.push(type_name); + + let ptr = Box::into_raw(type_info); + let type_info = unsafe { Box::from_raw(ptr) }; + type_infos.push(type_info); + + ptr + } else { + ptr::null() + }; + + let fn_info = FunctionInfo { + signature: FunctionSignature { + name: name.as_ptr(), + arg_types: type_infos.as_ptr() as *const *const _, + return_type, + num_arg_types, + privacy, + }, + fn_ptr, + }; + + let fn_storage = FunctionInfoStorage { + _name: name, + _type_names: type_names, + _type_infos: type_infos, + }; + + (fn_info, fn_storage) + } +} diff --git a/crates/mun_runtime/src/lib.rs b/crates/mun_runtime/src/lib.rs index c22559d6d..14bea37ac 100644 --- a/crates/mun_runtime/src/lib.rs +++ b/crates/mun_runtime/src/lib.rs @@ -5,6 +5,7 @@ #![warn(missing_docs)] mod assembly; +mod function; #[macro_use] mod macros; mod marshal; @@ -16,14 +17,14 @@ mod test; use std::alloc::Layout; use std::collections::HashMap; -use std::ffi::CString; use std::io; use std::path::{Path, PathBuf}; use std::sync::mpsc::{channel, Receiver}; use std::time::Duration; -use abi::{FunctionInfo, FunctionSignature, Guid, Privacy, TypeGroup, TypeInfo}; +use abi::{FunctionInfo, Privacy}; use failure::Error; +use function::FunctionInfoStorage; use notify::{DebouncedEvent, RecommendedWatcher, RecursiveMode, Watcher}; pub use crate::marshal::MarshalInto; @@ -106,11 +107,7 @@ pub struct Runtime { watcher: RecommendedWatcher, watcher_rx: Receiver, - _name: CString, - _u64_type: CString, - _ptr_mut_u8_type: CString, - _arg_types: Vec<*mut abi::TypeInfo>, - _ret_type: Box, + _local_fn_storage: Vec, } extern "C" fn malloc(size: u64, alignment: u64) -> *mut u8 { @@ -126,48 +123,16 @@ impl Runtime { pub fn new(options: RuntimeOptions) -> Result { let (tx, rx) = channel(); - let name = CString::new("malloc").unwrap(); - let u64_type = CString::new("core::u64").unwrap(); - let ptr_mut_u8_type = CString::new("core::u8*").unwrap(); - - let arg_types = vec![ - Box::into_raw(Box::new(TypeInfo { - guid: Guid { - b: md5::compute("core::u64").0, - }, - name: u64_type.as_ptr(), - group: TypeGroup::FundamentalTypes, - })), - Box::into_raw(Box::new(TypeInfo { - guid: Guid { - b: md5::compute("core::u64").0, - }, - name: u64_type.as_ptr(), - group: TypeGroup::FundamentalTypes, - })), - ]; - - let ret_type = Box::new(TypeInfo { - guid: Guid { - b: md5::compute("*mut core::u8").0, - }, - name: ptr_mut_u8_type.as_ptr(), - group: TypeGroup::FundamentalTypes, - }); - - let fn_info = FunctionInfo { - signature: FunctionSignature { - name: name.as_ptr(), - arg_types: arg_types.as_ptr() as *const *const _, - return_type: ret_type.as_ref(), - num_arg_types: 2, - privacy: Privacy::Public, - }, - fn_ptr: malloc as *const std::ffi::c_void, - }; + let (malloc_info, malloc_storage) = FunctionInfoStorage::new_function( + "malloc", + &["core::u8".to_string(), "core::u64".to_string()], + Some("*mut core::u8".to_string()), + Privacy::Public, + malloc as *const std::ffi::c_void, + ); let mut dispatch_table = DispatchTable::default(); - dispatch_table.insert_fn("malloc", fn_info); + dispatch_table.insert_fn("malloc", malloc_info); let watcher: RecommendedWatcher = Watcher::new(tx, options.delay)?; let mut runtime = Runtime { @@ -176,11 +141,7 @@ impl Runtime { watcher, watcher_rx: rx, - _name: name, - _u64_type: u64_type, - _ptr_mut_u8_type: ptr_mut_u8_type, - _arg_types: arg_types, - _ret_type: ret_type, + _local_fn_storage: vec![malloc_storage], }; runtime.add_assembly(&options.library_path)?; @@ -242,15 +203,6 @@ impl Runtime { } } -impl Drop for Runtime { - fn drop(&mut self) { - for raw_arg_type in self._arg_types.iter() { - // Drop arg type memory - let _arg_type = unsafe { Box::from_raw(*raw_arg_type) }; - } - } -} - /// Extends a result object with functions that allow retrying of an action. pub trait RetryResultExt: Sized { /// Output type on success From d539b40995445bbf96f1cd53a61f526efeee9959 Mon Sep 17 00:00:00 2001 From: Wodann Date: Sat, 7 Mar 2020 14:04:32 +0100 Subject: [PATCH 4/6] fix: allow desired usage of clippy::vec_box --- crates/mun_runtime/src/function.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/mun_runtime/src/function.rs b/crates/mun_runtime/src/function.rs index 478e9aa8e..572238249 100644 --- a/crates/mun_runtime/src/function.rs +++ b/crates/mun_runtime/src/function.rs @@ -6,6 +6,10 @@ use abi::{FunctionInfo, FunctionSignature, Guid, Privacy, TypeGroup, TypeInfo}; pub struct FunctionInfoStorage { _name: CString, _type_names: Vec, + + // Clippy warns: `Vec` is already on the heap, the boxing is unnecessary. + // However, in this case we explicitly want to have a Vec of pointers. + #[allow(clippy::vec_box)] _type_infos: Vec>, } From 9dc084e443b1723a234e9e1f104df2071f9ee7b3 Mon Sep 17 00:00:00 2001 From: Wodann Date: Sun, 1 Mar 2020 14:52:57 +0100 Subject: [PATCH 5/6] feat(code_gen): create marshallable wrapper for unmarshallable functions A function cannot be marshalled when one of its parameters or its return type are a value struct --- crates/mun_codegen/src/code_gen/symbols.rs | 9 +- crates/mun_codegen/src/db.rs | 6 +- crates/mun_codegen/src/ir/adt.rs | 4 +- crates/mun_codegen/src/ir/body.rs | 172 ++++++++++++++------ crates/mun_codegen/src/ir/dispatch_table.rs | 12 +- crates/mun_codegen/src/ir/function.rs | 31 +++- crates/mun_codegen/src/ir/module.rs | 48 +++++- crates/mun_codegen/src/ir/ty.rs | 26 ++- crates/mun_codegen/src/lib.rs | 7 + crates/mun_hir/src/ty.rs | 13 +- 10 files changed, 248 insertions(+), 80 deletions(-) diff --git a/crates/mun_codegen/src/code_gen/symbols.rs b/crates/mun_codegen/src/code_gen/symbols.rs index ffa2f382f..d8954c0ba 100644 --- a/crates/mun_codegen/src/code_gen/symbols.rs +++ b/crates/mun_codegen/src/code_gen/symbols.rs @@ -5,7 +5,7 @@ use crate::ir::{ }; use crate::type_info::{TypeGroup, TypeInfo}; use crate::values::{BasicValue, GlobalValue}; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use hir::Ty; use inkwell::{ attributes::Attribute, @@ -262,7 +262,6 @@ fn gen_function_info_array<'a, D: IrDatabase>( functions: impl Iterator, ) -> GlobalArrayValue { let function_infos: Vec = functions - .filter(|(f, _)| f.visibility(db) == hir::Visibility::Public) .map(|(f, value)| { // Get the function from the cloned module and modify the linkage of the function. let value = module @@ -321,9 +320,9 @@ fn gen_struct_info( (0..fields.len()).map(|idx| target_data.offset_of_element(&t, idx as u32).unwrap()); let (field_offsets, _) = gen_u16_array(module, field_offsets); - let field_sizes = fields - .iter() - .map(|field| target_data.get_store_size(&db.type_ir(field.ty(db)))); + let field_sizes = fields.iter().map(|field| { + target_data.get_store_size(&db.type_ir(field.ty(db), CodeGenParams { is_extern: false })) + }); let (field_sizes, _) = gen_u16_array(module, field_sizes); types.struct_info_type.const_named_struct(&[ diff --git a/crates/mun_codegen/src/db.rs b/crates/mun_codegen/src/db.rs index 5b911aebd..332c4ee73 100644 --- a/crates/mun_codegen/src/db.rs +++ b/crates/mun_codegen/src/db.rs @@ -1,6 +1,6 @@ #![allow(clippy::type_repetition_in_bounds)] -use crate::{ir::module::ModuleIR, type_info::TypeInfo, Context}; +use crate::{ir::module::ModuleIR, type_info::TypeInfo, CodeGenParams, Context}; use inkwell::types::StructType; use inkwell::{types::AnyTypeEnum, OptimizationLevel}; use mun_target::spec::Target; @@ -22,9 +22,9 @@ pub trait IrDatabase: hir::HirDatabase { #[salsa::input] fn target(&self) -> Target; - /// Given a type, return the corresponding IR type. + /// Given a type and code generation parameters, return the corresponding IR type. #[salsa::invoke(crate::ir::ty::ir_query)] - fn type_ir(&self, ty: hir::Ty) -> AnyTypeEnum; + fn type_ir(&self, ty: hir::Ty, params: CodeGenParams) -> AnyTypeEnum; /// Given a struct, return the corresponding IR type. #[salsa::invoke(crate::ir::ty::struct_ty_query)] diff --git a/crates/mun_codegen/src/ir/adt.rs b/crates/mun_codegen/src/ir/adt.rs index b8c00b271..66105d831 100644 --- a/crates/mun_codegen/src/ir/adt.rs +++ b/crates/mun_codegen/src/ir/adt.rs @@ -1,6 +1,6 @@ //use crate::ir::module::Types; use crate::ir::try_convert_any_to_basic; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use inkwell::types::{BasicTypeEnum, StructType}; pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructType { @@ -11,7 +11,7 @@ pub(super) fn gen_struct_decl(db: &impl IrDatabase, s: hir::Struct) -> StructTyp .iter() .map(|field| { let field_type = field.ty(db); - try_convert_any_to_basic(db.type_ir(field_type)) + try_convert_any_to_basic(db.type_ir(field_type, CodeGenParams { is_extern: false })) .expect("could not convert field type") }) .collect(); diff --git a/crates/mun_codegen/src/ir/body.rs b/crates/mun_codegen/src/ir/body.rs index 3bef260ce..741fa0521 100644 --- a/crates/mun_codegen/src/ir/body.rs +++ b/crates/mun_codegen/src/ir/body.rs @@ -1,5 +1,7 @@ use crate::intrinsics; -use crate::{ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, IrDatabase}; +use crate::{ + ir::dispatch_table::DispatchTable, ir::try_convert_any_to_basic, CodeGenParams, IrDatabase, +}; use hir::{ ArenaId, ArithOp, BinaryOp, Body, CmpOp, Expr, ExprId, HirDisplay, InferenceResult, Literal, Name, Ordering, Pat, PatId, Path, Resolution, Resolver, Statement, TypeCtor, @@ -7,7 +9,7 @@ use hir::{ use inkwell::{ builder::Builder, module::Module, - values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue}, + values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue, StructValue}, AddressSpace, FloatPredicate, IntPredicate, }; use std::{collections::HashMap, mem, sync::Arc}; @@ -37,6 +39,7 @@ pub(crate) struct BodyIrGenerator<'a, 'b, D: IrDatabase> { dispatch_table: &'b DispatchTable, active_loop: Option, hir_function: hir::Function, + params: CodeGenParams, } impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { @@ -47,6 +50,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { ir_function: FunctionValue, function_map: &'a HashMap, dispatch_table: &'b DispatchTable, + params: CodeGenParams, ) -> Self { // Get the type information from the `hir::Function` let body = hir_function.body(db); @@ -72,6 +76,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { dispatch_table, active_loop: None, hir_function, + params, } } @@ -127,6 +132,50 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } } + pub fn gen_fn_wrapper(&mut self) { + let fn_sig = self.hir_function.ty(self.db).callable_sig(self.db).unwrap(); + let args: Vec = fn_sig + .params() + .iter() + .enumerate() + .map(|(idx, ty)| { + let param = self.fn_value.get_nth_param(idx as u32).unwrap(); + self.opt_deref_value(ty.clone(), param) + }) + .collect(); + + let ret_value = self + .gen_call(self.hir_function, &args) + .try_as_basic_value() + .left(); + + let call_return_type = &self.infer[self.body.body_expr()]; + if !call_return_type.is_never() { + let fn_ret_type = self + .hir_function + .ty(self.db) + .callable_sig(self.db) + .unwrap() + .ret() + .clone(); + + if fn_ret_type.is_empty() { + self.builder.build_return(None); + } else if let Some(value) = ret_value { + let ret_value = if let Some(hir_struct) = fn_ret_type.as_struct() { + if hir_struct.data(self.db).memory_kind == hir::StructMemoryKind::Value { + self.gen_struct_alloc_on_heap(hir_struct, value.into_struct_value()) + } else { + value + } + } else { + value + }; + self.builder.build_return(Some(&ret_value)); + } + } + } + /// Generates IR for the specified expression. Dependending on the type of expression an IR /// value is returned. fn gen_expr(&mut self, expr: ExprId) -> Option { @@ -152,6 +201,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Get the callable definition from the map match self.infer[*callee].as_callable_def() { Some(hir::CallableDef::Function(def)) => { + // Get all the arguments + let args: Vec = args + .iter() + .map(|expr| self.gen_expr(*expr).expect("expected a value")) + .collect(); + self.gen_call(def, &args).try_as_basic_value().left() } Some(hir::CallableDef::Struct(_)) => Some(self.gen_named_tuple_lit(expr, args)), @@ -235,37 +290,45 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { hir::StructMemoryKind::Value => struct_lit.into(), hir::StructMemoryKind::GC => { // TODO: Root memory in GC - let struct_ir_ty = self.db.struct_ty(hir_struct); - let malloc_fn_ptr = self - .dispatch_table - .gen_intrinsic_lookup(&self.builder, &intrinsics::malloc); - let mem_ptr = self - .builder - .build_call( - malloc_fn_ptr, - &[ - struct_ir_ty.size_of().unwrap().into(), - struct_ir_ty.get_alignment().into(), - ], - "malloc", - ) - .try_as_basic_value() - .left() - .unwrap(); - let struct_ptr = self - .builder - .build_bitcast( - mem_ptr, - struct_ir_ty.ptr_type(AddressSpace::Generic), - &hir_struct.name(self.db).to_string(), - ) - .into_pointer_value(); - self.builder.build_store(struct_ptr, struct_lit); - struct_ptr.into() + self.gen_struct_alloc_on_heap(hir_struct, struct_lit) } } } + fn gen_struct_alloc_on_heap( + &mut self, + hir_struct: hir::Struct, + struct_lit: StructValue, + ) -> BasicValueEnum { + let struct_ir_ty = self.db.struct_ty(hir_struct); + let malloc_fn_ptr = self + .dispatch_table + .gen_intrinsic_lookup(&self.builder, &intrinsics::malloc); + let mem_ptr = self + .builder + .build_call( + malloc_fn_ptr, + &[ + struct_ir_ty.size_of().unwrap().into(), + struct_ir_ty.get_alignment().into(), + ], + "malloc", + ) + .try_as_basic_value() + .left() + .unwrap(); + let struct_ptr = self + .builder + .build_bitcast( + mem_ptr, + struct_ir_ty.ptr_type(AddressSpace::Generic), + &hir_struct.name(self.db).to_string(), + ) + .into_pointer_value(); + self.builder.build_store(struct_ptr, struct_lit); + struct_ptr.into() + } + /// Generates IR for a record literal, e.g. `Foo { a: 1.23, b: 4 }` fn gen_record_lit( &mut self, @@ -349,8 +412,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { Pat::Bind { name } => { let builder = self.new_alloca_builder(); let pat_ty = self.infer[pat].clone(); - let ty = try_convert_any_to_basic(self.db.type_ir(pat_ty.clone())) - .expect("expected basic type"); + let ty = try_convert_any_to_basic( + self.db + .type_ir(pat_ty.clone(), CodeGenParams { is_extern: false }), + ) + .expect("expected basic type"); let ptr = builder.build_alloca(ty, &name.to_string()); self.pat_to_local.insert(pat, ptr); self.pat_to_name.insert(pat, name.to_string()); @@ -394,8 +460,8 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } /// Given an expression and the type of the expression, optionally dereference the value. - fn opt_deref_value(&mut self, expr: ExprId, value: BasicValueEnum) -> BasicValueEnum { - match &self.infer[expr] { + fn opt_deref_value(&mut self, ty: hir::Ty, value: BasicValueEnum) -> BasicValueEnum { + match ty { hir::Ty::Apply(hir::ApplicationTy { ctor: hir::TypeCtor::Struct(s), .. @@ -403,7 +469,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { hir::StructMemoryKind::GC => { self.builder.build_load(value.into_pointer_value(), "deref") } - hir::StructMemoryKind::Value => value, + hir::StructMemoryKind::Value => { + if self.params.is_extern { + self.builder.build_load(value.into_pointer_value(), "deref") + } else { + value + } + } }, _ => value, } @@ -460,12 +532,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { ) -> Option { let lhs = self .gen_expr(lhs_expr) - .map(|value| self.opt_deref_value(lhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value)) .expect("no lhs value") .into_float_value(); let rhs = self .gen_expr(rhs_expr) - .map(|value| self.opt_deref_value(rhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[rhs_expr].clone(), value)) .expect("no rhs value") .into_float_value(); match op { @@ -519,12 +591,12 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { ) -> Option { let lhs = self .gen_expr(lhs_expr) - .map(|value| self.opt_deref_value(lhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value)) .expect("no lhs value") .into_int_value(); let rhs = self .gen_expr(rhs_expr) - .map(|value| self.opt_deref_value(lhs_expr, value)) + .map(|value| self.opt_deref_value(self.infer[lhs_expr].clone(), value)) .expect("no rhs value") .into_int_value(); match op { @@ -609,19 +681,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } } - // TODO: Implement me! fn should_use_dispatch_table(&self) -> bool { - true + // FIXME: When we use the dispatch table, generated wrappers have infinite recursion + !self.params.is_extern } /// Generates IR for a function call. - fn gen_call(&mut self, function: hir::Function, args: &[ExprId]) -> CallSiteValue { - // Get all the arguments - let args: Vec = args - .iter() - .map(|expr| self.gen_expr(*expr).expect("expected a value")) - .collect(); - + fn gen_call(&mut self, function: hir::Function, args: &[BasicValueEnum]) -> CallSiteValue { if self.should_use_dispatch_table() { let ptr_value = self.dispatch_table @@ -649,7 +715,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { // Generate IR for the condition let condition_ir = self .gen_expr(condition) - .map(|value| self.opt_deref_value(condition, value))? + .map(|value| self.opt_deref_value(self.infer[condition].clone(), value))? .into_int_value(); // Generate the code blocks to branch to @@ -787,7 +853,7 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { self.builder.position_at_end(&cond_block); let condition_ir = self .gen_expr(condition_expr) - .map(|value| self.opt_deref_value(condition_expr, value)); + .map(|value| self.opt_deref_value(self.infer[condition_expr].clone(), value)); if let Some(condition_ir) = condition_ir { self.builder.build_conditional_branch( condition_ir.into_int_value(), @@ -844,11 +910,11 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { } fn gen_field(&mut self, _expr: ExprId, receiver_expr: ExprId, name: &Name) -> PointerValue { - let receiver_ty = &self.infer[receiver_expr] + let hir_struct = self.infer[receiver_expr] .as_struct() .expect("expected a struct"); - let field_idx = receiver_ty + let field_idx = hir_struct .field(self.db, name) .expect("expected a struct field") .id() @@ -857,13 +923,13 @@ impl<'a, 'b, D: IrDatabase> BodyIrGenerator<'a, 'b, D> { let receiver_ptr = self.gen_place_expr(receiver_expr); let receiver_ptr = self - .opt_deref_value(receiver_expr, receiver_ptr.into()) + .opt_deref_value(self.infer[receiver_expr].clone(), receiver_ptr.into()) .into_pointer_value(); unsafe { self.builder.build_struct_gep( receiver_ptr, field_idx, - &format!("{}.{}", receiver_ty.name(self.db), name), + &format!("{}.{}", hir_struct.name(self.db), name), ) } } diff --git a/crates/mun_codegen/src/ir/dispatch_table.rs b/crates/mun_codegen/src/ir/dispatch_table.rs index 33f924f34..baaef9f7f 100644 --- a/crates/mun_codegen/src/ir/dispatch_table.rs +++ b/crates/mun_codegen/src/ir/dispatch_table.rs @@ -1,6 +1,6 @@ use crate::intrinsics; use crate::values::FunctionValue; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use inkwell::module::Module; use inkwell::types::{BasicTypeEnum, FunctionType}; use inkwell::values::{BasicValueEnum, PointerValue}; @@ -225,7 +225,10 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { let name = function.name(self.db).to_string(); let hir_type = function.ty(self.db); let sig = hir_type.callable_sig(self.db).unwrap(); - let ir_type = self.db.type_ir(hir_type).into_function_type(); + let ir_type = self + .db + .type_ir(hir_type, CodeGenParams { is_extern: false }) + .into_function_type(); let arg_types = sig .params() .iter() @@ -282,6 +285,11 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { self.collect_expr(body.body_expr(), body, infer); } + /// Collect the call expression from the body of a wrapper for the specified function. + pub fn collect_wrapper_body(&mut self, _function: hir::Function) { + self.collect_intrinsic(&intrinsics::malloc) + } + /// This creates the final DispatchTable with all *called* functions from within the module /// # Parameters /// * **functions**: Mapping of *defined* Mun functions to their respective IR values. diff --git a/crates/mun_codegen/src/ir/function.rs b/crates/mun_codegen/src/ir/function.rs index 357d0b92d..c3f7b0bc9 100644 --- a/crates/mun_codegen/src/ir/function.rs +++ b/crates/mun_codegen/src/ir/function.rs @@ -1,7 +1,7 @@ use crate::ir::body::BodyIrGenerator; use crate::ir::dispatch_table::DispatchTable; use crate::values::FunctionValue; -use crate::{IrDatabase, Module, OptimizationLevel}; +use crate::{CodeGenParams, IrDatabase, Module, OptimizationLevel}; use inkwell::passes::{PassManager, PassManagerBuilder}; use inkwell::types::AnyTypeEnum; @@ -30,9 +30,10 @@ pub(crate) fn gen_signature( db: &impl IrDatabase, f: hir::Function, module: &Module, + params: CodeGenParams, ) -> FunctionValue { let name = f.name(db).to_string(); - if let AnyTypeEnum::FunctionType(ty) = db.type_ir(f.ty(db)) { + if let AnyTypeEnum::FunctionType(ty) = db.type_ir(f.ty(db), params) { module.add_function(&name, ty, None) } else { panic!("not a function type") @@ -55,9 +56,35 @@ pub(crate) fn gen_body<'a, 'b, D: IrDatabase>( llvm_function, llvm_functions, dispatch_table, + CodeGenParams { is_extern: false }, ); code_gen.gen_fn_body(); llvm_function } + +/// Generates the body of a wrapper around `hir::Function` for its associated +/// `FunctionValue` +pub(crate) fn gen_wrapper_body<'a, 'b, D: IrDatabase>( + db: &'a D, + hir_function: hir::Function, + llvm_function: FunctionValue, + module: &'a Module, + llvm_functions: &'a HashMap, + dispatch_table: &'b DispatchTable, +) -> FunctionValue { + let mut code_gen = BodyIrGenerator::new( + db, + module, + hir_function, + llvm_function, + llvm_functions, + dispatch_table, + CodeGenParams { is_extern: true }, + ); + + code_gen.gen_fn_wrapper(); + + llvm_function +} diff --git a/crates/mun_codegen/src/ir/module.rs b/crates/mun_codegen/src/ir/module.rs index d3d54578a..0a303e992 100644 --- a/crates/mun_codegen/src/ir/module.rs +++ b/crates/mun_codegen/src/ir/module.rs @@ -2,7 +2,7 @@ use super::adt; use crate::ir::dispatch_table::{DispatchTable, DispatchTableBuilder}; use crate::ir::function; use crate::type_info::TypeInfo; -use crate::IrDatabase; +use crate::{CodeGenParams, IrDatabase}; use hir::{FileId, ModuleDef}; use inkwell::{module::Module, values::FunctionValue}; use std::collections::{HashMap, HashSet}; @@ -47,6 +47,7 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { // Generate all the function signatures let mut functions = HashMap::new(); + let mut wrappers = HashMap::new(); let mut dispatch_table_builder = DispatchTableBuilder::new(db, &llvm_module); for def in db.module_data(file_id).definitions() { // TODO: Remove once we have more ModuleDef variants @@ -65,13 +66,31 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { } // Construct the function signature - let fun = function::gen_signature(db, *f, &llvm_module); + let fun = function::gen_signature( + db, + *f, + &llvm_module, + CodeGenParams { is_extern: false }, + ); functions.insert(*f, fun); // Add calls to the dispatch table let body = f.body(db); let infer = f.infer(db); dispatch_table_builder.collect_body(&body, &infer); + + if f.data(db).visibility() != hir::Visibility::Private && !fn_sig.marshallable(db) { + let wrapper_fun = function::gen_signature( + db, + *f, + &llvm_module, + CodeGenParams { is_extern: true }, + ); + wrappers.insert(*f, wrapper_fun); + + // Add calls from the function's wrapper to the dispatch table + dispatch_table_builder.collect_wrapper_body(*f); + } } _ => {} } @@ -94,6 +113,18 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { fn_pass_manager.run_on(llvm_function); } + for (hir_function, llvm_function) in wrappers.iter() { + function::gen_wrapper_body( + db, + *hir_function, + *llvm_function, + &llvm_module, + &functions, + &dispatch_table, + ); + fn_pass_manager.run_on(llvm_function); + } + // Dispatch entries can include previously unchecked intrinsics for entry in dispatch_table.entries().iter() { // Collect argument types @@ -106,10 +137,21 @@ pub(crate) fn ir_query(db: &impl IrDatabase, file_id: FileId) -> Arc { } } + // Filter private methods + let mut api: HashMap = functions + .into_iter() + .filter(|(f, _)| f.visibility(db) != hir::Visibility::Private) + .collect(); + + // Replace non-marshallable functions with their marshallable wrappers + for (hir_function, llvm_function) in wrappers { + api.insert(hir_function, llvm_function); + } + Arc::new(ModuleIR { file_id, llvm_module, - functions, + functions: api, types, dispatch_table, }) diff --git a/crates/mun_codegen/src/ir/ty.rs b/crates/mun_codegen/src/ir/ty.rs index 280fe0595..165b907c0 100644 --- a/crates/mun_codegen/src/ir/ty.rs +++ b/crates/mun_codegen/src/ir/ty.rs @@ -1,14 +1,14 @@ use super::try_convert_any_to_basic; use crate::{ type_info::{TypeGroup, TypeInfo}, - IrDatabase, + CodeGenParams, IrDatabase, }; use hir::{ApplicationTy, CallableDef, Ty, TypeCtor}; use inkwell::types::{AnyTypeEnum, BasicType, BasicTypeEnum, StructType}; use inkwell::AddressSpace; /// Given a mun type, construct an LLVM IR type -pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { +pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty, params: CodeGenParams) -> AnyTypeEnum { let context = db.context(); match ty { Ty::Empty => AnyTypeEnum::StructType(context.struct_type(&[], false)), @@ -18,17 +18,19 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { TypeCtor::Bool => AnyTypeEnum::IntType(context.bool_type()), TypeCtor::FnDef(def @ CallableDef::Function(_)) => { let ty = db.callable_sig(def); - let params: Vec = ty + let param_tys: Vec = ty .params() .iter() - .map(|p| try_convert_any_to_basic(db.type_ir(p.clone())).unwrap()) + .map(|p| { + try_convert_any_to_basic(db.type_ir(p.clone(), params.clone())).unwrap() + }) .collect(); let fn_type = match ty.ret() { - Ty::Empty => context.void_type().fn_type(¶ms, false), - ty => try_convert_any_to_basic(db.type_ir(ty.clone())) + Ty::Empty => context.void_type().fn_type(¶m_tys, false), + ty => try_convert_any_to_basic(db.type_ir(ty.clone(), params)) .expect("could not convert return value") - .fn_type(¶ms, false), + .fn_type(¶m_tys, false), }; AnyTypeEnum::FunctionType(fn_type) @@ -37,7 +39,13 @@ pub(crate) fn ir_query(db: &impl IrDatabase, ty: Ty) -> AnyTypeEnum { let struct_ty = db.struct_ty(s); match s.data(db).memory_kind { hir::StructMemoryKind::GC => struct_ty.ptr_type(AddressSpace::Generic).into(), - hir::StructMemoryKind::Value => struct_ty.into(), + hir::StructMemoryKind::Value => { + if params.is_extern { + struct_ty.ptr_type(AddressSpace::Generic).into() + } else { + struct_ty.into() + } + } } } _ => unreachable!(), @@ -51,7 +59,7 @@ pub fn struct_ty_query(db: &impl IrDatabase, s: hir::Struct) -> StructType { let name = s.name(db).to_string(); for field in s.fields(db).iter() { // Ensure that salsa's cached value incorporates the struct fields - let _field_type_ir = db.type_ir(field.ty(db)); + let _field_type_ir = db.type_ir(field.ty(db), CodeGenParams { is_extern: false }); } db.context().opaque_struct_type(&name) diff --git a/crates/mun_codegen/src/lib.rs b/crates/mun_codegen/src/lib.rs index f674be161..8544fb805 100644 --- a/crates/mun_codegen/src/lib.rs +++ b/crates/mun_codegen/src/lib.rs @@ -19,3 +19,10 @@ pub use crate::{ code_gen::write_module_shared_object, db::{IrDatabase, IrDatabaseStorage}, }; + +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +pub struct CodeGenParams { + /// Whether generated code should support extern function calls. + /// This allows function parameters with `struct(value)` types to be marshalled. + is_extern: bool, +} diff --git a/crates/mun_hir/src/ty.rs b/crates/mun_hir/src/ty.rs index 6f3f84cf3..0c253455a 100644 --- a/crates/mun_hir/src/ty.rs +++ b/crates/mun_hir/src/ty.rs @@ -5,7 +5,7 @@ mod op; use crate::display::{HirDisplay, HirFormatter}; use crate::ty::infer::TypeVarId; use crate::ty::lower::fn_sig_for_struct_constructor; -use crate::{HirDatabase, Struct}; +use crate::{HirDatabase, Struct, StructMemoryKind}; pub(crate) use infer::infer_query; pub use infer::InferenceResult; pub(crate) use lower::{callable_item_sig, fn_sig_for_fn, type_for_def, CallableDef, TypableDef}; @@ -172,6 +172,17 @@ impl FnSig { pub fn ret(&self) -> &Ty { &self.params_and_return[self.params_and_return.len() - 1] } + + pub fn marshallable(&self, db: &impl HirDatabase) -> bool { + for ty in self.params_and_return.iter() { + if let Some(s) = ty.as_struct() { + if s.data(db).memory_kind == StructMemoryKind::Value { + return false; + } + } + } + true + } } impl HirDisplay for Ty { From 3a716add8fe7dd31ba0c4aceaf4a4085362d7f5a Mon Sep 17 00:00:00 2001 From: Wodann Date: Thu, 5 Mar 2020 16:09:30 +0100 Subject: [PATCH 6/6] feat(runtime): add marshalling of value structs --- crates/mun/src/main.rs | 7 +- crates/mun_codegen/src/intrinsics.rs | 3 + crates/mun_codegen/src/intrinsics/macros.rs | 2 +- crates/mun_codegen/src/ir.rs | 8 + crates/mun_codegen/src/ir/dispatch_table.rs | 9 + .../src/snapshots/test__field_crash.snap | 4 +- .../src/snapshots/test__field_expr.snap | 6 +- .../src/snapshots/test__gc_struct.snap | 4 +- .../src/snapshots/test__struct_test.snap | 3 + crates/mun_codegen/src/type_info.rs | 7 + crates/mun_runtime/examples/hot_reloading.rs | 7 +- crates/mun_runtime/src/lib.rs | 30 ++- crates/mun_runtime/src/macros.rs | 73 +++---- crates/mun_runtime/src/marshal.rs | 38 +++- crates/mun_runtime/src/reflection.rs | 62 +++++- crates/mun_runtime/src/struct.rs | 125 +++++++---- crates/mun_runtime/src/test.rs | 205 +++++++++++------- 17 files changed, 407 insertions(+), 186 deletions(-) diff --git a/crates/mun/src/main.rs b/crates/mun/src/main.rs index 3266db3dc..e8343b005 100644 --- a/crates/mun/src/main.rs +++ b/crates/mun/src/main.rs @@ -1,6 +1,8 @@ #[macro_use] extern crate failure; +use std::cell::RefCell; +use std::rc::Rc; use std::time::Duration; use clap::{App, AppSettings, Arg, ArgMatches, SubCommand}; @@ -84,10 +86,11 @@ fn build(matches: &ArgMatches) -> Result<(), failure::Error> { /// Starts the runtime with the specified library and invokes function `entry`. fn start(matches: &ArgMatches) -> Result<(), failure::Error> { - let mut runtime = runtime(matches)?; + let runtime = Rc::new(RefCell::new(runtime(matches)?)); + let borrowed = runtime.borrow(); let entry_point = matches.value_of("entry").unwrap_or("main"); - let fn_info = runtime.get_function_info(entry_point).ok_or_else(|| { + let fn_info = borrowed.get_function_info(entry_point).ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::InvalidInput, format!("Failed to obtain entry point '{}'", entry_point), diff --git a/crates/mun_codegen/src/intrinsics.rs b/crates/mun_codegen/src/intrinsics.rs index bbe2d6131..d4106231d 100644 --- a/crates/mun_codegen/src/intrinsics.rs +++ b/crates/mun_codegen/src/intrinsics.rs @@ -1,4 +1,5 @@ use crate::ir::dispatch_table::FunctionPrototype; +use crate::type_info::TypeInfo; use inkwell::context::Context; use inkwell::types::FunctionType; @@ -18,4 +19,6 @@ pub trait Intrinsic: Sync { intrinsics! { /// Allocates memory from the runtime to use in code. pub fn malloc(size: u64, alignment: u64) -> *mut u8; + /// Allocates memory for and clones the specified type located at `src` into it. + pub fn clone(src: *const u8, ty: *const TypeInfo) -> *mut u8; } diff --git a/crates/mun_codegen/src/intrinsics/macros.rs b/crates/mun_codegen/src/intrinsics/macros.rs index 72839a69a..dc49d3ac0 100644 --- a/crates/mun_codegen/src/intrinsics/macros.rs +++ b/crates/mun_codegen/src/intrinsics/macros.rs @@ -1,5 +1,5 @@ macro_rules! intrinsics{ - ($($(#[$attr:meta])* pub fn $name:ident($($arg_name:ident:$arg:ty),*) -> $ret:ty;);*) => { + ($($(#[$attr:meta])* pub fn $name:ident($($arg_name:ident:$arg:ty),+) -> $ret:ty;)+) => { $( paste::item! { pub struct []; diff --git a/crates/mun_codegen/src/ir.rs b/crates/mun_codegen/src/ir.rs index 6ff1b9284..fab6ed00e 100644 --- a/crates/mun_codegen/src/ir.rs +++ b/crates/mun_codegen/src/ir.rs @@ -1,3 +1,4 @@ +use crate::type_info::TypeInfo; use inkwell::context::Context; use inkwell::types::{ AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, FunctionType, IntType, PointerType, @@ -183,6 +184,13 @@ impl> IsPointerType for *const T { } } +// HACK: Manually add `*const TypeInfo` +impl IsPointerType for *const TypeInfo { + fn ir_type(context: &Context) -> PointerType { + context.i8_type().ptr_type(AddressSpace::Const) + } +} + impl> IsPointerType for *mut T { fn ir_type(context: &Context) -> PointerType { T::ir_type(context).ptr_type(AddressSpace::Generic) diff --git a/crates/mun_codegen/src/ir/dispatch_table.rs b/crates/mun_codegen/src/ir/dispatch_table.rs index baaef9f7f..d01f76389 100644 --- a/crates/mun_codegen/src/ir/dispatch_table.rs +++ b/crates/mun_codegen/src/ir/dispatch_table.rs @@ -181,6 +181,9 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { match infer[*callee].as_callable_def() { Some(hir::CallableDef::Function(def)) => self.collect_fn_def(def), Some(hir::CallableDef::Struct(s)) => { + // self.collect_intrinsic(&intrinsics::new); + self.collect_intrinsic(&intrinsics::clone); + // self.collect_intrinsic(&intrinsics::drop); if s.data(self.db).memory_kind == hir::StructMemoryKind::GC { self.collect_intrinsic(&intrinsics::malloc) } @@ -192,6 +195,9 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { if let Expr::RecordLit { .. } = expr { let struct_ty = infer[expr_id].clone(); let hir_struct = struct_ty.as_struct().unwrap(); // Can only really get here if the type is a struct + // self.collect_intrinsic(&intrinsics::new); + self.collect_intrinsic(&intrinsics::clone); + // self.collect_intrinsic(&intrinsics::drop); if hir_struct.data(self.db).memory_kind == hir::StructMemoryKind::GC { self.collect_intrinsic(&intrinsics::malloc) } @@ -205,6 +211,9 @@ impl<'a, D: IrDatabase> DispatchTableBuilder<'a, D> { .expect("unknown path"); if let hir::Resolution::Def(hir::ModuleDef::Struct(s)) = resolution { + // self.collect_intrinsic(&intrinsics::new); + self.collect_intrinsic(&intrinsics::clone); + // self.collect_intrinsic(&intrinsics::drop); if s.data(self.db).memory_kind == hir::StructMemoryKind::GC { self.collect_intrinsic(&intrinsics::malloc) } diff --git a/crates/mun_codegen/src/snapshots/test__field_crash.snap b/crates/mun_codegen/src/snapshots/test__field_crash.snap index fa9b982cc..5cf383f7e 100644 --- a/crates/mun_codegen/src/snapshots/test__field_crash.snap +++ b/crates/mun_codegen/src/snapshots/test__field_crash.snap @@ -5,7 +5,7 @@ expression: "struct(gc) Foo { a: int };\n\nfn main(c:int):int {\n let b = Foo ; ModuleID = 'main.mun' source_filename = "main.mun" -%DispatchTable = type { i8* (i64, i64)* } +%DispatchTable = type { i8* (i8 addrspace(4)*, i8 addrspace(4)*)*, i8* (i64, i64)* } %Foo = type { i64 } @dispatchTable = global %DispatchTable zeroinitializer @@ -18,7 +18,7 @@ body: %c1 = load i64, i64* %c %add = add i64 %c1, 5 %init = insertvalue %Foo undef, i64 %add, 0 - %malloc_ptr = load i8* (i64, i64)*, i8* (i64, i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0) + %malloc_ptr = load i8* (i64, i64)*, i8* (i64, i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 1) %malloc = call i8* %malloc_ptr(i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 ptrtoint (i64* getelementptr ({ i1, i64 }, { i1, i64 }* null, i64 0, i32 1) to i64)) %Foo = bitcast i8* %malloc to %Foo* store %Foo %init, %Foo* %Foo diff --git a/crates/mun_codegen/src/snapshots/test__field_expr.snap b/crates/mun_codegen/src/snapshots/test__field_expr.snap index 8668bcf45..9f8ea630d 100644 --- a/crates/mun_codegen/src/snapshots/test__field_expr.snap +++ b/crates/mun_codegen/src/snapshots/test__field_expr.snap @@ -1,15 +1,15 @@ --- source: crates/mun_codegen/src/test.rs -expression: "struct Bar(float, Foo);\nstruct Foo { a: int };\n\nfn bar_0(bar: Bar): float {\n bar.0\n}\n\nfn bar_1(bar: Bar): Foo {\n bar.1\n}\n\nfn bar_1_a(bar: Bar): int {\n bar.1.a\n}\n\nfn foo_a(foo: Foo): int {\n foo.a\n}\n\nfn bar_1_foo_a(bar: Bar): int {\n foo_a(bar_1(bar))\n}\n\nfn main(): int {\n let a: Foo = Foo { a: 5 };\n let b: Bar = Bar(1.23, a);\n let aa_lhs = a.a + 2;\n let aa_rhs = 2 + a.a;\n aa_lhs + aa_rhs\n}" +expression: "struct(value) Bar(float, Foo);\nstruct(value) Foo { a: int };\n\nfn bar_0(bar: Bar): float {\n bar.0\n}\n\nfn bar_1(bar: Bar): Foo {\n bar.1\n}\n\nfn bar_1_a(bar: Bar): int {\n bar.1.a\n}\n\nfn foo_a(foo: Foo): int {\n foo.a\n}\n\nfn bar_1_foo_a(bar: Bar): int {\n foo_a(bar_1(bar))\n}\n\nfn main(): int {\n let a: Foo = Foo { a: 5 };\n let b: Bar = Bar(1.23, a);\n let aa_lhs = a.a + 2;\n let aa_rhs = 2 + a.a;\n aa_lhs + aa_rhs\n}" --- ; ModuleID = 'main.mun' source_filename = "main.mun" -%DispatchTable = type { i64 (%Foo)*, %Foo (%Bar)* } +%DispatchTable = type { i64 (%Foo)*, %Foo (%Bar)*, i8* (i8 addrspace(4)*, i8 addrspace(4)*)* } %Foo = type { i64 } %Bar = type { double, %Foo } -@dispatchTable = global %DispatchTable { i64 (%Foo)* @foo_a, %Foo (%Bar)* @bar_1 } +@dispatchTable = global %DispatchTable { i64 (%Foo)* @foo_a, %Foo (%Bar)* @bar_1, i8* (i8 addrspace(4)*, i8 addrspace(4)*)* null } define double @bar_0(%Bar) { body: diff --git a/crates/mun_codegen/src/snapshots/test__gc_struct.snap b/crates/mun_codegen/src/snapshots/test__gc_struct.snap index 20846593c..0806e5f06 100644 --- a/crates/mun_codegen/src/snapshots/test__gc_struct.snap +++ b/crates/mun_codegen/src/snapshots/test__gc_struct.snap @@ -5,7 +5,7 @@ expression: "struct(gc) Foo { a: int, b: int };\n\nfn foo() {\n let a = Foo { ; ModuleID = 'main.mun' source_filename = "main.mun" -%DispatchTable = type { i8* (i64, i64)* } +%DispatchTable = type { i8* (i8 addrspace(4)*, i8 addrspace(4)*)*, i8* (i64, i64)* } %Foo = type { i64, i64 } @dispatchTable = global %DispatchTable zeroinitializer @@ -14,7 +14,7 @@ define void @foo() { body: %b4 = alloca %Foo* %a = alloca %Foo* - %malloc_ptr = load i8* (i64, i64)*, i8* (i64, i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 0) + %malloc_ptr = load i8* (i64, i64)*, i8* (i64, i64)** getelementptr inbounds (%DispatchTable, %DispatchTable* @dispatchTable, i32 0, i32 1) %malloc = call i8* %malloc_ptr(i64 mul nuw (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2), i64 ptrtoint (i64* getelementptr ({ i1, i64 }, { i1, i64 }* null, i64 0, i32 1) to i64)) %Foo = bitcast i8* %malloc to %Foo* store %Foo { i64 3, i64 4 }, %Foo* %Foo diff --git a/crates/mun_codegen/src/snapshots/test__struct_test.snap b/crates/mun_codegen/src/snapshots/test__struct_test.snap index ec0552775..7c12ba243 100644 --- a/crates/mun_codegen/src/snapshots/test__struct_test.snap +++ b/crates/mun_codegen/src/snapshots/test__struct_test.snap @@ -5,10 +5,13 @@ expression: "struct(value) Bar(float, int, bool, Foo);\nstruct(value) Foo { a: i ; ModuleID = 'main.mun' source_filename = "main.mun" +%DispatchTable = type { i8* (i8 addrspace(4)*, i8 addrspace(4)*)* } %Baz = type {} %Bar = type { double, i64, i1, %Foo } %Foo = type { i64 } +@dispatchTable = global %DispatchTable zeroinitializer + define void @foo() { body: %c = alloca %Baz diff --git a/crates/mun_codegen/src/type_info.rs b/crates/mun_codegen/src/type_info.rs index b56db2521..f088559d2 100644 --- a/crates/mun_codegen/src/type_info.rs +++ b/crates/mun_codegen/src/type_info.rs @@ -116,6 +116,13 @@ impl HasStaticTypeInfo for *const T { } } +// HACK: Manually add `*const TypeInfo` +impl HasStaticTypeInfo for *const TypeInfo { + fn type_info() -> TypeInfo { + TypeInfo::new("*const TypeInfo", TypeGroup::FundamentalTypes) + } +} + /// A trait that statically defines that a type can be used as a return type for a function. pub trait HasStaticReturnTypeInfo { fn return_type_info() -> Option; diff --git a/crates/mun_runtime/examples/hot_reloading.rs b/crates/mun_runtime/examples/hot_reloading.rs index 7f544d14e..cc0c63e7a 100644 --- a/crates/mun_runtime/examples/hot_reloading.rs +++ b/crates/mun_runtime/examples/hot_reloading.rs @@ -1,5 +1,7 @@ use mun_runtime::{invoke_fn, RetryResultExt, RuntimeBuilder}; +use std::cell::RefCell; use std::env; +use std::rc::Rc; // How to run? // 1. On the CLI, navigate to the `crates/mun_runtime/examples` directory. @@ -9,14 +11,15 @@ fn main() { let lib_dir = env::args().nth(1).expect("Expected path to a Mun library."); println!("lib: {}", lib_dir); - let mut runtime = RuntimeBuilder::new(lib_dir) + let runtime = RuntimeBuilder::new(lib_dir) .spawn() .expect("Failed to spawn Runtime"); + let runtime = Rc::new(RefCell::new(runtime)); loop { let n: i64 = invoke_fn!(runtime, "nth").wait(); let result: i64 = invoke_fn!(runtime, "fibonacci", n).wait(); println!("fibonacci({}) = {}", n, result); - runtime.update(); + runtime.borrow_mut().update(); } } diff --git a/crates/mun_runtime/src/lib.rs b/crates/mun_runtime/src/lib.rs index 14bea37ac..564909f93 100644 --- a/crates/mun_runtime/src/lib.rs +++ b/crates/mun_runtime/src/lib.rs @@ -19,19 +19,20 @@ use std::alloc::Layout; use std::collections::HashMap; use std::io; use std::path::{Path, PathBuf}; +use std::ptr; use std::sync::mpsc::{channel, Receiver}; use std::time::Duration; -use abi::{FunctionInfo, Privacy}; +use abi::{FunctionInfo, Privacy, TypeInfo}; use failure::Error; use function::FunctionInfoStorage; use notify::{DebouncedEvent, RecommendedWatcher, RecursiveMode, Watcher}; -pub use crate::marshal::MarshalInto; +pub use crate::marshal::Marshal; pub use crate::reflection::{ArgumentReflection, ReturnTypeReflection}; pub use crate::assembly::Assembly; -pub use crate::r#struct::Struct; +pub use crate::r#struct::StructRef; /// Options for the construction of a [`Runtime`]. #[derive(Clone, Debug)] @@ -116,6 +117,18 @@ extern "C" fn malloc(size: u64, alignment: u64) -> *mut u8 { } } +extern "C" fn clone(src: *const u8, ty: *const TypeInfo) -> *mut u8 { + let type_info = unsafe { ty.as_ref().unwrap() }; + let struct_info = type_info.as_struct().unwrap(); + let size = struct_info.field_offsets().last().cloned().unwrap_or(0) + + struct_info.field_sizes().last().cloned().unwrap_or(0); + let alignment = 8; + + let dest = malloc(size as u64, alignment); + unsafe { ptr::copy_nonoverlapping(src, dest, size as usize) }; + dest +} + impl Runtime { /// Constructs a new `Runtime` that loads the library at `library_path` and its /// dependencies. The `Runtime` contains a file watcher that is triggered with an interval @@ -131,8 +144,17 @@ impl Runtime { malloc as *const std::ffi::c_void, ); + let (clone_info, clone_storage) = FunctionInfoStorage::new_function( + "clone", + &["*const core::u8".to_string(), "*const TypeInfo".to_string()], + Some("*mut core::u8".to_string()), + Privacy::Public, + clone as *const std::ffi::c_void, + ); + let mut dispatch_table = DispatchTable::default(); dispatch_table.insert_fn("malloc", malloc_info); + dispatch_table.insert_fn("clone", clone_info); let watcher: RecommendedWatcher = Watcher::new(tx, options.delay)?; let mut runtime = Runtime { @@ -141,7 +163,7 @@ impl Runtime { watcher, watcher_rx: rx, - _local_fn_storage: vec![malloc_storage], + _local_fn_storage: vec![malloc_storage, clone_storage], }; runtime.add_assembly(&options.library_path)?; diff --git a/crates/mun_runtime/src/macros.rs b/crates/mun_runtime/src/macros.rs index d286dfa37..41d9c5deb 100644 --- a/crates/mun_runtime/src/macros.rs +++ b/crates/mun_runtime/src/macros.rs @@ -16,36 +16,36 @@ macro_rules! invoke_fn_impl { /// An invocation error that contains the function name, a mutable reference to the /// runtime, passed arguments, and the output type. This allows the caller to retry /// the function invocation using the `Retriable` trait. - pub struct $ErrName<'r, 's, $($T: ArgumentReflection,)* Output:ReturnTypeReflection> { + pub struct $ErrName<'s, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> { msg: String, - runtime: &'r mut Runtime, + runtime: std::rc::Rc>, function_name: &'s str, $($Arg: $T,)* output: core::marker::PhantomData, } - impl<'r, 's, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> core::fmt::Debug for $ErrName<'r, 's, $($T,)* Output> { + impl<'s, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> core::fmt::Debug for $ErrName<'s, $($T,)* Output> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{}", &self.msg) } } - impl<'r, 's, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> core::fmt::Display for $ErrName<'r, 's, $($T,)* Output> { + impl<'s, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> core::fmt::Display for $ErrName<'s, $($T,)* Output> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{}", &self.msg) } } - impl<'r, 's, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> std::error::Error for $ErrName<'r, 's, $($T,)* Output> { + impl<'s, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> std::error::Error for $ErrName<'s, $($T,)* Output> { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } } - impl<'r, 's, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> $ErrName<'r, 's, $($T,)* Output> { + impl<'s, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> $ErrName<'s, $($T,)* Output> { /// Constructs a new invocation error. #[allow(clippy::too_many_arguments)] - pub fn new(err_msg: String, runtime: &'r mut Runtime, function_name: &'s str, $($Arg: $T),*) -> Self { + pub fn new(err_msg: String, runtime: std::rc::Rc>, function_name: &'s str, $($Arg: $T),*) -> Self { Self { msg: err_msg, runtime, @@ -56,7 +56,7 @@ macro_rules! invoke_fn_impl { } } - impl<'r, 's, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> $crate::RetryResultExt for core::result::Result> { + impl<'s, $($T: ArgumentReflection,)* Output: ReturnTypeReflection> $crate::RetryResultExt for core::result::Result> { type Output = Output; fn retry(self) -> Self { @@ -64,10 +64,10 @@ macro_rules! invoke_fn_impl { Ok(output) => Ok(output), Err(err) => { eprintln!("{}", err.msg); - while !err.runtime.update() { + while !err.runtime.borrow_mut().update() { // Wait until there has been an update that might fix the error } - $crate::Runtime::$FnName(err.runtime, err.function_name, $(err.$Arg,)*) + $crate::Runtime::$FnName(&err.runtime, err.function_name, $(err.$Arg,)*) } } } @@ -76,8 +76,9 @@ macro_rules! invoke_fn_impl { loop { if let Ok(output) = self { return output; + } else { + self = self.retry(); } - self = self.retry(); } } } @@ -89,12 +90,13 @@ macro_rules! invoke_fn_impl { /// If an error occurs when invoking the method, an error message is logged. The /// runtime continues looping until the cause of the error has been resolved. #[allow(clippy::too_many_arguments, unused_assignments)] - pub fn $FnName<'r, 's, $($T: ArgumentReflection,)* Output: ReturnTypeReflection>( - runtime: &'r mut Runtime, + pub fn $FnName<'s, $($T: ArgumentReflection,)* Output: ReturnTypeReflection>( + runtime: &std::rc::Rc>, function_name: &'s str, $($Arg: $T,)* - ) -> core::result::Result> { + ) -> core::result::Result> { match runtime + .borrow() .get_function_info(function_name) .ok_or(format!("Failed to obtain function '{}'", function_name)) .and_then(|function_info| { @@ -148,9 +150,9 @@ macro_rules! invoke_fn_impl { let result = function($($Arg.marshal()),*); // Marshall the result - Ok(result.marshal_into(function_info.signature.return_type())) + return Ok(result.marshal_value(runtime.clone(), function_info.signature.return_type())) } - Err(e) => Err($ErrName::new(e, runtime, function_name, $($Arg),*)) + Err(e) => Err($ErrName::new(e, runtime.clone(), function_name, $($Arg),*)) } } } @@ -169,59 +171,46 @@ macro_rules! invoke_fn_impl { #[macro_export] macro_rules! invoke_fn { ($Runtime:expr, $FnName:expr) => { - $crate::Runtime::invoke_fn0(&mut $Runtime, $FnName) + $crate::Runtime::invoke_fn0(&$Runtime, $FnName) }; ($Runtime:expr, $FnName:expr, $A:expr) => { - $crate::Runtime::invoke_fn1(&mut $Runtime, $FnName, $A) + $crate::Runtime::invoke_fn1(&$Runtime, $FnName, $A) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr) => { - $crate::Runtime::invoke_fn2(&mut $Runtime, $FnName, $A, $B) + $crate::Runtime::invoke_fn2(&$Runtime, $FnName, $A, $B) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr) => { - $crate::Runtime::invoke_fn3(&mut $Runtime, $FnName, $A, $B, $C) + $crate::Runtime::invoke_fn3(&$Runtime, $FnName, $A, $B, $C) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr) => { - $crate::Runtime::invoke_fn4(&mut $Runtime, $FnName, $A, $B, $C, $D) + $crate::Runtime::invoke_fn4(&$Runtime, $FnName, $A, $B, $C, $D) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr) => { - $crate::Runtime::invoke_fn5(&mut $Runtime, $FnName, $A, $B, $C, $D, $E) + $crate::Runtime::invoke_fn5(&$Runtime, $FnName, $A, $B, $C, $D, $E) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr, $F:expr) => { - $crate::Runtime::invoke_fn6(&mut $Runtime, $FnName, $A, $B, $C, $D, $E, $F) + $crate::Runtime::invoke_fn6(&$Runtime, $FnName, $A, $B, $C, $D, $E, $F) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr, $F:expr, $G:expr) => { - $crate::Runtime::invoke_fn7(&mut $Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G) + $crate::Runtime::invoke_fn7(&$Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr, $F:expr, $G:expr, $H:expr) => { - $crate::Runtime::invoke_fn8(&mut $Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H) + $crate::Runtime::invoke_fn8(&$Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr, $F:expr, $G:expr, $H:expr, $I:expr) => { - $crate::Runtime::invoke_fn9(&mut $Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H, $I) + $crate::Runtime::invoke_fn9(&$Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H, $I) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr, $F:expr, $G:expr, $H:expr, $I:expr, $J:expr) => { - $crate::Runtime::invoke_fn10( - &mut $Runtime, - $FnName, - $A, - $B, - $C, - $D, - $E, - $F, - $G, - $H, - $I, - $J, - ) + $crate::Runtime::invoke_fn10(&$Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H, $I, $J) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr, $F:expr, $G:expr, $H:expr, $I:expr, $J:expr, $K:expr) => { $crate::Runtime::invoke_fn11( - $Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H, $I, $J, $K, + &$Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H, $I, $J, $K, ) }; ($Runtime:expr, $FnName:expr, $A:expr, $B:expr, $C:expr, $D:expr, $E:expr, $F:expr, $G:expr, $H:expr, $I:expr, $J:expr, $K:expr, $L:expr) => { $crate::Runtime::invoke_fn12( - $Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H, $I, $J, $K, $L, + &$Runtime, $FnName, $A, $B, $C, $D, $E, $F, $G, $H, $I, $J, $K, $L, ) }; } diff --git a/crates/mun_runtime/src/marshal.rs b/crates/mun_runtime/src/marshal.rs index 889fa42cd..15b2f5c5f 100644 --- a/crates/mun_runtime/src/marshal.rs +++ b/crates/mun_runtime/src/marshal.rs @@ -1,16 +1,44 @@ +use crate::Runtime; use abi::TypeInfo; +use std::cell::RefCell; +use std::ptr::NonNull; +use std::rc::Rc; /// Used to do value-to-value conversions that require runtime type information while consuming the /// input value. /// /// If no `TypeInfo` is provided, the type is `()`. -pub trait MarshalInto: Sized { - /// Performs the conversion. - fn marshal_into(self, type_info: Option<&TypeInfo>) -> T; +pub trait Marshal: Sized { + /// Marshals itself into a `T`. + fn marshal_value(self, runtime: Rc>, type_info: Option<&TypeInfo>) -> T; + + /// Marshals the value at memory location `ptr` into a `T`. + fn marshal_from_ptr( + ptr: NonNull, + runtime: Rc>, + type_info: Option<&TypeInfo>, + ) -> T; + + /// Marshals `value` to memory location `ptr`. + fn marshal_to_ptr(value: Self, ptr: NonNull, type_info: Option<&TypeInfo>); } -impl MarshalInto for T { - fn marshal_into(self, _type_info: Option<&TypeInfo>) -> T { +impl Marshal for T { + fn marshal_value(self, _runtime: Rc>, _type_info: Option<&TypeInfo>) -> T { self } + + fn marshal_from_ptr<'r>( + ptr: NonNull, + _runtime: Rc>, + _type_info: Option<&TypeInfo>, + ) -> T { + // TODO: Avoid unsafe `read` fn by using adding `Clone` trait to T. + // This also requires changes to the `impl Struct` + unsafe { ptr.as_ptr().read() } + } + + fn marshal_to_ptr(value: T, mut ptr: NonNull, _type_info: Option<&TypeInfo>) { + unsafe { *ptr.as_mut() = value }; + } } diff --git a/crates/mun_runtime/src/reflection.rs b/crates/mun_runtime/src/reflection.rs index b5d1a8e25..1402c819a 100644 --- a/crates/mun_runtime/src/reflection.rs +++ b/crates/mun_runtime/src/reflection.rs @@ -1,4 +1,4 @@ -use crate::{marshal::MarshalInto, Struct}; +use crate::{marshal::Marshal, StructRef}; use abi::{Guid, TypeInfo}; use md5; @@ -25,7 +25,7 @@ pub fn equals_return_type( } } abi::TypeGroup::StructTypes => { - if ::type_guid() != T::type_guid() { + if ::type_guid() != T::type_guid() { return Err(("struct", T::type_name())); } } @@ -34,9 +34,9 @@ pub fn equals_return_type( } /// A type to emulate dynamic typing across compilation units for static types. -pub trait ReturnTypeReflection: Sized + 'static { +pub trait ReturnTypeReflection: Sized { /// The resulting type after marshaling. - type Marshalled: MarshalInto; + type Marshalled: Marshal; /// Retrieves the type's `Guid`. fn type_guid() -> Guid { @@ -52,7 +52,7 @@ pub trait ReturnTypeReflection: Sized + 'static { /// A type to emulate dynamic typing across compilation units for statically typed values. pub trait ArgumentReflection: Sized { /// The resulting type after dereferencing. - type Marshalled: MarshalInto; + type Marshalled: Marshal; /// Retrieves the `Guid` of the value's type. fn type_guid(&self) -> Guid { @@ -116,6 +116,42 @@ impl ArgumentReflection for () { } } +impl ArgumentReflection for *const u8 { + type Marshalled = Self; + + fn type_name(&self) -> &str { + ::type_name() + } + + fn marshal(self) -> Self::Marshalled { + self + } +} + +impl ArgumentReflection for *mut u8 { + type Marshalled = Self; + + fn type_name(&self) -> &str { + ::type_name() + } + + fn marshal(self) -> Self::Marshalled { + self + } +} + +impl ArgumentReflection for *const TypeInfo { + type Marshalled = Self; + + fn type_name(&self) -> &str { + "*const TypeInfo" + } + + fn marshal(self) -> Self::Marshalled { + self + } +} + impl ReturnTypeReflection for f64 { type Marshalled = f64; @@ -147,3 +183,19 @@ impl ReturnTypeReflection for () { "core::empty" } } + +impl ReturnTypeReflection for *const u8 { + type Marshalled = Self; + + fn type_name() -> &'static str { + "*const core::u8" + } +} + +impl ReturnTypeReflection for *mut u8 { + type Marshalled = Self; + + fn type_name() -> &'static str { + "*mut core::u8" + } +} diff --git a/crates/mun_runtime/src/struct.rs b/crates/mun_runtime/src/struct.rs index 33fafb6b2..dbdfd1fe8 100644 --- a/crates/mun_runtime/src/struct.rs +++ b/crates/mun_runtime/src/struct.rs @@ -1,11 +1,14 @@ use crate::{ - marshal::MarshalInto, + marshal::Marshal, reflection::{ equals_argument_type, equals_return_type, ArgumentReflection, ReturnTypeReflection, }, + Runtime, }; -use abi::{StructInfo, TypeInfo}; -use std::mem; +use abi::{StructInfo, StructMemoryKind, TypeInfo}; +use std::cell::RefCell; +use std::ptr::{self, NonNull}; +use std::rc::Rc; /// Represents a Mun struct pointer. /// @@ -16,20 +19,21 @@ pub struct RawStruct(*mut u8); /// Type-agnostic wrapper for interoperability with a Mun struct. /// TODO: Handle destruction of `struct(value)` -#[derive(Clone)] -pub struct Struct { +pub struct StructRef { + runtime: Rc>, raw: RawStruct, info: StructInfo, } -impl Struct { +impl StructRef { /// Creates a struct that wraps a raw Mun struct. /// /// The provided [`TypeInfo`] must be for a struct type. - fn new(type_info: &TypeInfo, raw: RawStruct) -> Self { + fn new(runtime: Rc>, type_info: &TypeInfo, raw: RawStruct) -> StructRef { assert!(type_info.group.is_struct()); Self { + runtime, raw, info: type_info.as_struct().unwrap().clone(), } @@ -40,6 +44,22 @@ impl Struct { self.raw } + /// Retrieves its struct information. + pub fn info(&self) -> &StructInfo { + &self.info + } + + /// + /// + /// # Safety + /// + /// + unsafe fn offset_unchecked(&self, field_idx: usize) -> NonNull { + let offset = *self.info.field_offsets().get_unchecked(field_idx); + // self.raw is never null + NonNull::new_unchecked(self.raw.0.add(offset as usize)).cast::() + } + /// Retrieves the value of the field corresponding to the specified `field_name`. pub fn get(&self, field_name: &str) -> Result { let field_idx = StructInfo::find_field_index(&self.info, field_name)?; @@ -54,20 +74,13 @@ impl Struct { ) })?; - let field_value = unsafe { - // If we found the `field_idx`, we are guaranteed to also have the `field_offset` - let offset = *self.info.field_offsets().get_unchecked(field_idx); - // self.ptr is never null - // TODO: The unsafe `read` fn could be avoided by adding the `Clone` bound on - // `T::Marshalled`, but its only available on nightly: - // `ReturnTypeReflection` - self.raw - .0 - .add(offset as usize) - .cast::() - .read() - }; - Ok(field_value.marshal_into(Some(*field_type))) + // If we found the `field_idx`, we are guaranteed to also have the `field_offset` + let field_ptr = unsafe { self.offset_unchecked::(field_idx) }; + Ok(Marshal::marshal_from_ptr( + field_ptr, + self.runtime.clone(), + Some(*field_type), + )) } /// Replaces the value of the field corresponding to the specified `field_name` and returns the @@ -89,15 +102,10 @@ impl Struct { ) })?; - let mut marshalled: T::Marshalled = value.marshal(); - let ptr = unsafe { - // If we found the `field_idx`, we are guaranteed to also have the `field_offset` - let offset = *self.info.field_offsets().get_unchecked(field_idx); - // self.ptr is never null - &mut *self.raw.0.add(offset as usize).cast::() - }; - mem::swap(&mut marshalled, ptr); - Ok(marshalled.marshal_into(Some(*field_type))) + let field_ptr = unsafe { self.offset_unchecked::(field_idx) }; + let old = Marshal::marshal_from_ptr(field_ptr, self.runtime.clone(), Some(*field_type)); + Marshal::marshal_to_ptr(value.marshal(), field_ptr, Some(*field_type)); + Ok(old) } /// Sets the value of the field corresponding to the specified `field_name`. @@ -114,17 +122,13 @@ impl Struct { ) })?; - unsafe { - // If we found the `field_idx`, we are guaranteed to also have the `field_offset` - let offset = *self.info.field_offsets().get_unchecked(field_idx); - // self.ptr is never null - *self.raw.0.add(offset as usize).cast::() = value.marshal(); - } + let field_ptr = unsafe { self.offset_unchecked::(field_idx) }; + Marshal::marshal_to_ptr(value.marshal(), field_ptr, Some(*field_type)); Ok(()) } } -impl ArgumentReflection for Struct { +impl ArgumentReflection for StructRef { type Marshalled = RawStruct; fn type_name(&self) -> &str { @@ -136,7 +140,7 @@ impl ArgumentReflection for Struct { } } -impl ReturnTypeReflection for Struct { +impl ReturnTypeReflection for StructRef { type Marshalled = RawStruct; fn type_name() -> &'static str { @@ -144,9 +148,48 @@ impl ReturnTypeReflection for Struct { } } -impl MarshalInto for RawStruct { - fn marshal_into(self, type_info: Option<&TypeInfo>) -> Struct { +impl Marshal for RawStruct { + fn marshal_value( + self, + runtime: Rc>, + type_info: Option<&TypeInfo>, + ) -> StructRef { + // `type_info` is only `None` for the `()` type + StructRef::new(runtime, type_info.unwrap(), self) + } + + fn marshal_from_ptr( + ptr: NonNull, + runtime: Rc>, + type_info: Option<&TypeInfo>, + ) -> StructRef { // `type_info` is only `None` for the `()` type - Struct::new(type_info.unwrap(), self) + let type_info = type_info.unwrap(); + + let struct_info = type_info.as_struct().unwrap(); + let ptr = if struct_info.memory_kind == StructMemoryKind::Value { + ptr.cast::().as_ptr() as *const _ + } else { + unsafe { ptr.as_ref() }.0 as *const _ + }; + + // Clone the struct using the runtime's intrinsic + let cloned_ptr = invoke_fn!(runtime.clone(), "clone", ptr, type_info as *const _).unwrap(); + StructRef::new(runtime, type_info, RawStruct(cloned_ptr)) + } + + fn marshal_to_ptr(value: RawStruct, mut ptr: NonNull, type_info: Option<&TypeInfo>) { + // `type_info` is only `None` for the `()` type + let type_info = type_info.unwrap(); + + let struct_info = type_info.as_struct().unwrap(); + if struct_info.memory_kind == StructMemoryKind::Value { + let dest = ptr.cast::().as_ptr(); + let size = struct_info.field_offsets().last().cloned().unwrap_or(0) + + struct_info.field_sizes().last().cloned().unwrap_or(0); + unsafe { ptr::copy_nonoverlapping(value.0, dest, size as usize) }; + } else { + unsafe { *ptr.as_mut() = value }; + } } } diff --git a/crates/mun_runtime/src/test.rs b/crates/mun_runtime/src/test.rs index 2366d1f11..28218a5c1 100644 --- a/crates/mun_runtime/src/test.rs +++ b/crates/mun_runtime/src/test.rs @@ -1,6 +1,8 @@ -use crate::{Runtime, RuntimeBuilder, Struct}; +use crate::{ArgumentReflection, ReturnTypeReflection, Runtime, RuntimeBuilder, StructRef}; use mun_compiler::{ColorChoice, Config, Driver, FileId, PathOrInline, RelativePathBuf}; +use std::cell::RefCell; use std::path::PathBuf; +use std::rc::Rc; use std::thread::sleep; use std::time::Duration; @@ -11,7 +13,7 @@ struct TestDriver { out_path: PathBuf, file_id: FileId, driver: Driver, - runtime: Runtime, + runtime: Rc>, } impl TestDriver { @@ -38,7 +40,7 @@ impl TestDriver { driver, out_path, file_id, - runtime, + runtime: Rc::new(RefCell::new(runtime)), } } @@ -51,7 +53,7 @@ impl TestDriver { "recompiling did not result in the same assembly" ); let start_time = std::time::Instant::now(); - while !self.runtime.update() { + while !self.runtime.borrow_mut().update() { let now = std::time::Instant::now(); if now - start_time > std::time::Duration::from_secs(10) { panic!("runtime did not update after recompilation within 10secs"); @@ -60,23 +62,18 @@ impl TestDriver { } } } - - /// Returns the `Runtime` used by this instance - fn runtime_mut(&mut self) -> &mut Runtime { - &mut self.runtime - } } macro_rules! assert_invoke_eq { ($ExpectedType:ty, $ExpectedResult:expr, $Driver:expr, $($Arg:tt)+) => { - let result: $ExpectedType = invoke_fn!($Driver.runtime_mut(), $($Arg)*).unwrap(); + let result: $ExpectedType = invoke_fn!($Driver.runtime, $($Arg)*).unwrap(); assert_eq!(result, $ExpectedResult, "{} == {:?}", stringify!(invoke_fn!($Driver.runtime_mut(), $($Arg)*).unwrap()), $ExpectedResult); } } #[test] fn compile_and_run() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r" pub fn main() {} ", @@ -86,7 +83,7 @@ fn compile_and_run() { #[test] fn return_value() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r" pub fn main():int { 3 } ", @@ -96,7 +93,7 @@ fn return_value() { #[test] fn arguments() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r" pub fn main(a:int, b:int):int { a+b } ", @@ -108,7 +105,7 @@ fn arguments() { #[test] fn dispatch_table() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r" pub fn add(a:int, b:int):int { a+b } pub fn main(a:int, b:int):int { add(a,b) } @@ -126,7 +123,7 @@ fn dispatch_table() { #[test] fn booleans() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" pub fn equal(a:int, b:int):bool { a==b } pub fn equalf(a:float, b:float):bool { a==b } @@ -170,7 +167,7 @@ fn booleans() { #[test] fn fibonacci() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" pub fn fibonacci(n:int):int { if n <= 1 { @@ -189,7 +186,7 @@ fn fibonacci() { #[test] fn fibonacci_loop() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" pub fn fibonacci(n:int):int { let a = 0; @@ -216,7 +213,7 @@ fn fibonacci_loop() { #[test] fn fibonacci_loop_break() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" pub fn fibonacci(n:int):int { let a = 0; @@ -243,7 +240,7 @@ fn fibonacci_loop_break() { #[test] fn fibonacci_while() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" pub fn fibonacci(n:int):int { let a = 0; @@ -268,7 +265,7 @@ fn fibonacci_while() { #[test] fn true_is_true() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" pub fn test_true():bool { true @@ -314,7 +311,8 @@ fn compiler_valid_utf8() { "#, ); - let foo_func = driver.runtime.get_function_info("foo").unwrap(); + let borrowed = driver.runtime.borrow(); + let foo_func = borrowed.get_function_info("foo").unwrap(); assert_eq!( unsafe { CStr::from_ptr(foo_func.signature.name) } .to_str() @@ -352,7 +350,7 @@ fn compiler_valid_utf8() { #[test] fn fields() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" struct(gc) Foo { a:int, b:int }; pub fn main(foo:int):bool { @@ -369,7 +367,7 @@ fn fields() { #[test] fn field_crash() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" struct(gc) Foo { a: int }; @@ -384,79 +382,132 @@ fn field_crash() { #[test] fn marshal_struct() { - let mut driver = TestDriver::new( + let driver = TestDriver::new( r#" - struct(gc) Foo { a: int, b: bool, c: float, }; - struct Bar(Foo); + struct(value) Foo { a: int, b: bool }; + struct Bar(int, bool); + struct(value) Baz(Foo); + struct(gc) Qux(Bar); - pub fn foo_new(a: int, b: bool, c: float): Foo { - Foo { a, b, c, } + pub fn foo_new(a: int, b: bool): Foo { + Foo { a, b, } } - pub fn bar_new(foo: Foo): Bar { - Bar(foo) + pub fn bar_new(a: int, b: bool): Bar { + Bar(a, b) + } + pub fn baz_new(foo: Foo): Baz { + Baz(foo) + } + pub fn qux_new(bar: Bar): Qux { + Qux(bar) } - - pub fn foo_a(foo: Foo):int { foo.a } - pub fn foo_b(foo: Foo):bool { foo.b } - pub fn foo_c(foo: Foo):float { foo.c } "#, ); - let a = 3i64; - let b = true; - let c = 1.23f64; - let mut foo: Struct = invoke_fn!(driver.runtime, "foo_new", a, b, c).unwrap(); - assert_eq!(Ok(a), foo.get::("a")); - assert_eq!(Ok(b), foo.get::("b")); - assert_eq!(Ok(c), foo.get::("c")); - - let d = 6i64; - let e = false; - let f = 4.56f64; - foo.set("a", d).unwrap(); - foo.set("b", e).unwrap(); - foo.set("c", f).unwrap(); - - assert_eq!(Ok(d), foo.get::("a")); - assert_eq!(Ok(e), foo.get::("b")); - assert_eq!(Ok(f), foo.get::("c")); - - assert_eq!(Ok(d), foo.replace("a", a)); - assert_eq!(Ok(e), foo.replace("b", b)); - assert_eq!(Ok(f), foo.replace("c", c)); - - assert_eq!(Ok(a), foo.get::("a")); - assert_eq!(Ok(b), foo.get::("b")); - assert_eq!(Ok(c), foo.get::("c")); - - assert_invoke_eq!(i64, a, driver, "foo_a", foo.clone()); - assert_invoke_eq!(bool, b, driver, "foo_b", foo.clone()); - assert_invoke_eq!(f64, c, driver, "foo_c", foo.clone()); - - let mut bar: Struct = invoke_fn!(driver.runtime, "bar_new", foo.clone()).unwrap(); - let foo2 = bar.get::("0").unwrap(); - assert_eq!(Ok(a), foo2.get::("a")); - assert_eq!(foo2.get::("b"), foo.get::("b")); - assert_eq!(foo2.get::("c"), foo.get::("c")); + struct TestData(T, T); + + fn test_field< + T: Copy + std::fmt::Debug + PartialEq + ArgumentReflection + ReturnTypeReflection, + >( + s: &mut StructRef, + data: &TestData, + field_name: &str, + ) { + assert_eq!(Ok(data.0), s.get::(field_name)); + s.set(field_name, data.1).unwrap(); + assert_eq!(Ok(data.1), s.replace(field_name, data.0)); + assert_eq!(Ok(data.0), s.get::(field_name)); + } + + let int_data = TestData(3i64, 6i64); + let bool_data = TestData(true, false); + + // Verify that struct marshalling works for fundamental types + let mut foo: StructRef = + invoke_fn!(driver.runtime, "foo_new", int_data.0, bool_data.0).unwrap(); + test_field(&mut foo, &int_data, "a"); + test_field(&mut foo, &bool_data, "b"); + + let mut bar: StructRef = + invoke_fn!(driver.runtime, "bar_new", int_data.0, bool_data.0).unwrap(); + test_field(&mut bar, &int_data, "0"); + test_field(&mut bar, &bool_data, "1"); + + fn test_struct(s: &mut StructRef, c1: StructRef, c2: StructRef) { + let field_names: Vec = c1.info().field_names().map(|n| n.to_string()).collect(); + + let int_value = c2.get::(&field_names[0]); + let bool_value = c2.get::(&field_names[1]); + s.set("0", c2).unwrap(); + + let c2 = s.get::("0").unwrap(); + assert_eq!(c2.get::(&field_names[0]), int_value); + assert_eq!(c2.get::(&field_names[1]), bool_value); + + let int_value = c1.get::(&field_names[0]); + let bool_value = c1.get::(&field_names[1]); + s.replace("0", c1).unwrap(); + + let c1 = s.get::("0").unwrap(); + assert_eq!(c1.get::(&field_names[0]), int_value); + assert_eq!(c1.get::(&field_names[1]), bool_value); + } + + // Verify that struct marshalling works for struct types + let mut baz: StructRef = invoke_fn!(driver.runtime, "baz_new", foo).unwrap(); + let c1: StructRef = invoke_fn!(driver.runtime, "foo_new", int_data.0, bool_data.0).unwrap(); + let c2: StructRef = invoke_fn!(driver.runtime, "foo_new", int_data.1, bool_data.1).unwrap(); + test_struct(&mut baz, c1, c2); + + let mut qux: StructRef = invoke_fn!(driver.runtime, "qux_new", bar).unwrap(); + let c1: StructRef = invoke_fn!(driver.runtime, "bar_new", int_data.0, bool_data.0).unwrap(); + let c2: StructRef = invoke_fn!(driver.runtime, "bar_new", int_data.1, bool_data.1).unwrap(); + test_struct(&mut qux, c1, c2); + + fn test_shallow_copy< + T: Copy + std::fmt::Debug + PartialEq + ArgumentReflection + ReturnTypeReflection, + >( + s1: &mut StructRef, + s2: &StructRef, + data: &TestData, + field_name: &str, + ) { + assert_eq!(s1.get::(field_name), s2.get::(field_name)); + s1.set(field_name, data.1).unwrap(); + assert_ne!(s1.get::(field_name), s2.get::(field_name)); + s1.replace(field_name, data.0).unwrap(); + assert_eq!(s1.get::(field_name), s2.get::(field_name)); + } + + // Verify that StructRef::get makes a shallow copy of a struct + let mut foo = baz.get::("0").unwrap(); + let foo2 = baz.get::("0").unwrap(); + test_shallow_copy(&mut foo, &foo2, &int_data, "a"); + test_shallow_copy(&mut foo, &foo2, &bool_data, "b"); + + let mut bar = qux.get::("0").unwrap(); + let bar2 = qux.get::("0").unwrap(); + test_shallow_copy(&mut bar, &bar2, &int_data, "0"); + test_shallow_copy(&mut bar, &bar2, &bool_data, "1"); // Specify invalid return type - let bar_err = bar.get::("0"); + let bar_err = bar.get::("0"); assert!(bar_err.is_err()); // Specify invalid argument type - let bar_err = bar.replace("0", 1i64); + let bar_err = bar.replace("0", 1f64); assert!(bar_err.is_err()); // Specify invalid argument type - let bar_err = bar.set("0", 1i64); + let bar_err = bar.set("0", 1f64); assert!(bar_err.is_err()); // Specify invalid return type - let bar_err: Result = invoke_fn!(driver.runtime, "bar_new", foo); + let bar_err: Result = invoke_fn!(driver.runtime, "baz_new", foo); assert!(bar_err.is_err()); // Pass invalid struct type - let bar_err: Result = invoke_fn!(driver.runtime, "bar_new", bar); + let bar_err: Result = invoke_fn!(driver.runtime, "baz_new", bar); assert!(bar_err.is_err()); } @@ -464,7 +515,7 @@ fn marshal_struct() { fn hotreload_struct_decl() { let mut driver = TestDriver::new( r#" - struct(value) Args { + struct(gc) Args { n: int, foo: Bar, } @@ -480,7 +531,7 @@ fn hotreload_struct_decl() { ); driver.update( r#" - struct(value) Args { + struct(gc) Args { n: int, foo: Bar, }