Skip to content
This repository was archived by the owner on Mar 5, 2025. It is now read-only.

Commit fa37181

Browse files
authored
feat: Add initial float extension (#31)
1 parent 8d86755 commit fa37181

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

src/custom.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::{
2121

2222
use super::emit::EmitOp;
2323

24+
pub mod float;
2425
pub mod int;
2526
pub mod prelude;
2627

src/custom/float.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
use std::{any::TypeId, collections::HashSet};
2+
3+
use anyhow::{anyhow, Result};
4+
use hugr::{
5+
ops::{constant::CustomConst, CustomOp},
6+
std_extensions::arithmetic::{
7+
float_ops,
8+
float_types::{self, ConstF64, FLOAT64_CUSTOM_TYPE},
9+
},
10+
HugrView,
11+
};
12+
use inkwell::{
13+
types::{BasicType, FloatType},
14+
values::{BasicValue, BasicValueEnum},
15+
};
16+
17+
use crate::emit::{func::EmitFuncContext, EmitOp, EmitOpArgs, NullEmitLlvm};
18+
19+
use super::{CodegenExtension, CodegenExtsMap};
20+
21+
struct FloatTypesCodegenExtension;
22+
23+
impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatTypesCodegenExtension {
24+
fn extension(&self) -> hugr::extension::ExtensionId {
25+
float_types::EXTENSION_ID
26+
}
27+
28+
fn llvm_type(
29+
&self,
30+
context: &crate::types::TypingSession<'c, H>,
31+
hugr_type: &hugr::types::CustomType,
32+
) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> {
33+
if hugr_type == &FLOAT64_CUSTOM_TYPE {
34+
Ok(context.iw_context().f64_type().as_basic_type_enum())
35+
} else {
36+
Err(anyhow!(
37+
"FloatCodegenExtension: Unsupported type: {}",
38+
hugr_type
39+
))
40+
}
41+
}
42+
43+
fn emitter<'a>(
44+
&self,
45+
_context: &'a mut crate::emit::func::EmitFuncContext<'c, H>,
46+
) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> {
47+
Box::new(NullEmitLlvm)
48+
}
49+
50+
fn supported_consts(&self) -> HashSet<TypeId> {
51+
[TypeId::of::<ConstF64>()].into_iter().collect()
52+
}
53+
54+
fn load_constant(
55+
&self,
56+
context: &mut EmitFuncContext<'c, H>,
57+
konst: &dyn hugr::ops::constant::CustomConst,
58+
) -> Result<Option<BasicValueEnum<'c>>> {
59+
let Some(k) = konst.downcast_ref::<ConstF64>() else {
60+
return Ok(None);
61+
};
62+
let ty: FloatType<'c> = context.llvm_type(&k.get_type())?.try_into().unwrap();
63+
Ok(Some(ty.const_float(k.value()).as_basic_value_enum()))
64+
}
65+
}
66+
67+
struct FloatOpsCodegenExtension;
68+
69+
impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatOpsCodegenExtension {
70+
fn extension(&self) -> hugr::extension::ExtensionId {
71+
float_ops::EXTENSION_ID
72+
}
73+
74+
fn llvm_type(
75+
&self,
76+
_context: &crate::types::TypingSession<'c, H>,
77+
hugr_type: &hugr::types::CustomType,
78+
) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> {
79+
Err(anyhow!(
80+
"FloatOpsCodegenExtension: unsupported type: {hugr_type}"
81+
))
82+
}
83+
84+
fn emitter<'a>(
85+
&self,
86+
context: &'a mut crate::emit::func::EmitFuncContext<'c, H>,
87+
) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> {
88+
Box::new(FloatOpEmitter(context))
89+
}
90+
}
91+
92+
// we allow dead code for now, but once we implement the emitter, we should
93+
// remove this
94+
#[allow(dead_code)]
95+
struct FloatOpEmitter<'c, 'd, H: HugrView>(&'d mut EmitFuncContext<'c, H>);
96+
97+
impl<'c, H: HugrView> EmitOp<'c, CustomOp, H> for FloatOpEmitter<'c, '_, H> {
98+
fn emit(&mut self, args: EmitOpArgs<'c, CustomOp, H>) -> Result<()> {
99+
use hugr::ops::NamedOp;
100+
let name = args.node().name();
101+
// This looks strange now, but we will add cases for ops piecemeal, as
102+
// in the analgous match expression in `IntOpEmitter`.
103+
#[allow(clippy::match_single_binding)]
104+
match name.as_str() {
105+
n => Err(anyhow!("FloatOpEmitter: unknown op: {n}")),
106+
}
107+
}
108+
}
109+
110+
pub fn add_float_extensions<H: HugrView>(cem: CodegenExtsMap<'_, H>) -> CodegenExtsMap<'_, H> {
111+
cem.add_cge(FloatTypesCodegenExtension)
112+
.add_cge(FloatOpsCodegenExtension)
113+
}
114+
115+
impl<H: HugrView> CodegenExtsMap<'_, H> {
116+
pub fn add_float_extensions(self) -> Self {
117+
add_float_extensions(self)
118+
}
119+
}
120+
121+
#[cfg(test)]
122+
mod test {
123+
use hugr::{
124+
builder::{Dataflow, DataflowSubContainer},
125+
std_extensions::arithmetic::{
126+
float_ops::FLOAT_OPS_REGISTRY,
127+
float_types::{ConstF64, FLOAT64_TYPE},
128+
},
129+
};
130+
use rstest::rstest;
131+
132+
use super::add_float_extensions;
133+
use crate::{
134+
check_emission,
135+
emit::test::SimpleHugrConfig,
136+
test::{llvm_ctx, TestContext},
137+
};
138+
139+
#[rstest]
140+
fn const_float(mut llvm_ctx: TestContext) {
141+
llvm_ctx.add_extensions(add_float_extensions);
142+
let hugr = SimpleHugrConfig::new()
143+
.with_outs(FLOAT64_TYPE)
144+
.with_extensions(FLOAT_OPS_REGISTRY.to_owned())
145+
.finish(|mut builder| {
146+
let c = builder.add_load_value(ConstF64::new(3.12));
147+
builder.finish_with_outputs([c]).unwrap()
148+
});
149+
check_emission!(hugr, llvm_ctx);
150+
}
151+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
---
2+
source: src/custom/float.rs
3+
expression: module.to_string()
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define double @_hl.main.1() {
9+
alloca_block:
10+
br label %entry_block
11+
12+
entry_block: ; preds = %alloca_block
13+
ret double 3.120000e+00
14+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
source: src/custom/float.rs
3+
expression: module.to_string()
4+
---
5+
; ModuleID = 'test_context'
6+
source_filename = "test_context"
7+
8+
define double @_hl.main.1() {
9+
alloca_block:
10+
%"0" = alloca double, align 8
11+
%"5_0" = alloca double, align 8
12+
br label %entry_block
13+
14+
entry_block: ; preds = %alloca_block
15+
store double 3.120000e+00, double* %"5_0", align 8
16+
%"5_01" = load double, double* %"5_0", align 8
17+
store double %"5_01", double* %"0", align 8
18+
%"02" = load double, double* %"0", align 8
19+
ret double %"02"
20+
}

0 commit comments

Comments
 (0)