Skip to content

Commit

Permalink
add ConstExternSymbol
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Jun 17, 2024
1 parent adc6dc4 commit 2cdac37
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 17 deletions.
63 changes: 50 additions & 13 deletions src/custom/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{any::TypeId, collections::HashSet};

use anyhow::{anyhow, Result};
use hugr::{
extension::prelude::{self, ConstUsize, QB_T, USIZE_T},
extension::prelude::{self, ConstExternalSymbol, ConstUsize, QB_T, USIZE_T},
ops::constant::CustomConst,
types::TypeEnum,
HugrView,
Expand Down Expand Up @@ -102,22 +102,36 @@ impl<'c, H: HugrView, PCG: PreludeCodegen> CodegenExtension<'c, H>
}

fn supported_consts(&self) -> HashSet<TypeId> {
[TypeId::of::<ConstUsize>()].into_iter().collect()
[
TypeId::of::<ConstUsize>(),
TypeId::of::<ConstExternalSymbol>(),
]
.into_iter()
.collect()
}

fn load_constant(
&self,
context: &mut EmitFuncContext<'c, H>,
konst: &dyn CustomConst,
) -> Result<Option<BasicValueEnum<'c>>> {
let Some(k) = konst.downcast_ref::<ConstUsize>() else {
return Ok(None);
};
let ty: IntType<'c> = context
.llvm_type(&k.get_type())?
.try_into()
.map_err(|_| anyhow!("Failed to get ConstUsize as IntType"))?;
Ok(Some(ty.const_int(k.value(), false).into()))
if let Some(k) = konst.downcast_ref::<ConstUsize>() {
let ty: IntType<'c> = context
.llvm_type(&k.get_type())?
.try_into()
.map_err(|_| anyhow!("Failed to get ConstUsize as IntType"))?;
Ok(Some(ty.const_int(k.value(), false).into()))
} else if let Some(k) = konst.downcast_ref::<ConstExternalSymbol>() {
let llvm_type = context.llvm_type(&k.get_type())?;
let global = context.get_global(&k.symbol, llvm_type, k.constant)?;
Ok(Some(
context
.builder()
.build_load(global.as_pointer_value(), &k.symbol)?,
))
} else {
Ok(None)
}
}
}

