Skip to content

Commit 8d22f21

Browse files
authored
Add openvino backend support for numpy.hstack (#21257)
* Add openvino backend support for numpy.hstack * Handle non-tuple * Handle non-tuple 2 * Handle non-tuple 3 * Add support for numpy.hstack
1 parent 5f2793e commit 8d22f21

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ NumpyDtypeTest::test_exp2
2424
NumpyDtypeTest::test_eye
2525
NumpyDtypeTest::test_flip
2626
NumpyDtypeTest::test_floor
27-
NumpyDtypeTest::test_hstack
2827
NumpyDtypeTest::test_inner
2928
NumpyDtypeTest::test_isfinite
3029
NumpyDtypeTest::test_isinf
@@ -87,7 +86,6 @@ NumpyOneInputOpsCorrectnessTest::test_diagonal
8786
NumpyOneInputOpsCorrectnessTest::test_exp2
8887
NumpyOneInputOpsCorrectnessTest::test_flip
8988
NumpyOneInputOpsCorrectnessTest::test_floor_divide
90-
NumpyOneInputOpsCorrectnessTest::test_hstack
9189
NumpyOneInputOpsCorrectnessTest::test_imag
9290
NumpyOneInputOpsCorrectnessTest::test_isfinite
9391
NumpyOneInputOpsCorrectnessTest::test_isinf

keras/src/backend/openvino/numpy.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from keras.src.backend.openvino.core import (
1313
align_operand_types as _align_operand_types,
1414
)
15+
from keras.src.backend.openvino.core import convert_to_tensor
1516
from keras.src.backend.openvino.core import get_ov_output
1617
from keras.src.backend.openvino.core import ov_to_keras_type
1718

@@ -846,7 +847,18 @@ def greater_equal(x1, x2):
846847

847848

848849
def hstack(xs):
849-
raise NotImplementedError("`hstack` is not supported with openvino backend")
850+
if not isinstance(xs, (list, tuple)):
851+
xs = (xs,)
852+
elems = [convert_to_tensor(elem) for elem in xs]
853+
element_type = elems[0].output.get_element_type()
854+
elems = [get_ov_output(elem, element_type) for elem in elems]
855+
is_1d = elems and len(elems[0].get_partial_shape().to_shape()) == 1
856+
axis = 0 if is_1d else 1
857+
for i in range(1, len(elems)):
858+
elems[0], elems[i] = _align_operand_types(
859+
elems[0], elems[i], "hstack()"
860+
)
861+
return OpenVINOKerasTensor(ov_opset.concat(elems, axis).output(0))
850862

851863

852864
def identity(n, dtype=None):

0 commit comments

Comments
 (0)