Skip to content

Commit

Permalink
feat: Add initial float extension
Browse files Browse the repository at this point in the history
Note that this is only a skeleton, with not even a single op lowering
implemented. It does add support for f64 types and constants though.
  • Loading branch information
doug-q committed Jun 18, 2024
1 parent 4b51a4c commit a346843
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use crate::{

use super::emit::EmitOp;

pub mod float;
pub mod int;
pub mod prelude;
// pub mod float_ops;
// pub mod logic_ops;

/// The extension point for lowering HUGR Extensions to LLVM.
Expand Down
151 changes: 151 additions & 0 deletions src/custom/float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::{any::TypeId, collections::HashSet};

use anyhow::{anyhow, Result};
use hugr::{
ops::{constant::CustomConst, CustomOp},
std_extensions::arithmetic::{
float_ops,
float_types::{self, ConstF64, FLOAT64_CUSTOM_TYPE},
},
HugrView,
};
use inkwell::{
types::{BasicType, FloatType},
values::{BasicValue, BasicValueEnum},
};

use crate::emit::{func::EmitFuncContext, EmitOp, EmitOpArgs, NullEmitLlvm};

use super::{CodegenExtension, CodegenExtsMap};

struct FloatTypesCodegenExtension;

impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatTypesCodegenExtension {
fn extension(&self) -> hugr::extension::ExtensionId {
return float_types::EXTENSION_ID;
}

fn llvm_type(
&self,
context: &crate::types::TypingSession<'c, H>,
hugr_type: &hugr::types::CustomType,
) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> {
if hugr_type == &FLOAT64_CUSTOM_TYPE {
Ok(context.iw_context().f64_type().as_basic_type_enum())
} else {
Err(anyhow!(
"FloatCodegenExtension: Unsupported type: {}",
hugr_type
))
}
}

fn emitter<'a>(
&self,
_context: &'a mut crate::emit::func::EmitFuncContext<'c, H>,
) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> {
Box::new(NullEmitLlvm)
}

fn supported_consts(&self) -> HashSet<TypeId> {
[TypeId::of::<ConstF64>()].into_iter().collect()
}

fn load_constant(
&self,
context: &mut EmitFuncContext<'c, H>,
konst: &dyn hugr::ops::constant::CustomConst,
) -> Result<Option<BasicValueEnum<'c>>> {
let Some(k) = konst.downcast_ref::<ConstF64>() else {
return Ok(None);
};
let ty: FloatType<'c> = context
.llvm_type(&k.get_type())?
.try_into()
.map_err(|_| anyhow!("Failed to get type of ConstF64 as FloatType"))?;
Ok(Some(ty.const_float(k.value()).as_basic_value_enum()))
}
}

struct FloatOpsCodegenExtension;

impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatOpsCodegenExtension {
fn extension(&self) -> hugr::extension::ExtensionId {
return float_ops::EXTENSION_ID;
}

fn llvm_type(
&self,
_context: &crate::types::TypingSession<'c, H>,
hugr_type: &hugr::types::CustomType,
) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> {
Err(anyhow!(
"FloatOpsCodegenExtension: unsupported type: {hugr_type}"
))
}

fn emitter<'a>(
&self,
context: &'a mut crate::emit::func::EmitFuncContext<'c, H>,
) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> {
Box::new(FloatOpEmitter(context))
}
}

// we allow dead code for now, but once we implement the emitter, we should
// remove this
#[allow(dead_code)]
struct FloatOpEmitter<'c, 'd, H: HugrView>(&'d mut EmitFuncContext<'c, H>);

impl<'c, H: HugrView> EmitOp<'c, CustomOp, H> for FloatOpEmitter<'c, '_, H> {
fn emit(&mut self, args: EmitOpArgs<'c, CustomOp, H>) -> Result<()> {
use hugr::ops::NamedOp;
let name = args.node().name();
match name.as_str() {
n => Err(anyhow!("FloatOpEmitter: unknown op: {}", n)),
}
}
}

pub fn add_float_extensions<H: HugrView>(cem: CodegenExtsMap<'_, H>) -> CodegenExtsMap<'_, H> {
cem.add_cge(FloatTypesCodegenExtension)
.add_cge(FloatOpsCodegenExtension)
}

impl<H: HugrView> CodegenExtsMap<'_, H> {
pub fn add_float_extensions(self) -> Self {
add_float_extensions(self)
}
}

#[cfg(test)]
mod test {
use hugr::{
builder::{Dataflow, DataflowSubContainer},
std_extensions::arithmetic::{
float_ops::FLOAT_OPS_REGISTRY,
float_types::{ConstF64, FLOAT64_TYPE},
},
};
use rstest::rstest;

use super::add_float_extensions;
use crate::{
check_emission,
emit::test::SimpleHugrConfig,
test::{llvm_ctx, TestContext},
};

#[rstest]
fn const_float(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_float_extensions);
let hugr = SimpleHugrConfig::new()
.with_outs(FLOAT64_TYPE)
.with_extensions(FLOAT_OPS_REGISTRY.to_owned())
.finish(|mut builder| {
let c = builder.add_load_value(ConstF64::new(3.14));
builder.finish_with_outputs([c]).unwrap()
});
check_emission!(hugr, llvm_ctx);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
source: src/custom/float.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define double @_hl.main.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
ret double 3.140000e+00
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
source: src/custom/float.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define double @_hl.main.1() {
alloca_block:
%"0" = alloca double, align 8
%"5_0" = alloca double, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store double 3.140000e+00, double* %"5_0", align 8
%"5_01" = load double, double* %"5_0", align 8
store double %"5_01", double* %"0", align 8
%"02" = load double, double* %"0", align 8
ret double %"02"
}

0 comments on commit a346843

Please sign in to comment.