Skip to content

Commit d561583

Browse files
committed
Add radd and rmul back and add tests
1 parent 3f577d7 commit d561583

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def __add__(self, other):
223223
raise TypeError(invalid_add_type_emsg)
224224
return summed_do
225225

226+
__radd__ = __add__
227+
226228
def __sub__(self, other):
227229
subtracted = deepcopy(self)
228230
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
@@ -265,6 +267,18 @@ def __mul__(self, other):
265267
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
266268
return multiplied
267269

270+
def __rmul__(self, other):
271+
multiplied = deepcopy(self)
272+
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
273+
multiplied.on_tth[1] = other * self.on_tth[1]
274+
multiplied.on_q[1] = other * self.on_q[1]
275+
elif self.on_tth[0].all() != other.on_tth[0].all():
276+
raise RuntimeError(x_grid_length_mismatch_emsg)
277+
else:
278+
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
279+
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
280+
return multiplied
281+
268282
def __truediv__(self, other):
269283
divided = deepcopy(self)
270284
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):

tests/test_diffraction_objects.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def test_copy_object(do_minimal):
705705

706706

707707
@pytest.mark.parametrize(
708-
"starting_all_arrays, scalar_value, expected_all_arrays",
708+
"starting_all_arrays, scalar_to_add, expected_all_arrays",
709709
[
710710
# Test scalar addition to xarray values (q, tth, d) and expect no change to yarray values
711711
( # C1: Add integer of 5, expect xarray to increase by by 5
@@ -720,11 +720,13 @@ def test_copy_object(do_minimal):
720720
),
721721
],
722722
)
723-
def test_addition_operator_by_scalar(starting_all_arrays, scalar_value, expected_all_arrays, do_minimal_tth):
723+
def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth):
724724
do = do_minimal_tth
725725
assert np.allclose(do.all_arrays, starting_all_arrays)
726-
do_sum = do + scalar_value
727-
assert np.allclose(do_sum.all_arrays, expected_all_arrays)
726+
do_sum_RHS = do + scalar_to_add
727+
do_sum_LHS = scalar_to_add + do
728+
assert np.allclose(do_sum_RHS.all_arrays, expected_all_arrays)
729+
assert np.allclose(do_sum_LHS.all_arrays, expected_all_arrays)
728730

729731

730732
@pytest.mark.parametrize(
@@ -750,10 +752,12 @@ def test_addition_operator_by_another_do(LHS_all_arrays, RHS_all_arrays, expecte
750752

751753
def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg):
752754
# Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition
753-
do_LHS = do_minimal_tth
755+
do = do_minimal_tth
754756
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
755-
do_LHS + "string_value"
756-
757+
do + "string_value"
758+
with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)):
759+
"string_value" + do
760+
757761

758762
def test_addition_operator_invalid_xarray_length(do_minimal, do_minimal_tth, x_grid_size_mismatch_error_msg):
759763
# Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays

0 commit comments

Comments
 (0)