Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ members = [
"examples/keras-tract-tf2",
"examples/nnef-dump-mobilenet-v2",
"examples/nnef-mobilenet-v2",
"examples/nnef-mobilenet-v2-api",
"examples/onnx-mobilenet-v2",
"examples/pytorch-albert-v2",
"examples/pytorch-resnet",
Expand Down
147 changes: 128 additions & 19 deletions api/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use anyhow::{Context, Result};
use std::cell::RefCell;
use std::ffi::{CStr, CString, c_char, c_void};
use tract_api::{
AsFact, DatumType, InferenceModelInterface, ModelInterface, NnefInterface, OnnxInterface,
RunnableInterface, StateInterface, ValueInterface,
AsFact, DatumType, DimInterface, FactInterface, InferenceModelInterface, ModelInterface,
NnefInterface, OnnxInterface, RunnableInterface, StateInterface, ValueInterface,
};
use tract_rs::{State, Value};

Expand Down Expand Up @@ -116,23 +116,6 @@ pub unsafe extern "C" fn tract_nnef_create(nnef: *mut *mut TractNnef) -> TRACT_R
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_nnef_transform_model(
nnef: *const TractNnef,
model: *mut TractModel,
transform_spec: *const i8,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef, model, transform_spec);
let transform_spec = CStr::from_ptr(transform_spec as _).to_str()?;
(*nnef)
.0
.transform_model(&mut (*model).0, transform_spec)
.with_context(|| format!("performing transform {transform_spec:?}"))?;
Ok(())
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_nnef_enable_tract_core(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
Expand Down Expand Up @@ -213,6 +196,27 @@ pub unsafe extern "C" fn tract_nnef_load(
})
}

/// Parse and load an NNEF buffer as a tract TypedModel.
///
/// `data` is a buffer pointer
/// `len` ise the buffer len
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_nnef_load_buffer(
nnef: *const TractNnef,
data: *const c_void,
len: usize,
model: *mut *mut TractModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef, model, data);
*model = std::ptr::null_mut();
let slice = std::slice::from_raw_parts(data as *const u8, len);
let m = Box::new(TractModel((*nnef).0.load_buffer(slice)?));
*model = Box::into_raw(m);
Ok(())
})
}

/// Dump a TypedModel as a NNEF tar file.
///
/// `path` is a null-terminated utf-8 string pointer to the `.tar` file to be created.
Expand Down Expand Up @@ -309,6 +313,27 @@ pub unsafe extern "C" fn tract_onnx_load(
})
}

/// Parse and load an ONNX buffer as a tract InferenceModel.
///
/// `data` is a buffer pointer
/// `len` ise the buffer len
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_onnx_load_buffer(
onnx: *const TractOnnx,
data: *const c_void,
len: usize,
model: *mut *mut TractInferenceModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(onnx, model, data);
*model = std::ptr::null_mut();
let slice = std::slice::from_raw_parts(data as *const u8, len);
let m = Box::new(TractInferenceModel((*onnx).0.load_buffer(slice)?));
*model = Box::into_raw(m);
Ok(())
})
}

// INFERENCE MODEL
pub struct TractInferenceModel(tract_rs::InferenceModel);

Expand Down Expand Up @@ -1103,6 +1128,31 @@ pub unsafe extern "C" fn tract_fact_parse(
})
}

/// Gets the rank (aka number of axes/dimensions) of a fact.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_fact_rank(fact: *const TractFact, rank: *mut usize) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(fact, rank);
*rank = (*fact).0.rank()?;
Ok(())
})
}

/// Extract the dimension from one dimension of the fact.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_fact_dim(
fact: *const TractFact,
axis: usize,
dim: *mut *mut TractDim,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(fact, dim);
let d = (*fact).0.dim(axis)?;
*dim = Box::into_raw(Box::new(TractDim(d)));
Ok(())
})
}

/// Write a fact as its specification string.
///
/// The returned string must be freed by the caller using tract_free_cstring.
Expand Down Expand Up @@ -1181,6 +1231,65 @@ pub unsafe extern "C" fn tract_inference_fact_destroy(
release!(fact)
}

/// Dim
pub struct TractDim(tract_rs::Dim);