Expand Down Expand Up @@ -155,6 +169,8 @@ impl<'c, H: HugrView> CodegenExtsMap<'c, H> {
#[cfg(test)]
mod test {
use hugr::builder::{Dataflow, DataflowSubContainer};
use hugr::type_row;
use hugr::types::Type;
use rstest::rstest;

use crate::check_emission;
Expand All @@ -178,7 +194,7 @@ mod test {
}

#[rstest]
fn test_prelude_extension(llvm_ctx: TestContext) {
fn prelude_extension_types(llvm_ctx: TestContext) {
let ctx = llvm_ctx.iw_context();
let ext: PreludeCodegenExtension<TestPreludeCodegen> = TestPreludeCodegen.into();
let tc = llvm_ctx.get_typing_session();
Expand All @@ -201,7 +217,7 @@ mod test {
}

#[rstest]
fn test_prelude_extension_in_test_context(mut llvm_ctx: TestContext) {
fn prelude_extension_types_in_test_context(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(|x| x.add_prelude_extensions(TestPreludeCodegen));
let tc = llvm_ctx.get_typing_session();
assert_eq!(
Expand All @@ -215,7 +231,7 @@ mod test {
}

#[rstest]
fn test_prelude_extension_const_usize(mut llvm_ctx: TestContext) {
fn prelude_const_usize(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_default_prelude_extensions);

let hugr = SimpleHugrConfig::new()
Expand All @@ -227,4 +243,25 @@ mod test {
});
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn prelude_const_external_symbol(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_default_prelude_extensions);
let konst1 = ConstExternalSymbol::new("sym1", USIZE_T, true);
let konst2 = ConstExternalSymbol::new(
"sym2",
Type::new_sum([type_row![USIZE_T, Type::new_unit_sum(3)], type_row![]]),
false,
);

let hugr = SimpleHugrConfig::new()
.with_outs(vec![konst1.get_type(), konst2.get_type()])
.with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
.finish(|mut builder| {
let k1 = builder.add_load_value(konst1);
let k2 = builder.add_load_value(konst2);
builder.finish_with_outputs([k1, k2]).unwrap()
});
check_emission!(hugr, llvm_ctx);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
source: src/custom/prelude.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

@sym2 = external global { i32, { i64, { i32, {}, {}, {} } }, {} }
@sym1 = external constant i64

define { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } @_hl.main.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%sym2 = load { i32, { i64, { i32, {}, {}, {} } }, {} }, { i32, { i64, { i32, {}, {}, {} } }, {} }* @sym2, align 4
%sym1 = load i64, i64* @sym1, align 4
%mrv = insertvalue { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } undef, i64 %sym1, 0
%mrv5 = insertvalue { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } %mrv, { i32, { i64, { i32, {}, {}, {} } }, {} } %sym2, 1
ret { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } %mrv5
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
source: src/custom/prelude.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

@sym2 = external global { i32, { i64, { i32, {}, {}, {} } }, {} }
@sym1 = external constant i64

define { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } @_hl.main.1() {
alloca_block:
%"0" = alloca i64, align 8
%"1" = alloca { i32, { i64, { i32, {}, {}, {} } }, {} }, align 8
%"7_0" = alloca { i32, { i64, { i32, {}, {}, {} } }, {} }, align 8
%"5_0" = alloca i64, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
%sym2 = load { i32, { i64, { i32, {}, {}, {} } }, {} }, { i32, { i64, { i32, {}, {}, {} } }, {} }* @sym2, align 4
store { i32, { i64, { i32, {}, {}, {} } }, {} } %sym2, { i32, { i64, { i32, {}, {}, {} } }, {} }* %"7_0", align 4
%sym1 = load i64, i64* @sym1, align 4
store i64 %sym1, i64* %"5_0", align 4
%"5_01" = load i64, i64* %"5_0", align 4
%"7_02" = load { i32, { i64, { i32, {}, {}, {} } }, {} }, { i32, { i64, { i32, {}, {}, {} } }, {} }* %"7_0", align 4
store i64 %"5_01", i64* %"0", align 4
store { i32, { i64, { i32, {}, {}, {} } }, {} } %"7_02", { i32, { i64, { i32, {}, {}, {} } }, {} }* %"1", align 4
%"03" = load i64, i64* %"0", align 4
%"14" = load { i32, { i64, { i32, {}, {}, {} } }, {} }, { i32, { i64, { i32, {}, {}, {} } }, {} }* %"1", align 4
%mrv = insertvalue { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } undef, i64 %"03", 0
%mrv5 = insertvalue { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } %mrv, { i32, { i64, { i32, {}, {}, {} } }, {} } %"14", 1
ret { i64, { i32, { i64, { i32, {}, {}, {} } }, {} } } %mrv5
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
source: src/custom/prelude.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i64 @_hl.main.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
ret i64 17
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
source: src/custom/prelude.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i64 @_hl.main.1() {
alloca_block:
%"0" = alloca i64, align 8
%"5_0" = alloca i64, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store i64 17, i64* %"5_0", align 4
%"5_01" = load i64, i64* %"5_0", align 4
store i64 %"5_01", i64* %"0", align 4
%"02" = load i64, i64* %"0", align 4
ret i64 %"02"
}
51 changes: 49 additions & 2 deletions src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use hugr::{
use inkwell::{
context::Context,
module::Module,
types::{BasicTypeEnum, FunctionType},
values::{BasicValueEnum, FunctionValue},
types::{AnyType, BasicType, BasicTypeEnum, FunctionType},
values::{BasicValueEnum, FunctionValue, GlobalValue},
};
use std::{collections::HashSet, hash::Hash, rc::Rc};

Expand Down Expand Up @@ -239,6 +239,53 @@ impl<'c, H: HugrView> EmitModuleContext<'c, H> {
self.get_func_impl(&node.name, node.node(), &node.signature)
}

/// Adds or gets the [GlobalValue] in the [Module] corresponding to the
/// given symbol and LLVM type.
///
/// The name will not be mangled.
///
/// If a global with the given name exists but the type or constant-ness
/// does not match then an error will be returned.
pub fn get_global(
&self,
symbol: impl AsRef<str>,
typ: impl BasicType<'c>,
constant: bool,
) -> Result<GlobalValue<'c>> {
let symbol = symbol.as_ref();
let typ = typ.as_basic_type_enum();
if let Some(global) = self.module().get_global(symbol) {
let global_type = {
// TODO This is exposed as `get_value_type` on the master branch
// of inkwell, will be in the next release. When it's released
// use `get_value_type`.
use inkwell::types::AnyTypeEnum;
use inkwell::values::AsValueRef;
unsafe {
AnyTypeEnum::new(llvm_sys_140::core::LLVMGlobalGetValueType(
global.as_value_ref(),
))
}
};
if global_type != typ.as_any_type_enum() {
Err(anyhow!(
"Global '{symbol}' has wrong type: expected: {typ} actual: {global_type}"
))?
}
if global.is_constant() != constant {
Err(anyhow!(
"Global '{symbol}' has wrong constant-ness: expected: {constant} actual: {}",
global.is_constant()
))?
}
Ok(global)
} else {
let global = self.module().add_global(typ, None, symbol.as_ref());
global.set_constant(constant);
Ok(global)
}
}

/// Consumes the `EmitModuleContext` and returns the internal [Module].
pub fn finish(self) -> Module<'c> {
self.module
Expand Down
12 changes: 10 additions & 2 deletions src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use inkwell::{
basic_block::BasicBlock,
builder::Builder,
context::Context,
types::{BasicTypeEnum, FunctionType},
values::FunctionValue,
types::{BasicType, BasicTypeEnum, FunctionType},
values::{FunctionValue, GlobalValue},
};
use itertools::zip_eq;

Expand Down Expand Up @@ -78,6 +78,14 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
///
/// The name of the result may have been mangled.
pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result<FunctionValue<'c>>;
/// Adds or gets the [GlobalValue] in the [inkwell::module::Module] corresponding to the
/// given symbol and LLVM type.
///
/// The name will not be mangled.
///
/// If a global with the given name exists but the type or constant-ness
/// does not match then an error will be returned.
pub fn get_global(&self, symbol: impl AsRef<str>, typ: impl BasicType<'c>, constant: bool) -> Result<GlobalValue<'c>>;
}
}

Expand Down

0 comments on commit 2cdac37

Please sign in to comment.