Skip to content
This repository was archived by the owner on Mar 5, 2025. It is now read-only.

feat: Add initial float extension #31

Merged
merged 3 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::{

use super::emit::EmitOp;

pub mod float;
pub mod int;
pub mod prelude;

Expand Down
151 changes: 151 additions & 0 deletions src/custom/float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::{any::TypeId, collections::HashSet};

use anyhow::{anyhow, Result};
use hugr::{
ops::{constant::CustomConst, CustomOp},
std_extensions::arithmetic::{
float_ops,
float_types::{self, ConstF64, FLOAT64_CUSTOM_TYPE},
},
HugrView,
};
use inkwell::{
types::{BasicType, FloatType},
values::{BasicValue, BasicValueEnum},
};

use crate::emit::{func::EmitFuncContext, EmitOp, EmitOpArgs, NullEmitLlvm};

use super::{CodegenExtension, CodegenExtsMap};

struct FloatTypesCodegenExtension;

impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatTypesCodegenExtension {
fn extension(&self) -> hugr::extension::ExtensionId {
float_types::EXTENSION_ID
}

fn llvm_type(
&self,
context: &crate::types::TypingSession<'c, H>,
hugr_type: &hugr::types::CustomType,
) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> {
if hugr_type == &FLOAT64_CUSTOM_TYPE {
Ok(context.iw_context().f64_type().as_basic_type_enum())
} else {
Err(anyhow!(
"FloatCodegenExtension: Unsupported type: {}",
hugr_type
))
}
}

fn emitter<'a>(
&self,
_context: &'a mut crate::emit::func::EmitFuncContext<'c, H>,
) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> {
Box::new(NullEmitLlvm)
}

fn supported_consts(&self) -> HashSet<TypeId> {
[TypeId::of::<ConstF64>()].into_iter().collect()
}

fn load_constant(
&self,
context: &mut EmitFuncContext<'c, H>,
konst: &dyn hugr::ops::constant::CustomConst,
) -> Result<Option<BasicValueEnum<'c>>> {
let Some(k) = konst.downcast_ref::<ConstF64>() else {
return Ok(None);
};
let ty: FloatType<'c> = context.llvm_type(&k.get_type())?.try_into().unwrap();
Ok(Some(ty.const_float(k.value()).as_basic_value_enum()))
}
}

struct FloatOpsCodegenExtension;

impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatOpsCodegenExtension {
fn extension(&self) -> hugr::extension::ExtensionId {
float_ops::EXTENSION_ID
}

fn llvm_type(
&self,
_context: &crate::types::TypingSession<'c, H>,
hugr_type: &hugr::types::CustomType,
) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> {
Err(anyhow!(
"FloatOpsCodegenExtension: unsupported type: {hugr_type}"
))
}

fn emitter<'a>(
&self,
context: &'a mut crate::emit::func::EmitFuncContext<'c, H>,
) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> {
Box::new(FloatOpEmitter(context))
}
}

// we allow dead code for now, but once we implement the emitter, we should
// remove this
#[allow(dead_code)]
struct FloatOpEmitter<'c, 'd, H: HugrView>(&'d mut EmitFuncContext<'c, H>);

impl<'c, H: HugrView> EmitOp<'c, CustomOp, H> for FloatOpEmitter<'c, '_, H> {
fn emit(&mut self, args: EmitOpArgs<'c, CustomOp, H>) -> Result<()> {
use hugr::ops::NamedOp;
let name = args.node().name();
// This looks strange now, but we will add cases for ops piecemeal, as
// in the analgous match expression in `IntOpEmitter`.
#[allow(clippy::match_single_binding)]
match name.as_str() {
n => Err(anyhow!("FloatOpEmitter: unknown op: {n}")),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a curious way to say anyhow!("FloatOpEmitter: unknown op: {}", name.as_str())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed it is! the intention is that we can insert handlers for ops piecemeal above it, as in int.rs.

}
}
}

pub fn add_float_extensions<H: HugrView>(cem: CodegenExtsMap<'_, H>) -> CodegenExtsMap<'_, H> {
cem.add_cge(FloatTypesCodegenExtension)
.add_cge(FloatOpsCodegenExtension)
}

impl<H: HugrView> CodegenExtsMap<'_, H> {
pub fn add_float_extensions(self) -> Self {
add_float_extensions(self)
}
}

#[cfg(test)]
mod test {
use hugr::{
builder::{Dataflow, DataflowSubContainer},
std_extensions::arithmetic::{
float_ops::FLOAT_OPS_REGISTRY,
float_types::{ConstF64, FLOAT64_TYPE},
},
};
use rstest::rstest;

use super::add_float_extensions;
use crate::{
check_emission,
emit::test::SimpleHugrConfig,
test::{llvm_ctx, TestContext},
};

#[rstest]
fn const_float(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_float_extensions);
let hugr = SimpleHugrConfig::new()
.with_outs(FLOAT64_TYPE)
.with_extensions(FLOAT_OPS_REGISTRY.to_owned())
.finish(|mut builder| {
let c = builder.add_load_value(ConstF64::new(3.12));
builder.finish_with_outputs([c]).unwrap()
});
check_emission!(hugr, llvm_ctx);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
source: src/custom/float.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define double @_hl.main.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
ret double 3.120000e+00
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
source: src/custom/float.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define double @_hl.main.1() {
alloca_block:
%"0" = alloca double, align 8
%"5_0" = alloca double, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store double 3.120000e+00, double* %"5_0", align 8
%"5_01" = load double, double* %"5_0", align 8
store double %"5_01", double* %"0", align 8
%"02" = load double, double* %"0", align 8
ret double %"02"
}