Skip to content

Commit ddeeb26

Browse files
committed
feat: implement __mul__ for DiffracitonObject wih test funcs
1 parent 630f00e commit ddeeb26

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

src/diffpy/utils/diffraction_objects.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,18 +229,13 @@ def __sub__(self, other):
229229
__rsub__ = __sub__
230230

231231
def __mul__(self, other):
232-
multiplied = deepcopy(self)
233-
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray):
234-
multiplied.on_tth[1] = other * self.on_tth[1]
235-
multiplied.on_q[1] = other * self.on_q[1]
236-
elif not isinstance(other, DiffractionObject):
237-
raise TypeError("I only know how to multiply two Scattering_object objects")
238-
elif self.on_tth[0].all() != other.on_tth[0].all():
239-
raise RuntimeError(y_grid_length_mismatch_emsg)
240-
else:
241-
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1]
242-
multiplied.on_q[1] = self.on_q[1] * other.on_q[1]
243-
return multiplied
232+
self._check_operation_compatibility(other)
233+
multiplied_do = deepcopy(self)
234+
if isinstance(other, (int, float)):
235+
multiplied_do._all_arrays[:, 0] *= other
236+
if isinstance(other, DiffractionObject):
237+
multiplied_do._all_arrays[:, 0] *= other.all_arrays[:, 0]
238+
return multiplied_do
244239

245240
__rmul__ = __mul__
246241

tests/test_diffraction_objects.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,18 +741,33 @@ def test_copy_object(do_minimal):
741741
0.5,
742742
np.array([[0.5, 0.51763809, 30.0, 12.13818192], [1.5, 1.0, 60.0, 6.28318531]]),
743743
),
744+
# C2. Test scalar multiplication to yarray values (intensity), expect no change to xarrays (q, tth, d)
745+
( # 1. Multipliy by integer 2
746+
"mul",
747+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
748+
2,
749+
np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
750+
),
751+
( # 2. Multipliy by float 0.5
752+
"mul",
753+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),
754+
2.5,
755+
np.array([[2.5, 0.51763809, 30.0, 12.13818192], [5.0, 1.0, 60.0, 6.28318531]]),
756+
),
744757
],
745758
)
746759
def test_scalar_operations(operation, starting_all_arrays, scalar_value, expected_all_arrays, do_minimal_tth):
747760
do = do_minimal_tth
748761
assert np.allclose(do.all_arrays, starting_all_arrays)
749-
750762
if operation == "add":
751763
result_right = do + scalar_value
752764
result_left = scalar_value + do
753765
elif operation == "sub":
754766
result_right = do - scalar_value
755767
result_left = scalar_value - do
768+
elif operation == "mul":
769+
result_right = do * scalar_value
770+
result_left = scalar_value * do
756771

757772
assert np.allclose(result_right.all_arrays, expected_all_arrays)
758773
assert np.allclose(result_left.all_arrays, expected_all_arrays)
@@ -773,6 +788,11 @@ def test_scalar_operations(operation, starting_all_arrays, scalar_value, expecte
773788
np.array([[0.0, 0.51763809, 30.0, 12.13818192], [0.0, 1.0, 60.0, 6.28318531]]),
774789
np.array([[0.0, 6.28318531, 100.70777771, 1], [0.0, 3.14159265, 45.28748053, 2.0]]),
775790
),
791+
(
792+
"mul",
793+
np.array([[1.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),
794+
np.array([[1.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),
795+
),
776796
],
777797
)
778798
def test_binary_operator_on_do(
@@ -797,6 +817,9 @@ def test_binary_operator_on_do(
797817
elif operation == "sub":
798818
do_1_y_modified = do_1 - do_2
799819
do_2_y_modified = do_2 - do_1
820+
elif operation == "mul":
821+
do_1_y_modified = do_1 * do_2
822+
do_2_y_modified = do_2 * do_1
800823

801824
assert np.allclose(do_1_y_modified.all_arrays, expected_do_1_all_arrays_with_y_modified)
802825
assert np.allclose(do_2_y_modified.all_arrays, expected_do_2_all_arrays_with_y_modified)

0 commit comments

Comments
 (0)