Skip to content
This repository was archived by the owner on Mar 5, 2025. It is now read-only.

feat: add get_extern_func #28

Merged
merged 2 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use anyhow::{anyhow, Result};
use delegate::delegate;
use hugr::{
ops::{FuncDecl, FuncDefn, NamedOp as _, OpType},
types::PolyFuncType,
HugrView, Node, NodeIndex,
};
use inkwell::{
context::Context,
module::Module,
module::{Linkage, Module},
types::{BasicTypeEnum, FunctionType},
values::{BasicValueEnum, FunctionValue},
};
Expand Down Expand Up @@ -205,39 +206,68 @@ impl<'c, H: HugrView> EmitModuleContext<'c, H> {
fn get_func_impl(
&self,
name: impl AsRef<str>,
node: Node,
func_ty: &hugr::types::PolyFuncType,
func_ty: FunctionType<'c>,
linkage: Option<Linkage>,
) -> Result<FunctionValue<'c>> {
let sig = (func_ty.params().is_empty())
.then_some(func_ty.body())
.ok_or(anyhow!("function has type params"))?;
let llvm_func_ty = self.llvm_func_type(sig)?;
let name = self.name_func(name, node);
let func = self
.module()
.get_function(&name)
.unwrap_or_else(|| self.module.add_function(&name, llvm_func_ty, None));
if func.get_type() != llvm_func_ty {
.get_function(name.as_ref())
.unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
if func.get_type() != func_ty {
Err(anyhow!(
"Function '{name}' has wrong type: hugr: {func_ty} expected: {llvm_func_ty} actual: {}",
"Function '{}' has wrong type: expected: {func_ty} actual: {}",
name.as_ref(),
func.get_type()
))?
}
Ok(func)
}

fn get_hugr_func_impl(
&self,
name: impl AsRef<str>,
node: Node,
func_ty: &PolyFuncType,
) -> Result<FunctionValue<'c>> {
let func_ty = (func_ty.params().is_empty())
.then_some(func_ty.body())
.ok_or(anyhow!("function has type params"))?;
let llvm_func_ty = self.llvm_func_type(func_ty)?;
let name = self.name_func(name, node);
self.get_func_impl(name, llvm_func_ty, None)
}

/// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDefn].
///
/// The name of the result is mangled by [EmitModuleContext::name_func].
pub fn get_func_defn(&self, node: FatNode<'c, FuncDefn, H>) -> Result<FunctionValue<'c>> {
self.get_func_impl(&node.name, node.node(), &node.signature)
self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
}

/// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDecl].
///
/// The name of the result is mangled by [EmitModuleContext::name_func].
pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result<FunctionValue<'c>> {
self.get_func_impl(&node.name, node.node(), &node.signature)
self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
}

/// Adds or get the [FunctionValue] in the [Module] with the given symbol
/// and function type.
///
/// The name undergoes no mangling. The [FunctionValue] will have
/// [Linkage::External].
///
/// If this function is called multiple times with the same arguments it
/// will return the same [FunctionValue].
///
/// If a function with the given name exists but the type does not match
/// then an Error is returned.
pub fn get_extern_func(
&self,
symbol: impl AsRef<str>,
typ: FunctionType<'c>,
) -> Result<FunctionValue<'c>> {
self.get_func_impl(symbol, typ, Some(Linkage::External))
}

/// Consumes the `EmitModuleContext` and returns the internal [Module].
Expand Down
13 changes: 13 additions & 0 deletions src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub struct EmitFuncContext<'c, H: HugrView> {
impl<'c, H: HugrView> EmitFuncContext<'c, H> {
delegate! {
to self.emit_context {
/// Returns the inkwell [Context].
fn iw_context(&self) -> &'c Context;
/// Returns the internal [CodegenExtsMap] .
pub fn extensions(&self) -> Rc<CodegenExtsMap<'c,H>>;
Expand All @@ -78,6 +79,18 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
///
/// The name of the result may have been mangled.
pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result<FunctionValue<'c>>;
/// Adds or get the [FunctionValue] in the [inkwell::module::Module] with the given symbol
/// and function type.
///
/// The name undergoes no mangling. The [FunctionValue] will have
/// [inkwell::module::Linkage::External].
///
/// If this function is called multiple times with the same arguments it
/// will return the same [FunctionValue].
///
/// If a function with the given name exists but the type does not match
/// then an Error is returned.
pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
}
}

Expand Down
13 changes: 13 additions & 0 deletions src/emit/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,16 @@ fn emit_hugr_custom_op(#[with(-1, add_int_extensions)] llvm_ctx: TestContext) {
});
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn get_external_func(llvm_ctx: TestContext) {
llvm_ctx.with_emit_module_context(|emc| {
let func_type1 = emc.iw_context().i32_type().fn_type(&[], false);
let func_type2 = emc.iw_context().f64_type().fn_type(&[], false);
let foo1 = emc.get_extern_func("foo", func_type1).unwrap();
assert_eq!(foo1.get_name().to_str().unwrap(), "foo");
let foo2 = emc.get_extern_func("foo", func_type1).unwrap();
assert_eq!(foo1, foo2);
assert!(emc.get_extern_func("foo", func_type2).is_err());
});
}
17 changes: 16 additions & 1 deletion src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rstest::fixture;

use crate::{
custom::CodegenExtsMap,
emit::EmitHugr,
emit::{EmitHugr, EmitModuleContext, Namer},
types::{TypeConverter, TypingSession},
};

Expand Down Expand Up @@ -120,6 +120,21 @@ impl TestContext {
(r, ectx.finish())
})
}

pub fn with_emit_module_context<'c, T>(
&'c self,
f: impl FnOnce(EmitModuleContext<'c, THugrView>) -> T,
) -> T {
self.with_context(|ctx| {
let m = ctx.create_module("test_module");
f(EmitModuleContext::new(
m,
Namer::default().into(),
self.extensions(),
TypeConverter::new(ctx),
))
})
}
}

#[fixture]
Expand Down