Skip to content

Commit a32bb40

Browse files
committed
some compilation fixes
1 parent ab2eda3 commit a32bb40

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

tflite/Cargo.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@ description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference"
77
repository = "https://github.com/snipsco/tract"
88
edition = "2021"
99

10-
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
11-
1210
[dependencies]
1311
flatbuffers = {version="23.1.21"}
1412
tract-hir = { version = "=0.20.5-pre", path = "../hir" }
15-
16-
[lib]
17-
name="tflite"
18-
path="src/lib.rs"

tflite/src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
mod ops;
22
mod tensors;
3+
4+
#[allow(unused_imports)]
35
mod tflite_generated;
4-
use crate::tflite_generated::tflite::Model as ModelBuffer;
6+
pub use tflite_generated::tflite;
57

8+
9+
/*
10+
use crate::tflite_generated::tflite::Model as ModelBuffer;
611
impl ModelBuffer {
712
pub fn from_file(path: P) -> Result<ModelBuffer, Error> {
813
let model_file = &*fs::read(model_file_path)?;
@@ -20,3 +25,4 @@ pub struct TFLiteModel<'model> {
2025
}
2126
2227
impl TFLite<'_> {}
28+
*/

tflite/src/tensors.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::tflite_generated::{tflite::TensorType, TensorType as BufferTensorType};
1+
use crate::tflite_generated::tflite::{TensorType, TensorType as BufferTensorType};
22
#[cfg(feature = "complex")]
33
use num_complex::Complex;
44
use tract_hir::internal::*;
@@ -8,15 +8,15 @@ impl TryFrom<BufferTensorType> for DatumType {
88
fn try_from(t: BufferTensorType) -> TractResult<DatumType> {
99
Ok(match t {
1010
BufferTensorType::FLOAT32 => DatumType::F32,
11-
BUfferTensorType::FLOAT16 => DatumType::F16,
11+
BufferTensorType::FLOAT16 => DatumType::F16,
1212
BufferTensorType::INT32 => DatumType::I32,
1313
BufferTensorType::UINT8 => DatumType::U8,
1414
BufferTensorType::INT64 => DatumType::I64,
1515
BufferTensorType::STRING => DatumType::String,
1616
BufferTensorType::BOOL => DatumType::Bool,
1717
BufferTensorType::INT16 => DatumType::I16,
1818
#[cfg(feature = "complex")]
19-
BufferTensorType::COMPLEX64 => DatumType::ComplexF64,
19+
BufferTensorType::COMPLEX64 => DatumType::ComplexF64, // TODO check this
2020
TensorType::INT8 => DatumType::I8,
2121
TensorType::FLOAT64 => DatumType::F64,
2222
//TensorType::COMPLEX128 => DatumType::ComplexF64,
@@ -32,7 +32,7 @@ impl TryFrom<BufferTensorType> for DatumType {
3232
}
3333
}
3434

35-
fn create_tensor(shape: Vec<usize>, dt: DatumType, data: &[u8]) -> TractResult<Tensor> {
35+
fn create_tensor(dt: DatumType, shape: &[usize], data: &[u8]) -> TractResult<Tensor> {
3636
unsafe {
3737
match dt {
3838
DatumType::U8 => Tensor::from_raw::<u8>(&shape, data),
@@ -47,7 +47,7 @@ fn create_tensor(shape: Vec<usize>, dt: DatumType, data: &[u8]) -> TractResult<T
4747
DatumType::F32 => Tensor::from_raw::<f32>(&shape, data),
4848
DatumType::F64 => Tensor::from_raw::<f64>(&shape, data),
4949
#[cfg(feature = "complex")]
50-
DatumType::ComplexF64 => Tensor::from_raw::<Complex<f64>>(&shape, data),
50+
DatumType::ComplexF64 => Tensor::from_raw::<Complex<f64>>(&shape, data), // TODO check this
5151
DatumType::Bool => Ok(Tensor::from_raw::<u8>(&shape, data)?
5252
.into_array::<u8>()?
5353
.mapv(|x| x != 0)

0 commit comments

Comments
 (0)