Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Allow extension callbacks to have non-'static lifetimes #128

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion src/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
mut self,
handler: impl 'a
+ for<'c> Fn(
&mut EmitFuncContext<'c, H>,
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
Op,
) -> Result<()>,
Expand Down Expand Up @@ -152,3 +152,73 @@ pub struct CodegenExtsMap<'a, H> {
pub extension_op_handlers: Rc<ExtensionOpMap<'a, H>>,
pub type_converter: Rc<TypeConverter<'a>>,
}

#[cfg(test)]
mod test {
use hugr::{
extension::prelude::{ConstString, PRELUDE_ID, PRINT_OP_ID, STRING_TYPE, STRING_TYPE_NAME},
Hugr,
};
use inkwell::{
context::Context,
types::BasicType,
values::{BasicMetadataValueEnum, BasicValue},
};
use itertools::Itertools as _;

use crate::{emit::libc::emit_libc_printf, CodegenExtsBuilder};

#[test]
fn types_with_lifetimes() {
let n = "name_with_lifetime".to_string();

let cem = CodegenExtsBuilder::<Hugr>::default()
.custom_type((PRELUDE_ID, STRING_TYPE_NAME), |session, _| {
let ctx = session.iw_context();
Ok(ctx
.get_struct_type(n.as_ref())
.unwrap_or_else(|| ctx.opaque_struct_type(n.as_ref()))
.as_basic_type_enum())
})
.finish();

let ctx = Context::create();

let ty = cem
.type_converter
.session(&ctx)
.llvm_type(&STRING_TYPE)
.unwrap()
.into_struct_type();
let ty_n = ty.get_name().unwrap().to_str().unwrap();
assert_eq!(ty_n, n);
}

#[test]
fn custom_const_lifetime_of_context() {
let ctx = Context::create();

let _ = CodegenExtsBuilder::<Hugr>::default()
.custom_const::<ConstString>(|_, konst| {
Ok(ctx
.const_string(konst.value().as_bytes(), true)
.as_basic_value_enum())
})
.finish();
}

#[test]
fn extension_op_lifetime() {
let ctx = Context::create();

let _ = CodegenExtsBuilder::<Hugr>::default()
.extension_op(PRELUDE_ID, PRINT_OP_ID, |context, args| {
let mut print_args: Vec<BasicMetadataValueEnum> =
vec![ctx.const_string("%s".as_bytes(), true).into()];
print_args.extend(args.inputs.into_iter().map_into::<BasicMetadataValueEnum>());
emit_libc_printf(context, &print_args)?;
args.outputs.finish(context.builder(), [])
})
.finish();
}
}
8 changes: 4 additions & 4 deletions src/custom/extension_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ use crate::emit::{EmitFuncContext, EmitOpArgs};
///
/// Callbacks may hold references with lifetimes older than `'a`.
pub trait ExtensionOpFn<'a, H>:
for<'c> Fn(&mut EmitFuncContext<'c, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, EmitOpArgs<'c, '_, ExtensionOp, H>) -> Result<()> + 'a
{
}

impl<
'a,
H,
F: for<'c> Fn(
&mut EmitFuncContext<'c, H>,
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()>
+ ?Sized
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<'a, H: HugrView> ExtensionOpMap<'a, H> {
&mut self,
handler: impl 'a
+ for<'c> Fn(
&mut EmitFuncContext<'c, H>,
&mut EmitFuncContext<'c, 'a, H>,
EmitOpArgs<'c, '_, ExtensionOp, H>,
Op,
) -> Result<()>,
Expand All @@ -96,7 +96,7 @@ impl<'a, H: HugrView> ExtensionOpMap<'a, H> {
/// If no handler is registered for the op an error will be returned.
pub fn emit_extension_op<'c>(
&self,
context: &mut EmitFuncContext<'c, H>,
context: &mut EmitFuncContext<'c, 'a, H>,
args: EmitOpArgs<'c, '_, ExtensionOp, H>,
) -> Result<()> {
let node = args.node();
Expand Down
8 changes: 5 additions & 3 deletions src/custom/load_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ use crate::emit::EmitFuncContext;
///
/// Callbacks may hold references with lifetimes older than `'a`.
pub trait LoadConstantFn<'a, H: ?Sized, CC: CustomConst + ?Sized>:
for<'c> Fn(&mut EmitFuncContext<'c, H>, &CC) -> Result<BasicValueEnum<'c>> + 'a
for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, &CC) -> Result<BasicValueEnum<'c>> + 'a
{
}