#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_dim_eval(
dim: *const TractDim,
nb_symbols: usize,
symbols: *const *const i8,
values: *const i64,
result: *mut *mut TractDim,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(dim, symbols, values, result);
let mut table = vec![];
for i in 0..nb_symbols {
let name = CStr::from_ptr(*symbols.add(i) as _)
.to_str()
.with_context(|| {
format!("failed to parse symbol name for {i}th symbol (not utf8)")
})?
.to_owned();
table.push((name, *values.add(i)));
}
let r = (*dim).0.eval(table)?;
*result = Box::into_raw(Box::new(TractDim(r)));
Ok(())
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_dim_to_int64(fact: *const TractDim, i: *mut i64) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(fact, i);
*i = (*fact).0.to_int64()?;
Ok(())
})
}

/// Write a dim as its specification string.
///
/// The returned string must be freed by the caller using tract_free_cstring.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_dim_dump(
dim: *const TractDim,
spec: *mut *mut c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(dim, spec);
*spec = CString::new((*dim).0.to_string())?.into_raw();
Ok(())
})
}

/// Destroy a dim.
#[unsafe(no_mangle)]
pub unsafe extern "C" fn tract_dim_destroy(dim: *mut *mut TractDim) -> TRACT_RESULT {
release!(dim)
}

// MISC

// HELPERS
Expand Down
71 changes: 67 additions & 4 deletions api/proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ impl NnefInterface for Nnef {
Ok(Model(model))
}

fn transform_model(&self, model: &mut Self::Model, transform_spec: &str) -> Result<()> {
let t = CString::new(transform_spec)?;
check!(sys::tract_nnef_transform_model(self.0, model.0, t.as_ptr()))
fn load_buffer(&self, data: &[u8]) -> Result<Model> {
let mut model = null_mut();
check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
Ok(Model(model))
}

fn enable_tract_core(&mut self) -> Result<()> {
Expand Down Expand Up @@ -136,6 +137,12 @@ impl OnnxInterface for Onnx {
check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
Ok(InferenceModel(model))
}

fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
let mut model = null_mut();
check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
Ok(InferenceModel(model))
}
}

// INFERENCE MODEL
Expand Down Expand Up @@ -600,7 +607,21 @@ impl Fact {
}
}

impl FactInterface for Fact {}
impl FactInterface for Fact {
type Dim = Dim;

fn rank(&self) -> Result<usize> {
let mut rank = 0;
check!(sys::tract_fact_rank(self.0, &mut rank))?;
Ok(rank)
}

fn dim(&self, axis: usize) -> Result<Self::Dim> {
let mut ptr = null_mut();
check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
Ok(Dim(ptr))
}
}

impl std::fmt::Display for Fact {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -658,3 +679,45 @@ impl std::fmt::Display for InferenceFact {

as_inference_fact_impl!(InferenceModel, InferenceFact);
as_fact_impl!(Model, Fact);

// Dim
wrapper!(Dim, TractDim, tract_dim_destroy);

impl Dim {
fn dump(&self) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_dim_dump(self.0, &mut ptr))?;
unsafe {
let s = CStr::from_ptr(ptr).to_owned();
sys::tract_free_cstring(ptr);
Ok(s.to_str()?.to_owned())
}
}
}

impl DimInterface for Dim {
fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
let c_strings: Vec<CString> =
names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
let mut ptr = null_mut();
check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
Ok(Dim(ptr))
}

fn to_int64(&self) -> Result<i64> {
let mut i = 0;
check!(sys::tract_dim_to_int64(self.0, &mut i))?;
Ok(i)
}
}

impl std::fmt::Display for Dim {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.dump() {
Ok(s) => f.write_str(&s),
Err(_) => Err(std::fmt::Error),
}
}
}
1 change: 1 addition & 0 deletions api/proxy/sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::path::PathBuf;

fn main() {
println!("cargo:rerun-if-env-changed=TRACT_DYLIB_SEARCH_PATH");
println!("cargo:rerun-if-env-changed=tract.h");
if let Ok(path) = std::env::var("TRACT_DYLIB_SEARCH_PATH") {
println!("cargo:rustc-link-search={path}");
}
Expand Down
Loading
Loading