Skip to content

Commit

Permalink
Update contrast_mask to be able to handle timesteps
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJudge94 committed Aug 12, 2019
1 parent 59e4d9a commit 3d21a63
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
9 changes: 5 additions & 4 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,12 @@ def test_extract_signals_pca(self):
def test_contrast_mask(self):

# The first time idx - User selected in the xp.XanesFrameset.frames() method
frames = self.coins()[0]
frames = self.coins()

print(np.shape(frames))
# Set up the initial numpy frames
mean_frames_image = np.mean(frames, axis=0)
single_frame_image = frames[10]
mean_frames_image = np.mean(frames, axis=(0, 1))
single_frame_image = frames[0][10]

# Check difference input values
sensitivity_vals = [1, 0.5, 1.8]
Expand Down Expand Up @@ -520,7 +521,7 @@ def test_contrast_mask(self):
single_frames_check.append(contrast_mask(frames=frames,
sensitivity=sensitivity,
min_size=min_size,
frame_idx=10))
frame_idx=(0, 10)))

# Check all the values
np.testing.assert_equal(all_masks[0], mean_frames_check[0])
Expand Down
9 changes: 5 additions & 4 deletions xanespy/xanes_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ def contrast_mask(frames: np.ndarray, sensitivity: float = 1, min_size=0, frame_
min_size : float, optional
Objects below this size (in pixels) will be
removed. Passing zero (default) will result in no effect.
frame_idx : str, int, optional
frame_idx : str, tuple(time_step_index, energy_index), optional
Allows the user to select which image to
Returns:
Expand All @@ -1129,10 +1129,11 @@ def contrast_mask(frames: np.ndarray, sensitivity: float = 1, min_size=0, frame_

# Obtain the correct image set
if frame_idx == 'mean':
image = np.mean(frames, axis=0)
image = np.mean(frames, axis=(0, 1))

elif isinstance(frame_idx, int):
image = frames[frame_idx]
elif isinstance(frame_idx, tuple):
time_idx, energy_idx = frame_idx
image = frames[time_idx][energy_idx]

# Determining threshold
img_bottom = image.min()
Expand Down

0 comments on commit 3d21a63

Please sign in to comment.