impl<
'a,
H: ?Sized,
CC: ?Sized + CustomConst,
F: 'a + ?Sized + for<'c> Fn(&mut EmitFuncContext<'c, H>, &CC) -> Result<BasicValueEnum<'c>>,
F: 'a
+ ?Sized
+ for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, &CC) -> Result<BasicValueEnum<'c>>,
> LoadConstantFn<'a, H, CC> for F
{
}
Expand Down Expand Up @@ -59,7 +61,7 @@ impl<'a, H: HugrView> LoadConstantsMap<'a, H> {
/// appropriate inner callbacks.
pub fn emit_load_constant<'c>(
&self,
context: &mut EmitFuncContext<'c, H>,
context: &mut EmitFuncContext<'c, 'a, H>,
konst: &dyn CustomConst,
) -> Result<BasicValueEnum<'c>> {
let type_id = konst.type_id();
Expand Down
16 changes: 9 additions & 7 deletions src/custom/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::marker::PhantomData;

use itertools::Itertools as _;

use hugr::types::CustomType;
Expand All @@ -14,22 +16,22 @@ use crate::{
};

pub trait LLVMCustomTypeFn<'a>:
for<'c> Fn(TypingSession<'c>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a
for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a
{
}

impl<
'a,
F: for<'c> Fn(TypingSession<'c>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a + ?Sized,
F: for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a + ?Sized,
> LLVMCustomTypeFn<'a> for F
{
}

#[derive(Default, Clone)]
pub struct LLVMTypeMapping;
pub struct LLVMTypeMapping<'a>(PhantomData<&'a ()>);

