Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-F-Bryan committed May 3, 2022
1 parent 8df6ce0 commit 8518d3b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 33 deletions.
98 changes: 68 additions & 30 deletions rune/modulo/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use std::fmt::Display;

use crate::{
hotg_proc_blocks::BufferExt,
proc_block_v1::{BadArgumentReason, GraphError, InvalidArgument, KernelError},
proc_block_v1::{
BadArgumentReason, GraphError, InvalidArgument, KernelError,
},
runtime_v1::{
ArgumentMetadata, Dimensions, ElementType, GraphContext, KernelContext, Metadata,
TensorMetadata, TensorParam, TensorResult,
ArgumentMetadata, Dimensions, ElementType, GraphContext, KernelContext,
Metadata, TensorMetadata, TensorParam, TensorResult,
},
};
use num_traits::{FromPrimitive, ToPrimitive};
Expand All @@ -19,14 +21,16 @@ pub struct ProcBlockV1;

impl proc_block_v1::ProcBlockV1 for ProcBlockV1 {
fn register_metadata() {
let metadata = Metadata::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
let metadata =
Metadata::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
metadata.set_description(env!("CARGO_PKG_DESCRIPTION"));

let modulo = ArgumentMetadata::new("modulo");
modulo.add_hint(&runtime_v1::non_negative_number());
metadata.add_argument(&modulo);
let element_type = ArgumentMetadata::new("element_type");
element_type.set_description("The type of tensor this proc-block will accept");
element_type
.set_description("The type of tensor this proc-block will accept");
element_type.set_default_value("f64");
element_type.add_hint(&runtime_v1::interpret_as_string_in_enum(&[
"u8", "i8", "u16", "i16", "u32", "i32", "f32", "u64", "i64", "f64",
Expand All @@ -43,11 +47,13 @@ impl proc_block_v1::ProcBlockV1 for ProcBlockV1 {
}

fn graph(node_id: String) -> Result<(), GraphError> {
let ctx = GraphContext::for_node(&node_id)
.ok_or_else(|| GraphError::Other("Unable to load the graph context".to_string()))?;
let ctx = GraphContext::for_node(&node_id).ok_or_else(|| {
GraphError::Other("Unable to load the graph context".to_string())
})?;

// make sure the modulus is valid
let _ = get_modulus(|n| ctx.get_argument(n)).map_err(GraphError::InvalidArgument)?;
let _ = get_modulus(|n| ctx.get_argument(n))
.map_err(GraphError::InvalidArgument)?;

let element_type = match ctx.get_argument("element_type").as_deref() {
Some("u8") => ElementType::U8,
Expand All @@ -56,16 +62,18 @@ impl proc_block_v1::ProcBlockV1 for ProcBlockV1 {
Some("i16") => ElementType::I16,
Some("u32") => ElementType::U32,
Some("i32") => ElementType::I32,
Some("f32") => ElementType::F32,
Some("f32") => ElementType::Float32,
Some("u64") => ElementType::U64,
Some("i64") => ElementType::I64,
Some("f64") | None => ElementType::F64,
Some("f64") | None => ElementType::Float64,
Some(_) => {
return Err(GraphError::InvalidArgument(InvalidArgument {
name: "element_type".to_string(),
reason: BadArgumentReason::InvalidValue("Unsupported element type".to_string()),
reason: BadArgumentReason::InvalidValue(
"Unsupported element type".to_string(),
),
}))
}
},
};

ctx.add_input_tensor("input", element_type, Dimensions::Dynamic);
Expand All @@ -75,10 +83,12 @@ impl proc_block_v1::ProcBlockV1 for ProcBlockV1 {
}

fn kernel(node_id: String) -> Result<(), KernelError> {
let ctx = KernelContext::for_node(&node_id)
.ok_or_else(|| KernelError::Other("Unable to load the kernel context".to_string()))?;
let ctx = KernelContext::for_node(&node_id).ok_or_else(|| {
KernelError::Other("Unable to load the kernel context".to_string())
})?;

let modulus = get_modulus(|n| ctx.get_argument(n)).map_err(KernelError::InvalidArgument)?;
let modulus = get_modulus(|n| ctx.get_argument(n))
.map_err(KernelError::InvalidArgument)?;

let TensorResult {
dimensions,
Expand All @@ -93,21 +103,41 @@ impl proc_block_v1::ProcBlockV1 for ProcBlockV1 {
// data variant that gets used.

match element_type {
ElementType::U8 => modulus_in_place(buffer.elements_mut::<u8>(), modulus)?,
ElementType::I8 => modulus_in_place(buffer.elements_mut::<i8>(), modulus)?,
ElementType::U16 => modulus_in_place(buffer.elements_mut::<u16>(), modulus)?,
ElementType::I16 => modulus_in_place(buffer.elements_mut::<i16>(), modulus)?,
ElementType::U32 => modulus_in_place(buffer.elements_mut::<u32>(), modulus)?,
ElementType::I32 => modulus_in_place(buffer.elements_mut::<i32>(), modulus)?,
ElementType::F32 => modulus_in_place(buffer.elements_mut::<f32>(), modulus)?,
ElementType::U64 => modulus_in_place(buffer.elements_mut::<u64>(), modulus)?,
ElementType::I64 => modulus_in_place(buffer.elements_mut::<i64>(), modulus)?,
ElementType::F64 => modulus_in_place(buffer.elements_mut::<f64>(), modulus)?,
ElementType::U8 => {
modulus_in_place(buffer.elements_mut::<u8>(), modulus)?
},
ElementType::I8 => {
modulus_in_place(buffer.elements_mut::<i8>(), modulus)?
},
ElementType::U16 => {
modulus_in_place(buffer.elements_mut::<u16>(), modulus)?
},
ElementType::I16 => {
modulus_in_place(buffer.elements_mut::<i16>(), modulus)?
},
ElementType::U32 => {
modulus_in_place(buffer.elements_mut::<u32>(), modulus)?
},
ElementType::I32 => {
modulus_in_place(buffer.elements_mut::<i32>(), modulus)?
},
ElementType::Float32 => {
modulus_in_place(buffer.elements_mut::<f32>(), modulus)?
},
ElementType::U64 => {
modulus_in_place(buffer.elements_mut::<u64>(), modulus)?
},
ElementType::I64 => {
modulus_in_place(buffer.elements_mut::<i64>(), modulus)?
},
ElementType::Float64 => {
modulus_in_place(buffer.elements_mut::<f64>(), modulus)?
},
ElementType::Utf8 => {
return Err(KernelError::Other(
"String tensors aren't supported".to_string(),
))
}
},
}

ctx.set_output_tensor(
Expand All @@ -123,7 +153,10 @@ impl proc_block_v1::ProcBlockV1 for ProcBlockV1 {
}
}

fn modulus_in_place<T>(values: &mut [T], modulus: f64) -> Result<(), KernelError>
fn modulus_in_place<T>(
values: &mut [T],
modulus: f64,
) -> Result<(), KernelError>
where
T: ToPrimitive + FromPrimitive + Copy + Display,
{
Expand All @@ -137,18 +170,23 @@ where
}

fn error(value: impl Display) -> KernelError {
KernelError::Other(format!("Unable to convert `{}` to/from a double", value))
KernelError::Other(format!(
"Unable to convert `{}` to/from a double",
value
))
}

fn get_modulus(get_argument: impl FnOnce(&str) -> Option<String>) -> Result<f64, InvalidArgument> {
fn get_modulus(
get_argument: impl FnOnce(&str) -> Option<String>,
) -> Result<f64, InvalidArgument> {
let value = match get_argument("modulus") {
Some(s) => s,
None => {
return Err(InvalidArgument {
name: "modulus".to_string(),
reason: BadArgumentReason::NotFound,
})
}
},
};

let value = value.parse::<f64>().map_err(|e| InvalidArgument {
Expand Down
6 changes: 3 additions & 3 deletions rune/runtime-v1.wit
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ enum element-type {
%i16,
%u32,
%i32,
%f32,
%float32,
%u64,
%i64,
%f64,
%float64,
/// A string as UTF-8 encoded bytes.
utf8,
}
Expand Down Expand Up @@ -215,7 +215,7 @@ variant log-value {
null,
boolean(bool),
integer(s64),
float(f64),
float(float64),
%string(string),
}

Expand Down

0 comments on commit 8518d3b

Please sign in to comment.