Skip to content

Commit

Permalink
Also patched decimal multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
lorinkoz committed Jun 27, 2024
1 parent 2110f9a commit 844eb37
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 23 deletions.
30 changes: 24 additions & 6 deletions pytests/test_money.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,12 @@ def test_advanced_negation(decimal):
assert str(Money(decimal).amount) == "-0"


@pytest.mark.parametrize("left", ["0", "-0"])
@pytest.mark.parametrize("right", ["0", "-0"])
def test_add_with_negative_zeros(left, right):
OPERANDS = ["0", "-0", "1", "-1"]


@pytest.mark.parametrize("left", OPERANDS)
@pytest.mark.parametrize("right", OPERANDS)
def test_add_with_operands(left, right):
expected = Decimal(left) + Decimal(right)

money_left = Money(left)
Expand All @@ -152,9 +155,9 @@ def test_add_with_negative_zeros(left, right):
assert str((money_left + money_right).amount) == str(expected)


@pytest.mark.parametrize("left", ["0", "-0"])
@pytest.mark.parametrize("right", ["0", "-0"])
def test_sub_with_negative_zeros(left, right):
@pytest.mark.parametrize("left", OPERANDS)
@pytest.mark.parametrize("right", OPERANDS)
def test_sub_with_operands(left, right):
expected = Decimal(left) - Decimal(right)

money_left = Money(left)
Expand All @@ -168,6 +171,21 @@ def test_sub_with_negative_zeros(left, right):
assert str((money_left - money_right).amount) == str(expected)


@pytest.mark.parametrize("left", OPERANDS)
@pytest.mark.parametrize("right", OPERANDS)
def test_mult_with_operands(left, right):
expected = Decimal(left) * Decimal(right)

money_left = Money(left)
money_right = Money(right)

assert str(money_left.amount) == str(left)
assert str(money_right.amount) == str(right)

assert str((money_left * Decimal(right)).amount) == str(expected)
assert str((Decimal(left) * money_right).amount) == str(expected)


def test_equality_to_other_types():
x = Money(0)
assert x != None # noqa: E711
Expand Down
15 changes: 15 additions & 0 deletions src/decimals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ pub fn decimal_add(left: Decimal, right: Decimal) -> Decimal {
}
}

// Multiplies decimals the way of Python
pub fn decimal_mult(left: Decimal, right: Decimal) -> Decimal {
let zero = Decimal::new(0, 0);

if left.abs() == zero || right.abs() == zero {
if left.is_sign_negative() == right.is_sign_negative() {
zero
} else {
-zero
}
} else {
left * right
}
}

// Rounds decimals the way of Python
pub fn round(value: Decimal, scale: i32, round_up: bool) -> Decimal {
let strategy = if round_up {
Expand Down
11 changes: 3 additions & 8 deletions src/money.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,9 @@ impl Money {

fn __mul__(&self, other: Bound<PyAny>) -> PyResult<Self> {
if let Ok(other_decimal) = get_decimal(other) {
if other_decimal == Decimal::new(-1, 0) {
// Hack for minus zero
Ok(self.__neg__())
} else {
Ok(Self {
amount: self.amount * other_decimal,
})
}
Ok(Self {
amount: decimal_mult(self.amount, other_decimal),
})
} else {
Err(pyo3::exceptions::PyTypeError::new_err(
"Unsupported operand",
Expand Down
12 changes: 6 additions & 6 deletions src/money_vat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::str::FromStr;

use crate::decimals::{decimal_add, get_decimal};
use crate::decimals::{decimal_add, decimal_mult, get_decimal};
use crate::money::{Money, MONEY_PRECISION};
use crate::money_vat_ratio::MoneyWithVATRatio;

Expand Down Expand Up @@ -75,7 +75,7 @@ impl MoneyWithVAT {
}

for rate in Self::known_vat_rates() {
let vat = rate * self.net.amount;
let vat = decimal_mult(rate, self.net.amount);
let vat_diff = (decimal_add(vat, -self.tax.amount)).abs();
if vat_diff < boundary {
return rate;
Expand Down Expand Up @@ -218,23 +218,23 @@ impl MoneyWithVAT {

fn __mul__(&self, other: Bound<PyAny>) -> PyResult<Self> {
if let Ok(other_ratio) = other.extract::<MoneyWithVATRatio>() {
let net_value = other_ratio.net_ratio * self.net.amount;
let net_value = decimal_mult(other_ratio.net_ratio, self.net.amount);
Ok(Self {
net: Money { amount: net_value },
tax: Money {
amount: decimal_add(
other_ratio.gross_ratio * self.get_gross().amount,
decimal_mult(other_ratio.gross_ratio, self.get_gross().amount),
-net_value,
),
},
})
} else if let Ok(other_decimal) = get_decimal(other) {
Ok(Self {
net: Money {
amount: self.net.amount * other_decimal,
amount: decimal_mult(self.net.amount, other_decimal),
},
tax: Money {
amount: self.tax.amount * other_decimal,
amount: decimal_mult(self.tax.amount, other_decimal),
},
})
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/money_vat_ratio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::types::{PyCFunction, PyDict, PyTuple};
use rust_decimal::prelude::FromPrimitive;
use rust_decimal::Decimal;

use crate::decimals::{decimal_add, get_decimal};
use crate::decimals::{decimal_add, decimal_mult, get_decimal};

#[pyclass]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -81,8 +81,8 @@ impl MoneyWithVATRatio {
let other_decimal = Decimal::from_f64(other).unwrap();

Self {
net_ratio: self.net_ratio * other_decimal,
gross_ratio: self.gross_ratio * other_decimal,
net_ratio: decimal_mult(self.net_ratio, other_decimal),
gross_ratio: decimal_mult(self.gross_ratio, other_decimal),
}
}

Expand Down

0 comments on commit 844eb37

Please sign in to comment.