@@ -657,9 +657,83 @@ def slice(inputs, start_indices, shape):
657
657
658
658
659
659
def slice_update (inputs , start_indices , updates ):
660
- raise NotImplementedError (
661
- "`slice_update` is not supported with openvino backend"
660
+ inputs = get_ov_output (inputs )
661
+ if isinstance (start_indices , (list , np .ndarray )):
662
+ start_indices = tuple (start_indices )
663
+ assert isinstance (start_indices , tuple ), (
664
+ "`slice_update` is not supported by openvino backend"
665
+ " for `start_indices` of type {}" .format (type (start_indices ))
662
666
)
667
+ processed_start_indices = []
668
+ for idx in start_indices :
669
+ val = get_ov_output (idx )
670
+ val_type = val .get_element_type ()
671
+ if not val_type .is_integral ():
672
+ raise ValueError (
673
+ "`slice` is not supported by OpenVINO backend "
674
+ "for `start_indices` or `shape` with non-integer types"
675
+ )
676
+ if val_type != Type .i32 :
677
+ val = ov_opset .convert (val , Type .i32 ).output (0 )
678
+ if len (val .get_partial_shape ()) == 0 :
679
+ val = ov_opset .unsqueeze (
680
+ val , ov_opset .constant (0 , Type .i32 )
681
+ ).output (0 )
682
+ processed_start_indices .append (val )
683
+ start_indices_tensor = ov_opset .concat (processed_start_indices , axis = 0 )
684
+
685
+ rank = len (updates .shape )
686
+ ranges = []
687
+ for dim in updates .shape :
688
+ r = ov_opset .range (
689
+ ov_opset .constant (0 , Type .i32 ),
690
+ ov_opset .constant (dim , Type .i32 ),
691
+ ov_opset .constant (1 , Type .i32 ),
692
+ output_type = Type .i32 ,
693
+ )
694
+ ranges .append (r )
695
+
696
+ broadcasted_ranges = []
697
+ for i , r in enumerate (ranges ):
698
+ shape = [1 ] * rank
699
+ shape [i ] = updates .shape [i ]
700
+ r_reshaped = ov_opset .reshape (
701
+ r , ov_opset .constant (shape , Type .i32 ), special_zero = False
702
+ ).output (0 )
703
+ target_shape = ov_opset .constant (list (updates .shape ), Type .i32 )
704
+ r_broadcasted = ov_opset .broadcast (r_reshaped , target_shape ).output (0 )
705
+ broadcasted_ranges .append (r_broadcasted )
706
+
707
+ indices_stack = ov_opset .concat (broadcasted_ranges , axis = 0 ).output (0 )
708
+
709
+ num_updates = 1
710
+ for dim in updates .shape :
711
+ num_updates *= dim
712
+ new_shape = ov_opset .constant ([rank , num_updates ], Type .i32 )
713
+ indices_reshaped = ov_opset .reshape (
714
+ indices_stack , new_shape , special_zero = False
715
+ ).output (0 )
716
+ absolute_indices = ov_opset .transpose (
717
+ indices_reshaped , ov_opset .constant ([1 , 0 ], Type .i32 )
718
+ ).output (0 )
719
+
720
+ start_indices_expanded = ov_opset .broadcast (
721
+ start_indices_tensor , ov_opset .constant ([num_updates , rank ], Type .i32 )
722
+ ).output (0 )
723
+ absolute_indices = ov_opset .add (
724
+ absolute_indices , start_indices_expanded
725
+ ).output (0 )
726
+
727
+ updates_tensor = get_ov_output (updates )
728
+ updates_flat = ov_opset .reshape (
729
+ updates_tensor ,
730
+ ov_opset .constant ([num_updates ], Type .i32 ),
731
+ special_zero = False ,
732
+ ).output (0 )
733
+ updated = ov_opset .scatter_nd_update (
734
+ inputs , absolute_indices , updates_flat
735
+ ).output (0 )
736
+ return OpenVINOKerasTensor (updated )
663
737
664
738
665
739
def while_loop (
0 commit comments