Skip to content

Commit

Permalink
Add debug_print(msg) util (#510)
Browse files Browse the repository at this point in the history
* update melior to 0.17

* add debug print  util

* fix
  • Loading branch information
edg-l authored Apr 16, 2024
1 parent 2264ce4 commit db96ffc
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ educe = "0.5.11"
id-arena = "2.2"
itertools = "0.12"
lazy_static = "1.4"
libc = "0.2.147"
libc = "0.2.153"
llvm-sys = "170.0.0"
melior = { version = "0.17.0", features = ["ods-dialects"] }
mlir-sys = "0.2.1"
Expand Down
130 changes: 129 additions & 1 deletion src/metadata/debug_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@

use crate::error::Result;
use melior::{
dialect::{arith, func, llvm},
dialect::{
arith, func,
llvm::{self, r#type::opaque_pointer},
ods,
},
ir::{
attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute},
operation::OperationBuilder,
Expand All @@ -100,6 +104,7 @@ use std::collections::HashSet;
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
enum DebugBinding {
BreakpointMarker,
DebugPrint,
PrintI1,
PrintI8,
PrintI32,
Expand Down Expand Up @@ -150,6 +155,108 @@ impl DebugUtils {
Ok(())
}

/// Prints the given &str.
pub fn debug_print<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
message: &str,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
if self.active_map.insert(DebugBinding::DebugPrint) {
module.body().append_operation(func::func(
context,
StringAttribute::new(context, "__debug__debug_print_impl"),
TypeAttribute::new(
FunctionType::new(
context,
&[
opaque_pointer(context),
IntegerType::new(context, 64).into(),
],
&[],
)
.into(),
),
Region::new(),
&[(
Identifier::new(context, "sym_visibility"),
StringAttribute::new(context, "private").into(),
)],
Location::unknown(context),
));
}

let ty = llvm::r#type::array(
IntegerType::new(context, 8).into(),
message.len().try_into().unwrap(),
);

let k1 = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(IntegerType::new(context, 64).into(), 1).into(),
location,
))
.result(0)?
.into();

let ptr = block
.append_operation(
{
let mut op = ods::llvm::alloca(context, opaque_pointer(context), k1, location);
op.set_elem_type(TypeAttribute::new(ty));
op
}
.into(),
)
.result(0)?
.into();

let msg = block
.append_operation(
ods::llvm::mlir_constant(
context,
llvm::r#type::array(
IntegerType::new(context, 8).into(),
message.len().try_into().unwrap(),
),
StringAttribute::new(context, message).into(),
location,
)
.into(),
)
.result(0)?
.into();
block.append_operation(ods::llvm::store(context, msg, ptr, location).into());
let len = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(
IntegerType::new(context, 64).into(),
message.len().try_into().unwrap(),
)
.into(),
location,
))
.result(0)?
.into();

block.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, "__debug__debug_print_impl"),
&[ptr, len],
&[],
location,
));

Ok(())
}

pub fn debug_breakpoint_trap<'c, 'a>(
&mut self,
block: &'a Block<'c>,
Expand Down Expand Up @@ -534,6 +641,15 @@ impl DebugUtils {
}
}

if self.active_map.contains(&DebugBinding::DebugPrint) {
unsafe {
engine.register_symbol(
"__debug__debug_print_impl",
debug_print_impl as *const fn(*const std::ffi::c_char) -> () as *mut (),
);
}
}

if self.active_map.contains(&DebugBinding::PrintI1) {
unsafe {
engine.register_symbol(
Expand Down Expand Up @@ -603,6 +719,18 @@ extern "C" fn breakpoint_marker_impl() {
println!("[DEBUG] Breakpoint marker.");
}

extern "C" fn debug_print_impl(message: *const std::ffi::c_char, len: u64) {
// llvm constant strings are not zero terminated
let slice = unsafe { std::slice::from_raw_parts(message as *const u8, len as usize) };
let message = std::str::from_utf8(slice);

if let Ok(message) = message {
println!("[DEBUG] Message: {}", message);
} else {
println!("[DEBUG] Message: {:?}", message);
}
}

extern "C" fn print_i1_impl(value: bool) {
println!("[DEBUG] {value}");
}
Expand Down

0 comments on commit db96ffc

Please sign in to comment.