Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras.src.ops.core import slice
from keras.src.ops.core import slice_update
from keras.src.ops.core import stop_gradient
from keras.src.ops.core import switch
from keras.src.ops.core import unstack
from keras.src.ops.core import vectorized_map
from keras.src.ops.core import while_loop
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras.src.ops.core import slice
from keras.src.ops.core import slice_update
from keras.src.ops.core import stop_gradient
from keras.src.ops.core import switch
from keras.src.ops.core import unstack
from keras.src.ops.core import vectorized_map
from keras.src.ops.core import while_loop
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ def slice_update(inputs, start_indices, updates):
return jax.lax.dynamic_update_slice(inputs, updates, start_indices)


def switch(index, branches, *operands):
return jax.lax.switch(index, branches, *operands)


def while_loop(
cond,
body,
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ def slice_update(inputs, start_indices, updates):
return inputs


def switch(index, branches, *operands):
index = convert_to_tensor(index, "int32")
index = np.clip(index, 0, len(branches) - 1)
return branches[index](*operands)


def while_loop(
cond,
body,
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,19 @@ def slice_update(inputs, start_indices, updates):
return dynamic_update_slice(inputs, updates, start_indices)


def switch(index, branches, *operands):
index = convert_to_tensor(index, "int32")
index = tf.clip_by_value(index, 0, len(branches) - 1)

# Workaround to deal with python closures. More details:
# https://github.com/tensorflow/tensorflow/issues/8776#issuecomment-311383887
def gen_fn(i):
return lambda: branches[i](*operands)

branch_fns = [gen_fn(i) for i in range(len(branches))]
return tf.switch_case(index, branch_fns)


def while_loop(
cond,
body,
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ def slice_update(inputs, start_indices, updates):
return outputs


def switch(index, branches, *operands):
index = convert_to_tensor(index, "int32")
index = torch.clamp(index, 0, len(branches) - 1)
return branches[index](*operands)


def while_loop(
cond,
body,
Expand Down
53 changes: 52 additions & 1 deletion keras/src/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def scan(f, init, xs, length=None):
>>> sum_fn = lambda c, x: (c + x, c + x)
>>> init = keras.ops.array(0)
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
>>> carry, result = ops.scan(sum_fn, init, xs)
>>> carry, result = keras.ops.scan(sum_fn, init, xs)
>>> carry
15
>>> result
Expand Down Expand Up @@ -315,6 +315,57 @@ def slice_update(inputs, start_indices, updates):
return backend.core.slice_update(inputs, start_indices, updates)


class Switch(Operation):
def call(self, index, branches, *operands):
return backend.core.switch(index, branches, *operands)

def compute_output_spec(self, index, branches, *operands):
# We use first branch for output_spec
spec = backend.compute_output_spec(branches[0], *operands)
return spec


@keras_export("keras.ops.switch")
def switch(index, branches, *operands):
"""Apply exactly one of the `branches` given by `index`.

If `index` is out of bounds, it is clamped to within bounds.

The semantics of `switch` are given roughly by this Python implementation:

```python
def switch(index, branches, *operands):
index = clamp(0, index, len(branches) - 1)
return branches[index](*operands)
```

Args:
index: An integer scalar indicating which branch function to apply.
branches: A sequence of functions to be applied based on `index`.
operands: Inputs to whichever branch is applied.

Returns:
The outputs of `branch(*operands)` for the branch that was selected
based on `index`.

Examples:

>>> add_fn = lambda x, y: x + y
>>> substract_fn = lambda x, y: x - y
>>> x = keras.ops.array(2.0)
>>> y = keras.ops.array(0.5)
>>> branches = [add_fn, substract_fn]
>>> keras.ops.switch(0, branches, x, y)
2.5

>>> keras.ops.switch(1, branches, x, y)
1.5
"""
if any_symbolic_tensors(operands):
return Switch().symbolic_call(index, branches, *operands)
return backend.core.switch(index, branches, *operands)


class WhileLoop(Operation):
def __init__(self, cond, body, maximum_iterations):
super().__init__()
Expand Down
46 changes: 46 additions & 0 deletions keras/src/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ def test_slice_update(self):
core.slice_update(inputs, start_indices, updates).shape, (4, 4, 4)
)

def test_switch(self):
def fn(x, y):
return x[:, 0], y[0, :]

index = KerasTensor(())
x = KerasTensor((5, 2))
y = KerasTensor((5, 2))
self.assertEqual(core.switch(index, [fn], x, y)[0].shape, (5,))
self.assertEqual(core.switch(index, [fn], x, y)[1].shape, (2,))

def test_fori_loop(self):
def body_fun(i, x):
return x + i
Expand Down Expand Up @@ -303,6 +313,23 @@ def test_slice_update(self):
outputs = core.slice_update(inputs, start_indices, updates)
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))

def test_switch(self):
def fn1(x, y):
return x + y

def fn2(x, y):
return x - y

x = np.random.rand(2, 3, 4).astype("float32")
y = np.random.rand(2, 3, 4).astype("float32")
branches = [fn1, fn2]
self.assertAllClose(core.switch(0, branches, x, y), x + y)
self.assertAllClose(core.switch(1, branches, x, y), x - y)

# Test out-of-bound index
self.assertAllClose(core.switch(-100, branches, x, y), x + y)
self.assertAllClose(core.switch(100, branches, x, y), x - y)

@parameterized.named_parameters(
[
{
Expand Down Expand Up @@ -801,6 +828,25 @@ def test_slice_update_basic_call(self):
expected_output = np.array([[1, 2, 3], [4, 10, 11], [7, 12, 13]])
self.assertAllClose(core.convert_to_numpy(result), expected_output)

def test_switch_basic_call(self):
def fn1(x, y):
return x + y

def fn2(x, y):
return x - y

x = np.random.rand(2, 3, 4).astype("float32")
y = np.random.rand(2, 3, 4).astype("float32")
branches = [fn1, fn2]
switch_op = core.Switch()
index = 0
outputs = switch_op.call(index, branches, x, y)
self.assertAllClose(outputs, x + y)

index = 1
outputs = switch_op.call(index, branches, x, y)
self.assertAllClose(outputs, x - y)

def test_while_loop_basic_functionality(self):
# Loop condition: continue if i < 5
def cond(i):
Expand Down