impl TypeMapping for LLVMTypeMapping {
type InV<'c> = TypingSession<'c>;
impl<'a> TypeMapping for LLVMTypeMapping<'a> {
type InV<'c> = TypingSession<'c, 'a>;

type OutV<'c> = BasicTypeEnum<'c>;

Expand All @@ -48,7 +50,7 @@ impl TypeMapping for LLVMTypeMapping {
fn map_sum_type<'c>(
&self,
sum_type: &HugrSumType,
context: TypingSession<'c>,
context: TypingSession<'c, 'a>,
variants: impl IntoIterator<Item = Vec<Self::OutV<'c>>>,
) -> Result<Self::SumOutV<'c>> {
LLVMSumType::try_new2(
Expand All @@ -61,7 +63,7 @@ impl TypeMapping for LLVMTypeMapping {
fn map_function_type<'c>(
&self,
_: &HugrFuncType,
context: TypingSession<'c>,
context: TypingSession<'c, 'a>,
inputs: impl IntoIterator<Item = Self::OutV<'c>>,
outputs: impl IntoIterator<Item = Self::OutV<'c>>,
) -> Result<Self::FuncOutV<'c>> {
Expand Down
30 changes: 18 additions & 12 deletions src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ pub use ops::emit_value;
/// This includes the module itself, a set of extensions for lowering custom
/// elements, and policy for naming various HUGR elements.
///
/// `'c` names the lifetime of the LLVM context.
// TODO add another lifetime parameter for `extensions` below.
pub struct EmitModuleContext<'c, H> {
/// `'c` names the lifetime of the LLVM context, while `'a` names the lifetime
/// of other internal references.
pub struct EmitModuleContext<'c, 'a, H>
where
'a: 'c,
{
iw_context: &'c Context,
module: Module<'c>,
extensions: Rc<CodegenExtsMap<'static, H>>,
extensions: Rc<CodegenExtsMap<'a, H>>,
namer: Rc<Namer>,
}

impl<'c, H> EmitModuleContext<'c, H> {
impl<'c, 'a, H> EmitModuleContext<'c, 'a, H> {
delegate! {
to self.typing_session() {
/// Convert a [HugrType] into an LLVM [Type](BasicTypeEnum).
Expand All @@ -70,7 +73,7 @@ impl<'c, H> EmitModuleContext<'c, H> {
iw_context: &'c Context,
module: Module<'c>,
namer: Rc<Namer>,
extensions: Rc<CodegenExtsMap<'static, H>>,
extensions: Rc<CodegenExtsMap<'a, H>>,
) -> Self {
Self {
iw_context,
Expand All @@ -88,12 +91,12 @@ impl<'c, H> EmitModuleContext<'c, H> {
}

/// Returns a reference to the inner [CodegenExtsMap].
pub fn extensions(&self) -> Rc<CodegenExtsMap<'static, H>> {
pub fn extensions(&self) -> Rc<CodegenExtsMap<'a, H>> {
self.extensions.clone()
}

/// Returns a [TypingSession] constructed from it's members.
pub fn typing_session(&self) -> TypingSession<'c> {
pub fn typing_session(&self) -> TypingSession<'c, 'a> {
self.extensions
.type_converter
.clone()
Expand Down Expand Up @@ -235,12 +238,15 @@ impl<'c, H> EmitModuleContext<'c, H> {
type EmissionSet = HashSet<Node>;

/// Emits [HugrView]s into an LLVM [Module].
pub struct EmitHugr<'c, H> {
pub struct EmitHugr<'c, 'a, H>
where
'a: 'c,
{
emitted: EmissionSet,
module_context: EmitModuleContext<'c, H>,
module_context: EmitModuleContext<'c, 'a, H>,
}

impl<'c, H: HugrView> EmitHugr<'c, H> {
impl<'c, 'a, H: HugrView> EmitHugr<'c, 'a, H> {
delegate! {
to self.module_context {
/// Returns a reference to the inner [Context].
Expand All @@ -257,7 +263,7 @@ impl<'c, H: HugrView> EmitHugr<'c, H> {
iw_context: &'c Context,
module: Module<'c>,
namer: Rc<Namer>,
extensions: Rc<CodegenExtsMap<'static, H>>,
extensions: Rc<CodegenExtsMap<'a, H>>,
) -> Self {
assert_eq!(iw_context, &module.get_context());
Self {
Expand Down
25 changes: 13 additions & 12 deletions src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ pub use mailbox::{RowMailBox, RowPromise};
/// [MailBox](RowMailBox)es are stack allocations that are `alloca`ed in the
/// first basic block of the function, read from to get the input values of each
/// node, and written to with the output values of each node.
///
// TODO add another lifetime parameter which `emit_context` will need.
pub struct EmitFuncContext<'c, H> {
emit_context: EmitModuleContext<'c, H>,
pub struct EmitFuncContext<'c, 'a, H>
where
'a: 'c,
{
emit_context: EmitModuleContext<'c, 'a, H>,
todo: EmissionSet,
func: FunctionValue<'c>,
env: HashMap<Wire, ValueMailBox<'c>>,
Expand All @@ -55,15 +56,15 @@ pub struct EmitFuncContext<'c, H> {
launch_bb: BasicBlock<'c>,
}

impl<'c, H: HugrView> EmitFuncContext<'c, H> {
impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> {
delegate! {
to self.emit_context {
/// Returns the inkwell [Context].
pub fn iw_context(&self) -> &'c Context;
/// Returns the internal [CodegenExtsMap] .
pub fn extensions(&self) -> Rc<CodegenExtsMap<'static,H>>;
pub fn extensions(&self) -> Rc<CodegenExtsMap<'a,H>>;
/// Returns a new [TypingSession].
pub fn typing_session(&self) -> TypingSession<'c>;
pub fn typing_session(&self) -> TypingSession<'c, 'a>;
/// Convert hugr [HugrType] into an LLVM [Type](BasicTypeEnum).
pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c> >;
/// Convert a [HugrFuncType] into an LLVM [FunctionType].
Expand Down Expand Up @@ -143,9 +144,9 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
///
/// TODO on failure return `emit_context`
pub fn new(
emit_context: EmitModuleContext<'c, H>,
emit_context: EmitModuleContext<'c, 'a, H>,
func: FunctionValue<'c>,
) -> Result<EmitFuncContext<'c, H>> {
) -> Result<EmitFuncContext<'c, 'a, H>> {
if func.get_first_basic_block().is_some() {
Err(anyhow!(
"EmitContext::new: Function already has a basic block: {:?}",
Expand Down Expand Up @@ -180,9 +181,9 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
/// Create a new anonymous [RowMailBox]. This mailbox is not mapped to any
/// [Wire]s, and so will not interact with any mailboxes returned from
/// [EmitFuncContext::node_ins_rmb] or [EmitFuncContext::node_outs_rmb].
pub fn new_row_mail_box<'a>(
pub fn new_row_mail_box<'t>(
&mut self,
ts: impl IntoIterator<Item = &'a Type>,
ts: impl IntoIterator<Item = &'t Type>,
name: impl AsRef<str>,
) -> Result<RowMailBox<'c>> {
Ok(RowMailBox::new(
Expand Down Expand Up @@ -307,7 +308,7 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {

/// Consumes the `EmitFuncContext` and returns both the inner
/// [EmitModuleContext] and the scoped [FuncDefn]s that were encountered.
pub fn finish(self) -> Result<(EmitModuleContext<'c, H>, EmissionSet)> {
pub fn finish(self) -> Result<(EmitModuleContext<'c, 'a, H>, EmissionSet)> {
self.builder.position_at_end(self.prologue_bb);
self.builder.build_unconditional_branch(self.launch_bb)?;
Ok((self.emit_context, self.todo))
Expand Down
Loading