|
| 1 | +use std::{any::TypeId, collections::HashSet}; |
| 2 | + |
| 3 | +use anyhow::{anyhow, Result}; |
| 4 | +use hugr::{ |
| 5 | + ops::{constant::CustomConst, CustomOp}, |
| 6 | + std_extensions::arithmetic::{ |
| 7 | + float_ops, |
| 8 | + float_types::{self, ConstF64, FLOAT64_CUSTOM_TYPE}, |
| 9 | + }, |
| 10 | + HugrView, |
| 11 | +}; |
| 12 | +use inkwell::{ |
| 13 | + types::{BasicType, FloatType}, |
| 14 | + values::{BasicValue, BasicValueEnum}, |
| 15 | +}; |
| 16 | + |
| 17 | +use crate::emit::{func::EmitFuncContext, EmitOp, EmitOpArgs, NullEmitLlvm}; |
| 18 | + |
| 19 | +use super::{CodegenExtension, CodegenExtsMap}; |
| 20 | + |
| 21 | +struct FloatTypesCodegenExtension; |
| 22 | + |
| 23 | +impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatTypesCodegenExtension { |
| 24 | + fn extension(&self) -> hugr::extension::ExtensionId { |
| 25 | + float_types::EXTENSION_ID |
| 26 | + } |
| 27 | + |
| 28 | + fn llvm_type( |
| 29 | + &self, |
| 30 | + context: &crate::types::TypingSession<'c, H>, |
| 31 | + hugr_type: &hugr::types::CustomType, |
| 32 | + ) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> { |
| 33 | + if hugr_type == &FLOAT64_CUSTOM_TYPE { |
| 34 | + Ok(context.iw_context().f64_type().as_basic_type_enum()) |
| 35 | + } else { |
| 36 | + Err(anyhow!( |
| 37 | + "FloatCodegenExtension: Unsupported type: {}", |
| 38 | + hugr_type |
| 39 | + )) |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + fn emitter<'a>( |
| 44 | + &self, |
| 45 | + _context: &'a mut crate::emit::func::EmitFuncContext<'c, H>, |
| 46 | + ) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> { |
| 47 | + Box::new(NullEmitLlvm) |
| 48 | + } |
| 49 | + |
| 50 | + fn supported_consts(&self) -> HashSet<TypeId> { |
| 51 | + [TypeId::of::<ConstF64>()].into_iter().collect() |
| 52 | + } |
| 53 | + |
| 54 | + fn load_constant( |
| 55 | + &self, |
| 56 | + context: &mut EmitFuncContext<'c, H>, |
| 57 | + konst: &dyn hugr::ops::constant::CustomConst, |
| 58 | + ) -> Result<Option<BasicValueEnum<'c>>> { |
| 59 | + let Some(k) = konst.downcast_ref::<ConstF64>() else { |
| 60 | + return Ok(None); |
| 61 | + }; |
| 62 | + let ty: FloatType<'c> = context.llvm_type(&k.get_type())?.try_into().unwrap(); |
| 63 | + Ok(Some(ty.const_float(k.value()).as_basic_value_enum())) |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +struct FloatOpsCodegenExtension; |
| 68 | + |
| 69 | +impl<'c, H: HugrView> CodegenExtension<'c, H> for FloatOpsCodegenExtension { |
| 70 | + fn extension(&self) -> hugr::extension::ExtensionId { |
| 71 | + float_ops::EXTENSION_ID |
| 72 | + } |
| 73 | + |
| 74 | + fn llvm_type( |
| 75 | + &self, |
| 76 | + _context: &crate::types::TypingSession<'c, H>, |
| 77 | + hugr_type: &hugr::types::CustomType, |
| 78 | + ) -> anyhow::Result<inkwell::types::BasicTypeEnum<'c>> { |
| 79 | + Err(anyhow!( |
| 80 | + "FloatOpsCodegenExtension: unsupported type: {hugr_type}" |
| 81 | + )) |
| 82 | + } |
| 83 | + |
| 84 | + fn emitter<'a>( |
| 85 | + &self, |
| 86 | + context: &'a mut crate::emit::func::EmitFuncContext<'c, H>, |
| 87 | + ) -> Box<dyn crate::emit::EmitOp<'c, hugr::ops::CustomOp, H> + 'a> { |
| 88 | + Box::new(FloatOpEmitter(context)) |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | +// we allow dead code for now, but once we implement the emitter, we should |
| 93 | +// remove this |
| 94 | +#[allow(dead_code)] |
| 95 | +struct FloatOpEmitter<'c, 'd, H: HugrView>(&'d mut EmitFuncContext<'c, H>); |
| 96 | + |
| 97 | +impl<'c, H: HugrView> EmitOp<'c, CustomOp, H> for FloatOpEmitter<'c, '_, H> { |
| 98 | + fn emit(&mut self, args: EmitOpArgs<'c, CustomOp, H>) -> Result<()> { |
| 99 | + use hugr::ops::NamedOp; |
| 100 | + let name = args.node().name(); |
| 101 | + // This looks strange now, but we will add cases for ops piecemeal, as |
| 102 | + // in the analgous match expression in `IntOpEmitter`. |
| 103 | + #[allow(clippy::match_single_binding)] |
| 104 | + match name.as_str() { |
| 105 | + n => Err(anyhow!("FloatOpEmitter: unknown op: {n}")), |
| 106 | + } |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +pub fn add_float_extensions<H: HugrView>(cem: CodegenExtsMap<'_, H>) -> CodegenExtsMap<'_, H> { |
| 111 | + cem.add_cge(FloatTypesCodegenExtension) |
| 112 | + .add_cge(FloatOpsCodegenExtension) |
| 113 | +} |
| 114 | + |
| 115 | +impl<H: HugrView> CodegenExtsMap<'_, H> { |
| 116 | + pub fn add_float_extensions(self) -> Self { |
| 117 | + add_float_extensions(self) |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +#[cfg(test)] |
| 122 | +mod test { |
| 123 | + use hugr::{ |
| 124 | + builder::{Dataflow, DataflowSubContainer}, |
| 125 | + std_extensions::arithmetic::{ |
| 126 | + float_ops::FLOAT_OPS_REGISTRY, |
| 127 | + float_types::{ConstF64, FLOAT64_TYPE}, |
| 128 | + }, |
| 129 | + }; |
| 130 | + use rstest::rstest; |
| 131 | + |
| 132 | + use super::add_float_extensions; |
| 133 | + use crate::{ |
| 134 | + check_emission, |
| 135 | + emit::test::SimpleHugrConfig, |
| 136 | + test::{llvm_ctx, TestContext}, |
| 137 | + }; |
| 138 | + |
| 139 | + #[rstest] |
| 140 | + fn const_float(mut llvm_ctx: TestContext) { |
| 141 | + llvm_ctx.add_extensions(add_float_extensions); |
| 142 | + let hugr = SimpleHugrConfig::new() |
| 143 | + .with_outs(FLOAT64_TYPE) |
| 144 | + .with_extensions(FLOAT_OPS_REGISTRY.to_owned()) |
| 145 | + .finish(|mut builder| { |
| 146 | + let c = builder.add_load_value(ConstF64::new(3.12)); |
| 147 | + builder.finish_with_outputs([c]).unwrap() |
| 148 | + }); |
| 149 | + check_emission!(hugr, llvm_ctx); |
| 150 | + } |
| 151 | +} |
0 commit comments