Skip to content

Commit f0131f0

Browse files
[OpenVINO backend] support top_k
1 parent e233825 commit f0131f0

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,52 @@ CoreOpsCorrectnessTest::test_slice_update
186186
CoreOpsCorrectnessTest::test_switch
187187
CoreOpsCorrectnessTest::test_unstack
188188
CoreOpsCorrectnessTest::test_vectorized_map
189+
ExtractSequencesOpTest::test_extract_sequences_call
190+
InTopKTest::test_in_top_k_call
191+
MathOpsCorrectnessTest::test_erfinv_operation_basic
192+
MathOpsCorrectnessTest::test_erfinv_operation_dtype
193+
MathOpsCorrectnessTest::test_erfinv_operation_edge_cases
194+
MathOpsCorrectnessTest::test_extract_sequences
195+
MathOpsCorrectnessTest::test_fft
196+
MathOpsCorrectnessTest::test_fft2
197+
MathOpsCorrectnessTest::test_ifft2
198+
MathOpsCorrectnessTest::test_in_top_k
199+
MathOpsCorrectnessTest::test_irfft0
200+
MathOpsCorrectnessTest::test_irfft1
201+
MathOpsCorrectnessTest::test_irfft2
202+
MathOpsCorrectnessTest::test_istft0
203+
MathOpsCorrectnessTest::test_istft1
204+
MathOpsCorrectnessTest::test_istft2
205+
MathOpsCorrectnessTest::test_istft3
206+
MathOpsCorrectnessTest::test_istft4
207+
MathOpsCorrectnessTest::test_istft5
208+
MathOpsCorrectnessTest::test_istft6
209+
MathOpsCorrectnessTest::test_logdet
210+
MathOpsCorrectnessTest::test_logsumexp
211+
MathOpsCorrectnessTest::test_rfft0
212+
MathOpsCorrectnessTest::test_rfft1
213+
MathOpsCorrectnessTest::test_rfft2
214+
MathOpsCorrectnessTest::test_segment_reduce0
215+
MathOpsCorrectnessTest::test_segment_reduce1
216+
MathOpsCorrectnessTest::test_segment_reduce2
217+
MathOpsCorrectnessTest::test_segment_reduce3
218+
MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments0
219+
MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments1
220+
MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments2
221+
MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments3
222+
MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments4
223+
MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments5
224+
MathOpsCorrectnessTest::test_stft0
225+
MathOpsCorrectnessTest::test_stft1
226+
MathOpsCorrectnessTest::test_stft2
227+
MathOpsCorrectnessTest::test_stft3
228+
MathOpsCorrectnessTest::test_stft4
229+
MathOpsCorrectnessTest::test_stft5
230+
MathOpsCorrectnessTest::test_stft6
231+
SegmentSumTest::test_segment_sum_call
232+
SegmentMaxTest::test_segment_max_call
233+
TestMathErrors::test_invalid_fft_length
234+
TestMathErrors::test_istft_invalid_window_shape_2D_inputs
235+
TestMathErrors::test_stft_invalid_input_type
236+
TestMathErrors::test_stft_invalid_window
237+
TestMathErrors::test_stft_invalid_window_shape

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ keras/src/metrics
2929
keras/src/models
3030
keras/src/ops/image_test.py
3131
keras/src/ops/linalg_test.py
32-
keras/src/ops/math_test.py
3332
keras/src/ops/nn_test.py
3433
keras/src/optimizers
3534
keras/src/quantizers

keras/src/backend/openvino/math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False):
1818

1919

2020
def top_k(x, k, sorted=True):
21-
raise NotImplementedError("`top_k` is not supported with openvino backend")
21+
x = get_ov_output(x)
22+
k_tensor = ov_opset.constant(k, dtype=Type.i32)
23+
axis = -1
24+
sort_type = "value" if sorted else "none"
25+
topk_node = ov_opset.topk(x, k_tensor, axis, "max", sort_type)
26+
values = topk_node.output(0)
27+
indices = topk_node.output(1)
28+
return OpenVINOKerasTensor(values), OpenVINOKerasTensor(indices)
2229

2330

2431
def in_top_k(targets, predictions, k):

0 commit comments

Comments
 (0)