Skip to content

Commit

Permalink
Making distributions comparable by deriving PartialEq. Tests included
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Springer committed Feb 11, 2022
1 parent a407bdf commit 9f20df0
Show file tree
Hide file tree
Showing 22 changed files with 182 additions and 39 deletions.
7 changes: 6 additions & 1 deletion rand_distr/src/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use num_traits::Float;
/// let v = bin.sample(&mut rand::thread_rng());
/// println!("{} is from a binomial distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Binomial {
/// Number of trials.
Expand Down Expand Up @@ -347,4 +347,9 @@ mod test {
fn test_binomial_invalid_lambda_neg() {
Binomial::new(20, -10.0).unwrap();
}

#[test]
fn binomial_distributions_can_be_compared() {
assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/cauchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use core::fmt;
/// let v = cau.sample(&mut rand::thread_rng());
/// println!("{} is from a Cauchy(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
Expand Down Expand Up @@ -164,4 +164,9 @@ mod test {
assert_almost_eq!(*a, *b, 1e-5);
}
}

#[test]
fn cauchy_distributions_can_be_compared() {
assert_eq!(Cauchy::new(1.0, 2.0), Cauchy::new(1.0, 2.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use alloc::{boxed::Box, vec, vec::Vec};
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Dirichlet<F>
where
Expand Down Expand Up @@ -183,4 +183,9 @@ mod test {
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_size(0.0f64, 2).unwrap();
}

#[test]
fn dirichlet_distributions_can_be_compared() {
assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0]));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl Distribution<f64> for Exp1 {
/// let v = exp.sample(&mut rand::thread_rng());
/// println!("{} is from a Exp(2) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Exp<F>
where F: Float, Exp1: Distribution<F>
Expand Down Expand Up @@ -178,4 +178,9 @@ mod test {
fn test_exp_invalid_lambda_nan() {
Exp::new(f64::nan()).unwrap();
}

#[test]
fn exponential_distributions_can_be_compared() {
assert_eq!(Exp::new(1.0), Exp::new(1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/frechet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use rand::Rng;
/// let val: f64 = thread_rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap());
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Frechet<F>
where
Expand Down Expand Up @@ -182,4 +182,9 @@ mod tests {
.zip(&probabilities)
.all(|(p_hat, p)| (p_hat - p).abs() < 0.003))
}

#[test]
fn frechet_distributions_can_be_compared() {
assert_eq!(Frechet::new(1.0, 2.0, 3.0), Frechet::new(1.0, 2.0, 3.0));
}
}
49 changes: 37 additions & 12 deletions rand_distr/src/gamma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use serde::{Serialize, Deserialize};
/// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3
/// (September 2000), 363-372.
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Gamma<F>
where
Expand Down Expand Up @@ -91,7 +91,7 @@ impl fmt::Display for Error {
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum GammaRepr<F>
where
Expand Down Expand Up @@ -119,7 +119,7 @@ where
///
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct GammaSmallShape<F>
where
Expand All @@ -135,7 +135,7 @@ where
///
/// See `Gamma` for sampling from a Gamma distribution with general
/// shape parameters.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct GammaLargeShape<F>
where
Expand Down Expand Up @@ -280,7 +280,7 @@ where
/// let v = chi.sample(&mut rand::thread_rng());
/// println!("{} is from a χ²(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ChiSquared<F>
where
Expand Down Expand Up @@ -314,7 +314,7 @@ impl fmt::Display for ChiSquaredError {
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for ChiSquaredError {}

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum ChiSquaredRepr<F>
where
Expand Down Expand Up @@ -385,7 +385,7 @@ where
/// let v = f.sample(&mut rand::thread_rng());
/// println!("{} is from an F(2, 32) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct FisherF<F>
where
Expand Down Expand Up @@ -472,7 +472,7 @@ where
/// let v = t.sample(&mut rand::thread_rng());
/// println!("{} is from a t(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct StudentT<F>
where
Expand Down Expand Up @@ -522,15 +522,15 @@ where
/// Generating beta variates with nonintegral shape parameters.
/// Communications of the ACM 21, 317-322.
/// https://doi.org/10.1145/359460.359482
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum BetaAlgorithm<N> {
BB(BB<N>),
BC(BC<N>),
}

/// Algorithm BB for `min(alpha, beta) > 1`.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BB<N> {
alpha: N,
Expand All @@ -539,7 +539,7 @@ struct BB<N> {
}

/// Algorithm BC for `min(alpha, beta) <= 1`.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BC<N> {
alpha: N,
Expand All @@ -560,7 +560,7 @@ struct BC<N> {
/// let v = beta.sample(&mut rand::thread_rng());
/// println!("{} is from a Beta(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Beta<F>
where
Expand Down Expand Up @@ -811,4 +811,29 @@ mod test {
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
}
}

#[test]
fn gamma_distributions_can_be_compared() {
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
}

#[test]
fn beta_distributions_can_be_compared() {
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
}

#[test]
fn chi_squared_distributions_can_be_compared() {
assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0));
}

#[test]
fn fisher_f_distributions_can_be_compared() {
assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0));
}

#[test]
fn student_t_distributions_can_be_compared() {
assert_eq!(StudentT::new(1.0), StudentT::new(1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/geometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use num_traits::Float;
/// let v = geo.sample(&mut rand::thread_rng());
/// println!("{} is from a Geometric(0.25) distribution", v);
/// ```
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Geometric
{
Expand Down Expand Up @@ -235,4 +235,9 @@ mod test {
results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
assert!((variance - expected_variance).abs() < expected_variance / 10.0);
}

#[test]
fn geometric_distributions_can_be_compared() {
assert_eq!(Geometric::new(1.0), Geometric::new(1.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/gumbel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use rand::Rng;
/// let val: f64 = thread_rng().sample(Gumbel::new(0.0, 1.0).unwrap());
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Gumbel<F>
where
Expand Down Expand Up @@ -152,4 +152,9 @@ mod tests {
.zip(&probabilities)
.all(|(p_hat, p)| (p_hat - p).abs() < 0.003))
}

#[test]
fn gumbel_distributions_can_be_compared() {
assert_eq!(Gumbel::new(1.0, 2.0), Gumbel::new(1.0, 2.0));
}
}
9 changes: 7 additions & 2 deletions rand_distr/src/hypergeometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::fmt;
#[allow(unused_imports)]
use num_traits::Float;

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
enum SamplingMethod {
InverseTransform{ initial_p: f64, initial_x: i64 },
Expand Down Expand Up @@ -45,7 +45,7 @@ enum SamplingMethod {
/// let v = hypergeo.sample(&mut rand::thread_rng());
/// println!("{} is from a hypergeometric distribution", v);
/// ```
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Hypergeometric {
n1: u64,
Expand Down Expand Up @@ -419,4 +419,9 @@ mod test {
test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
}

#[test]
fn hypergeometric_distributions_can_be_compared() {
assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/inverse_gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl fmt::Display for Error {
impl std::error::Error for Error {}

/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution)
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct InverseGaussian<F>
where
Expand Down Expand Up @@ -109,4 +109,9 @@ mod tests {
assert!(InverseGaussian::new(1.0, -1.0).is_err());
assert!(InverseGaussian::new(1.0, 1.0).is_ok());
}

#[test]
fn inverse_gaussian_distributions_can_be_compared() {
assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0));
}
}
14 changes: 12 additions & 2 deletions rand_distr/src/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl Distribution<f64> for StandardNormal {
/// ```
///
/// [`StandardNormal`]: crate::StandardNormal
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Normal<F>
where F: Float, StandardNormal: Distribution<F>
Expand Down Expand Up @@ -227,7 +227,7 @@ where F: Float, StandardNormal: Distribution<F>
/// let v = log_normal.sample(&mut rand::thread_rng());
/// println!("{} is from an ln N(2, 9) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct LogNormal<F>
where F: Float, StandardNormal: Distribution<F>
Expand Down Expand Up @@ -368,4 +368,14 @@ mod tests {
assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err());
assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err());
}

#[test]
fn normal_distributions_can_be_compared() {
assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0));
}

#[test]
fn log_normal_distributions_can_be_compared() {
assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/normal_inverse_gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl fmt::Display for Error {
impl std::error::Error for Error {}

/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution)
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct NormalInverseGaussian<F>
where
Expand Down Expand Up @@ -104,4 +104,9 @@ mod tests {
assert!(NormalInverseGaussian::new(1.0, 2.0).is_err());
assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok());
}

#[test]
fn normal_inverse_gaussian_distributions_can_be_compared() {
assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0));
}
}
7 changes: 6 additions & 1 deletion rand_distr/src/pareto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use core::fmt;
/// let val: f64 = thread_rng().sample(Pareto::new(1., 2.).unwrap());
/// println!("{}", val);
/// ```
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Pareto<F>
where F: Float, OpenClosed01: Distribution<F>
Expand Down Expand Up @@ -131,4 +131,9 @@ mod tests {
105.8826669383772,
]);
}

#[test]
fn pareto_distributions_can_be_compared() {
assert_eq!(Pareto::new(1.0, 2.0), Pareto::new(1.0, 2.0));
}
}
Loading

0 comments on commit 9f20df0

Please sign in to comment.