1- use crate :: tflite_generated:: { tflite:: TensorType , TensorType as BufferTensorType } ;
1+ use crate :: tflite_generated:: tflite:: { TensorType , TensorType as BufferTensorType } ;
22#[ cfg( feature = "complex" ) ]
33use num_complex:: Complex ;
44use tract_hir:: internal:: * ;
@@ -8,15 +8,15 @@ impl TryFrom<BufferTensorType> for DatumType {
88 fn try_from ( t : BufferTensorType ) -> TractResult < DatumType > {
99 Ok ( match t {
1010 BufferTensorType :: FLOAT32 => DatumType :: F32 ,
11- BUfferTensorType :: FLOAT16 => DatumType :: F16 ,
11+ BufferTensorType :: FLOAT16 => DatumType :: F16 ,
1212 BufferTensorType :: INT32 => DatumType :: I32 ,
1313 BufferTensorType :: UINT8 => DatumType :: U8 ,
1414 BufferTensorType :: INT64 => DatumType :: I64 ,
1515 BufferTensorType :: STRING => DatumType :: String ,
1616 BufferTensorType :: BOOL => DatumType :: Bool ,
1717 BufferTensorType :: INT16 => DatumType :: I16 ,
1818 #[ cfg( feature = "complex" ) ]
19- BufferTensorType :: COMPLEX64 => DatumType :: ComplexF64 ,
19+ BufferTensorType :: COMPLEX64 => DatumType :: ComplexF64 , // TODO check this
2020 TensorType :: INT8 => DatumType :: I8 ,
2121 TensorType :: FLOAT64 => DatumType :: F64 ,
2222 //TensorType::COMPLEX128 => DatumType::ComplexF64,
@@ -32,7 +32,7 @@ impl TryFrom<BufferTensorType> for DatumType {
3232 }
3333}
3434
35- fn create_tensor ( shape : Vec < usize > , dt : DatumType , data : & [ u8 ] ) -> TractResult < Tensor > {
35+ fn create_tensor ( dt : DatumType , shape : & [ usize ] , data : & [ u8 ] ) -> TractResult < Tensor > {
3636 unsafe {
3737 match dt {
3838 DatumType :: U8 => Tensor :: from_raw :: < u8 > ( & shape, data) ,
@@ -47,7 +47,7 @@ fn create_tensor(shape: Vec<usize>, dt: DatumType, data: &[u8]) -> TractResult<T
4747 DatumType :: F32 => Tensor :: from_raw :: < f32 > ( & shape, data) ,
4848 DatumType :: F64 => Tensor :: from_raw :: < f64 > ( & shape, data) ,
4949 #[ cfg( feature = "complex" ) ]
50- DatumType :: ComplexF64 => Tensor :: from_raw :: < Complex < f64 > > ( & shape, data) ,
50+ DatumType :: ComplexF64 => Tensor :: from_raw :: < Complex < f64 > > ( & shape, data) , // TODO check this
5151 DatumType :: Bool => Ok ( Tensor :: from_raw :: < u8 > ( & shape, data) ?
5252 . into_array :: < u8 > ( ) ?
5353 . mapv ( |x| x != 0 )
0 commit comments