Skip to content

Commit d71a7fe

Browse files
authored
Implement bartlett function in keras.ops (#21214)
* Add bartlett for ops * Update excluded_concrete_tests.txt
1 parent c0017df commit d71a7fe

File tree

11 files changed

+92
-0
lines changed

11 files changed

+92
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
from keras.src.ops.numpy import argsort as argsort
136136
from keras.src.ops.numpy import array as array
137137
from keras.src.ops.numpy import average as average
138+
from keras.src.ops.numpy import bartlett as bartlett
138139
from keras.src.ops.numpy import bincount as bincount
139140
from keras.src.ops.numpy import bitwise_and as bitwise_and
140141
from keras.src.ops.numpy import bitwise_invert as bitwise_invert

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
2929
from keras.src.ops.numpy import average as average
30+
from keras.src.ops.numpy import bartlett as bartlett
3031
from keras.src.ops.numpy import bincount as bincount
3132
from keras.src.ops.numpy import bitwise_and as bitwise_and
3233
from keras.src.ops.numpy import bitwise_invert as bitwise_invert

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
from keras.src.ops.numpy import argsort as argsort
136136
from keras.src.ops.numpy import array as array
137137
from keras.src.ops.numpy import average as average
138+
from keras.src.ops.numpy import bartlett as bartlett
138139
from keras.src.ops.numpy import bincount as bincount
139140
from keras.src.ops.numpy import bitwise_and as bitwise_and
140141
from keras.src.ops.numpy import bitwise_invert as bitwise_invert

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
2929
from keras.src.ops.numpy import average as average
30+
from keras.src.ops.numpy import bartlett as bartlett
3031
from keras.src.ops.numpy import bincount as bincount
3132
from keras.src.ops.numpy import bitwise_and as bitwise_and
3233
from keras.src.ops.numpy import bitwise_invert as bitwise_invert

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def add(x1, x2):
3737
return jnp.add(x1, x2)
3838

3939

40+
def bartlett(x):
41+
x = convert_to_tensor(x)
42+
return jnp.bartlett(x)
43+
44+
4045
def bincount(x, weights=None, minlength=0, sparse=False):
4146
# Note: bincount is never tracable / jittable because the output shape
4247
# depends on the values in x.

keras/src/backend/numpy/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ def average(x, axis=None, weights=None):
305305
return np.average(x, weights=weights, axis=axis)
306306

307307

308+
def bartlett(x):
309+
x = convert_to_tensor(x)
310+
return np.bartlett(x).astype(config.floatx())
311+
312+
308313
def bincount(x, weights=None, minlength=0, sparse=False):
309314
if sparse:
310315
raise ValueError("Unsupported value `sparse=True` with numpy backend")

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ NumpyDtypeTest::test_angle
88
NumpyDtypeTest::test_any
99
NumpyDtypeTest::test_argpartition
1010
NumpyDtypeTest::test_array
11+
NumpyDtypeTest::test_bartlett
1112
NumpyDtypeTest::test_bitwise
1213
NumpyDtypeTest::test_ceil
1314
NumpyDtypeTest::test_concatenate
@@ -75,6 +76,7 @@ NumpyOneInputOpsCorrectnessTest::test_angle
7576
NumpyOneInputOpsCorrectnessTest::test_any
7677
NumpyOneInputOpsCorrectnessTest::test_argpartition
7778
NumpyOneInputOpsCorrectnessTest::test_array
79+
NumpyOneInputOpsCorrectnessTest::test_bartlett
7880
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
7981
NumpyOneInputOpsCorrectnessTest::test_conj
8082
NumpyOneInputOpsCorrectnessTest::test_correlate
@@ -151,4 +153,5 @@ NumpyTwoInputOpsCorrectnessTest::test_tensordot
151153
NumpyTwoInputOpsCorrectnessTest::test_vdot
152154
NumpyTwoInputOpsCorrectnessTest::test_where
153155
NumpyOneInputOpsDynamicShapeTest::test_angle
156+
NumpyOneInputOpsDynamicShapeTest::test_bartlett
154157
NumpyOneInputOpsStaticShapeTest::test_angle

keras/src/backend/tensorflow/numpy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,21 @@ def add(x1, x2):
131131
return tf.add(x1, x2)
132132

133133

134+
def bartlett(x):
135+
x = convert_to_tensor(x, dtype=config.floatx())
136+
if x == 0:
137+
return tf.constant([])
138+
if x == 1:
139+
return tf.ones([1])
140+
141+
n = tf.range(x)
142+
half = (x - 1) / 2
143+
144+
window = tf.where(n <= half, 2.0 * n / (x - 1), 2.0 - 2.0 * n / (x - 1))
145+
146+
return window
147+
148+
134149
def bincount(x, weights=None, minlength=0, sparse=False):
135150
x = convert_to_tensor(x)
136151
dtypes_to_resolve = [x.dtype]

keras/src/backend/torch/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,11 @@ def average(x, axis=None, weights=None):
430430
return torch.mean(x, axis)
431431

432432

433+
def bartlett(x):
434+
x = convert_to_tensor(x)
435+
return torch.signal.windows.bartlett(x)
436+
437+
433438
def bincount(x, weights=None, minlength=0, sparse=False):
434439
if sparse:
435440
raise ValueError("Unsupported value `sparse=True` with torch backend")

keras/src/ops/numpy.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,35 @@ def average(x, axis=None, weights=None):
12101210
return backend.numpy.average(x, weights=weights, axis=axis)
12111211

12121212

1213+
class Bartlett(Operation):
1214+
def call(self, x):
1215+
return backend.numpy.bartlett(x)
1216+
1217+
def compute_output_spec(self, x):
1218+
return KerasTensor(x.shape, dtype=backend.floatx())
1219+
1220+
1221+
@keras_export(["keras.ops.bartlett", "keras.ops.numpy.bartlett"])
1222+
def bartlett(x):
1223+
"""Bartlett window function.
1224+
The Bartlett window is a triangular window that rises then falls linearly.
1225+
1226+
Args:
1227+
x: Scalar or 1D Tensor. Window length.
1228+
1229+
Returns:
1230+
A 1D tensor containing the Bartlett window values.
1231+
1232+
Example:
1233+
>>> x = keras.ops.convert_to_tensor(5)
1234+
>>> keras.ops.bartlett(x)
1235+
array([0. , 0.5, 1. , 0.5, 0. ], dtype=float32)
1236+
"""
1237+
if any_symbolic_tensors((x,)):
1238+
return Bartlett().symbolic_call(x)
1239+
return backend.numpy.bartlett(x)
1240+
1241+
12131242
class Bincount(Operation):
12141243
def __init__(self, weights=None, minlength=0, sparse=False):
12151244
super().__init__()

0 commit comments

Comments
 (0)