Skip to content

Commit 00067c8

Browse files
committed
Add serde support for Scalar enum
1 parent 0819e73 commit 00067c8

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

src/core/defines.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -447,16 +447,27 @@ pub const MERSENNE: RandomEngineType = RandomEngineType::MERSENNE_GP11213;
447447
/// Default RandomEngine that defaults to [PHILOX](./constant.PHILOX.html)
448448
pub const DEFAULT_RANDOM_ENGINE: RandomEngineType = PHILOX;
449449

450+
#[cfg(feature = "afserde")]
451+
#[derive(Serialize, Deserialize)]
452+
#[serde(remote = "Complex")]
453+
struct ComplexDef<T> {
454+
re: T,
455+
im: T,
456+
}
457+
450458
/// Scalar value types
451459
#[derive(Clone, Copy, Debug, PartialEq)]
460+
#[cfg_attr(feature = "afserde", derive(Serialize, Deserialize))]
452461
pub enum Scalar {
453462
/// 32 bit float
454463
F32(f32),
455464
/// 32 bit complex float
465+
#[cfg_attr(feature = "afserde", serde(with = "ComplexDef"))]
456466
C32(Complex<f32>),
457467
/// 64 bit float
458468
F64(f64),
459469
/// 64 bit complex float
470+
#[cfg_attr(feature = "afserde", serde(with = "ComplexDef"))]
460471
C64(Complex<f64>),
461472
/// 8 bit boolean
462473
B8(bool),
@@ -592,18 +603,38 @@ pub enum CublasMathMode {
592603
#[cfg(test)]
593604
mod tests {
594605
#[cfg(feature = "afserde")]
595-
#[test]
596-
fn test_enum_serde() {
597-
use super::AfError;
598-
599-
let err_code = AfError::ERR_NO_MEM;
600-
let serd = match serde_json::to_string(&err_code) {
601-
Ok(serialized_str) => serialized_str,
602-
Err(e) => e.to_string(),
603-
};
604-
assert_eq!(serd, "\"ERR_NO_MEM\"");
605-
606-
let deserd: AfError = serde_json::from_str(&serd).unwrap();
607-
assert_eq!(deserd, AfError::ERR_NO_MEM);
606+
mod serde_tests {
607+
#[test]
608+
fn test_enum_serde() {
609+
use super::super::AfError;
610+
611+
let err_code = AfError::ERR_NO_MEM;
612+
let serd = match serde_json::to_string(&err_code) {
613+
Ok(serialized_str) => serialized_str,
614+
Err(e) => e.to_string(),
615+
};
616+
assert_eq!(serd, "\"ERR_NO_MEM\"");
617+
618+
let deserd: AfError = serde_json::from_str(&serd).unwrap();
619+
assert_eq!(deserd, AfError::ERR_NO_MEM);
620+
}
621+
622+
#[test]
623+
fn test_scalar_serde() {
624+
use super::super::Scalar;
625+
use num::Complex;
626+
627+
let scalar = Scalar::C32(Complex {
628+
re: 1.0f32,
629+
im: 1.0f32,
630+
});
631+
let serd = match serde_json::to_string(&scalar) {
632+
Ok(serialized_str) => serialized_str,
633+
Err(e) => e.to_string(),
634+
};
635+
636+
let deserd: Scalar = serde_json::from_str(&serd).unwrap();
637+
assert_eq!(deserd, scalar);
638+
}
608639
}
609640
}

0 commit comments

Comments
 (0)