Skip to content

Commit 5793f4a

Browse files
committed
interpolate between points
1 parent f905e76 commit 5793f4a

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

wordllama/algorithms/find_local_minima.pyx

+14-5
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,24 @@ cdef tuple _find_local_minima_impl(DTYPE_t[:] y, int window_size, int poly_order
122122
cdef list minima_values = []
123123
cdef DTYPE_t[:] dy_view = dy
124124
cdef DTYPE_t[:] ddy_view = ddy
125-
125+
126126
cdef int i
127+
cdef DTYPE_t interp_weight
127128

128129
# Identify minima by checking first derivative change and positive second derivative
129130
for i in range(1, n - 1):
130-
if dy_view[i-1] < 0 < dy_view[i] and ddy_view[i] > 0:
131-
minima_indices.append(i)
132-
minima_values.append(y[i])
133-
131+
if dy_view[i] < 0 < dy_view[i + 1] and ddy_view[i] > 0:
132+
# Calculate the weight of the zero crossing between i and i+1
133+
interp_weight = -dy_view[i] / (dy_view[i + 1] - dy_view[i])
134+
135+
# Determine if the zero crossing is closer to i or i+1
136+
if interp_weight < 0.5:
137+
minima_indices.append(i)
138+
minima_values.append(y[i])
139+
else:
140+
minima_indices.append(i + 1)
141+
minima_values.append(y[i + 1])
142+
134143
# Convert the lists to NumPy arrays
135144
cdef np.ndarray[np.int32_t, ndim=1] minima_idx_np = np.array(minima_indices, dtype=np.int32)
136145
cdef np.ndarray[DTYPE_t, ndim=1] minima_values_np = np.array(minima_values, dtype=np.float32)

0 commit comments

Comments
 (0)