Skip to content

Commit 05a4da5

Browse files
committed
Add optional serde serialization support
- Update ci to run serde tests - Add serialization support for Enums except the enum `arrayfire::Scalar` - Structs with serde support added - [x] Array - [x] Dim4 - [x] Seq - [x] RandomEngine - Structs without serde support - Features - currently not possible as `af_features` can't be recreated from individual `af_arrays` with current upstream API - Indexer - not possible with current API. Also, any subarray when fetched to host for serialization results in separate owned copy this making serde support for this unnecessary. - Callback - Event - Window
1 parent 97d097b commit 05a4da5

File tree

8 files changed

+275
-4
lines changed

8 files changed

+275
-4
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ jobs:
4848
export AF_PATH=${GITHUB_WORKSPACE}/afbin
4949
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${AF_PATH}/lib64
5050
echo "Using cargo version: $(cargo --version)"
51-
cargo build --all
52-
cargo test --no-fail-fast
51+
cargo build --all --all-features
52+
cargo test --no-fail-fast --all-features
5353
5454
format:
5555
name: Format Check

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,19 @@ statistics = []
4646
vision = []
4747
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
4848
"ml", "macros", "random", "signal", "sparse", "statistics", "vision"]
49+
afserde = ["serde"]
4950

5051
[dependencies]
5152
libc = "0.2"
5253
num = "0.2"
5354
lazy_static = "1.0"
5455
half = "1.5.0"
56+
serde = { version = "1.0", features = ["derive"], optional = true }
5557

5658
[dev-dependencies]
5759
half = "1.5.0"
60+
serde_json = "1.0"
61+
bincode = "1.3"
5862

5963
[build-dependencies]
6064
serde_json = "1.0"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Only, Major(M) & Minor(m) version numbers need to match. *p1* and *p2* are patch
1616

1717
## Supported platforms
1818

19-
Linux, Windows and OSX. Rust 1.15.1 or higher is required.
19+
Linux, Windows and OSX. Rust 1.31 or newer is required.
2020

2121
## Use from Crates.io [![][6]][7] [![][8]][9]
2222

src/core/array.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,73 @@ pub fn is_eval_manual() -> bool {
851851
}
852852
}
853853

854+
#[cfg(feature = "afserde")]
855+
mod afserde {
856+
// Reimport required from super scope
857+
use super::{Array, DType, Dim4, HasAfEnum};
858+
859+
use serde::de::{Deserializer, Error, Unexpected};
860+
use serde::ser::Serializer;
861+
use serde::{Deserialize, Serialize};
862+
863+
#[derive(Debug, Serialize, Deserialize)]
864+
struct ArrayOnHost<T: HasAfEnum + std::fmt::Debug> {
865+
dtype: DType,
866+
shape: Dim4,
867+
data: Vec<T>,
868+
}
869+
870+
/// Serialize Implementation of Array
871+
impl<T> Serialize for Array<T>
872+
where
873+
T: std::default::Default + std::clone::Clone + Serialize + HasAfEnum + std::fmt::Debug,
874+
{
875+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
876+
where
877+
S: Serializer,
878+
{
879+
let mut vec = vec![T::default(); self.elements()];
880+
self.host(&mut vec);
881+
let arr_on_host = ArrayOnHost {
882+
dtype: self.get_type(),
883+
shape: self.dims().clone(),
884+
data: vec,
885+
};
886+
arr_on_host.serialize(serializer)
887+
}
888+
}
889+
890+
/// Deserialize Implementation of Array
891+
impl<'de, T> Deserialize<'de> for Array<T>
892+
where
893+
T: Deserialize<'de> + HasAfEnum + std::fmt::Debug,
894+
{
895+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
896+
where
897+
D: Deserializer<'de>,
898+
{
899+
match ArrayOnHost::<T>::deserialize(deserializer) {
900+
Ok(arr_on_host) => {
901+
let read_dtype = arr_on_host.dtype;
902+
let expected_dtype = T::get_af_dtype();
903+
if expected_dtype != read_dtype {
904+
let error_msg = format!(
905+
"data type is {:?}, deserialized type is {:?}",
906+
expected_dtype, read_dtype
907+
);
908+
return Err(Error::invalid_value(Unexpected::Enum, &error_msg.as_str()));
909+
}
910+
Ok(Array::<T>::new(
911+
&arr_on_host.data,
912+
arr_on_host.shape.clone(),
913+
))
914+
}
915+
Err(err) => Err(err),
916+
}
917+
}
918+
}
919+
}
920+
854921
#[cfg(test)]
855922
mod tests {
856923
use super::super::array::print;
@@ -1082,4 +1149,37 @@ mod tests {
10821149
// 8.0000 8.0000 8.0000
10831150
// ANCHOR_END: accum_using_channel
10841151
}
1152+
1153+
#[cfg(feature = "afserde")]
1154+
mod serde_tests {
1155+
use super::super::Array;
1156+
use crate::algorithm::sum_all;
1157+
use crate::randu;
1158+
1159+
#[test]
1160+
fn array_serde_json() {
1161+
let input = randu!(u8; 2, 2);
1162+
let serd = match serde_json::to_string(&input) {
1163+
Ok(serialized_str) => serialized_str,
1164+
Err(e) => e.to_string(),
1165+
};
1166+
1167+
let deserd: Array<u8> = serde_json::from_str(&serd).unwrap();
1168+
1169+
assert_eq!(sum_all(&(input - deserd)), (0u32, 0u32));
1170+
}
1171+
1172+
#[test]
1173+
fn array_serde_bincode() {
1174+
let input = randu!(u8; 2, 2);
1175+
let encoded = match bincode::serialize(&input) {
1176+
Ok(encoded) => encoded,
1177+
Err(_) => vec![],
1178+
};
1179+
1180+
let decoded: Array<u8> = bincode::deserialize(&encoded).unwrap();
1181+
1182+
assert_eq!(sum_all(&(input - decoded)), (0u32, 0u32));
1183+
}
1184+
}
10851185
}

0 commit comments

Comments
 (0)