Skip to content

Commit aef0743

Browse files
[OpenVINO Backend] support slice_update
1 parent 771b001 commit aef0743

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

keras/src/backend/openvino/core.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,83 @@ def slice(inputs, start_indices, shape):
657657

658658

659659
def slice_update(inputs, start_indices, updates):
660-
raise NotImplementedError(
661-
"`slice_update` is not supported with openvino backend"
660+
inputs = get_ov_output(inputs)
661+
if isinstance(start_indices, (list, np.ndarray)):
662+
start_indices = tuple(start_indices)
663+
assert isinstance(start_indices, tuple), (
664+
"`slice_update` is not supported by openvino backend"
665+
" for `start_indices` of type {}".format(type(start_indices))
662666
)
667+
processed_start_indices = []
668+
for idx in start_indices:
669+
val = get_ov_output(idx)
670+
val_type = val.get_element_type()
671+
if not val_type.is_integral():
672+
raise ValueError(
673+
"`slice` is not supported by OpenVINO backend "
674+
"for `start_indices` or `shape` with non-integer types"
675+
)
676+
if val_type != Type.i32:
677+
val = ov_opset.convert(val, Type.i32).output(0)
678+
if len(val.get_partial_shape()) == 0:
679+
val = ov_opset.unsqueeze(
680+
val, ov_opset.constant(0, Type.i32)
681+
).output(0)
682+
processed_start_indices.append(val)
683+
start_indices_tensor = ov_opset.concat(processed_start_indices, axis=0)
684+
685+
rank = len(updates.shape)
686+
ranges = []
687+
for dim in updates.shape:
688+
r = ov_opset.range(
689+
ov_opset.constant(0, Type.i32),
690+
ov_opset.constant(dim, Type.i32),
691+
ov_opset.constant(1, Type.i32),
692+
output_type=Type.i32,
693+
)
694+
ranges.append(r)
695+
696+
broadcasted_ranges = []
697+
for i, r in enumerate(ranges):
698+
shape = [1] * rank
699+
shape[i] = updates.shape[i]
700+
r_reshaped = ov_opset.reshape(
701+
r, ov_opset.constant(shape, Type.i32), special_zero=False
702+
).output(0)
703+
target_shape = ov_opset.constant(list(updates.shape), Type.i32)
704+
r_broadcasted = ov_opset.broadcast(r_reshaped, target_shape).output(0)
705+
broadcasted_ranges.append(r_broadcasted)
706+
707+
indices_stack = ov_opset.concat(broadcasted_ranges, axis=0).output(0)
708+
709+
num_updates = 1
710+
for dim in updates.shape:
711+
num_updates *= dim
712+
new_shape = ov_opset.constant([rank, num_updates], Type.i32)
713+
indices_reshaped = ov_opset.reshape(
714+
indices_stack, new_shape, special_zero=False
715+
).output(0)
716+
absolute_indices = ov_opset.transpose(
717+
indices_reshaped, ov_opset.constant([1, 0], Type.i32)
718+
).output(0)
719+
720+
start_indices_expanded = ov_opset.broadcast(
721+
start_indices_tensor, ov_opset.constant([num_updates, rank], Type.i32)
722+
).output(0)
723+
absolute_indices = ov_opset.add(
724+
absolute_indices, start_indices_expanded
725+
).output(0)
726+
727+
updates_tensor = get_ov_output(updates)
728+
updates_flat = ov_opset.reshape(
729+
updates_tensor,
730+
ov_opset.constant([num_updates], Type.i32),
731+
special_zero=False,
732+
).output(0)
733+
updated = ov_opset.scatter_nd_update(
734+
inputs, absolute_indices, updates_flat
735+
).output(0)
736+
return OpenVINOKerasTensor(updated)
663737

664738

665739
def while_loop(

0 commit comments

Comments
 (0)