File tree Expand file tree Collapse file tree 2 files changed +13
-3
lines changed
keras/src/backend/openvino Expand file tree Collapse file tree 2 files changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -24,7 +24,6 @@ NumpyDtypeTest::test_exp2
24
24
NumpyDtypeTest::test_eye
25
25
NumpyDtypeTest::test_flip
26
26
NumpyDtypeTest::test_floor
27
- NumpyDtypeTest::test_hstack
28
27
NumpyDtypeTest::test_inner
29
28
NumpyDtypeTest::test_isfinite
30
29
NumpyDtypeTest::test_isinf
@@ -87,7 +86,6 @@ NumpyOneInputOpsCorrectnessTest::test_diagonal
87
86
NumpyOneInputOpsCorrectnessTest::test_exp2
88
87
NumpyOneInputOpsCorrectnessTest::test_flip
89
88
NumpyOneInputOpsCorrectnessTest::test_floor_divide
90
- NumpyOneInputOpsCorrectnessTest::test_hstack
91
89
NumpyOneInputOpsCorrectnessTest::test_imag
92
90
NumpyOneInputOpsCorrectnessTest::test_isfinite
93
91
NumpyOneInputOpsCorrectnessTest::test_isinf
Original file line number Diff line number Diff line change 12
12
from keras .src .backend .openvino .core import (
13
13
align_operand_types as _align_operand_types ,
14
14
)
15
+ from keras .src .backend .openvino .core import convert_to_tensor
15
16
from keras .src .backend .openvino .core import get_ov_output
16
17
from keras .src .backend .openvino .core import ov_to_keras_type
17
18
@@ -846,7 +847,18 @@ def greater_equal(x1, x2):
846
847
847
848
848
849
def hstack (xs ):
849
- raise NotImplementedError ("`hstack` is not supported with openvino backend" )
850
+ if not isinstance (xs , (list , tuple )):
851
+ xs = (xs ,)
852
+ elems = [convert_to_tensor (elem ) for elem in xs ]
853
+ element_type = elems [0 ].output .get_element_type ()
854
+ elems = [get_ov_output (elem , element_type ) for elem in elems ]
855
+ is_1d = elems and len (elems [0 ].get_partial_shape ().to_shape ()) == 1
856
+ axis = 0 if is_1d else 1
857
+ for i in range (1 , len (elems )):
858
+ elems [0 ], elems [i ] = _align_operand_types (
859
+ elems [0 ], elems [i ], "hstack()"
860
+ )
861
+ return OpenVINOKerasTensor (ov_opset .concat (elems , axis ).output (0 ))
850
862
851
863
852
864
def identity (n , dtype = None ):
You can’t perform that action at this time.
0 commit comments