Skip to content

Commit e8fe5e2

Browse files
authored
[sml] optimize preprocessing by eliminating unnecessary where function (#608)
# Pull Request ## What problem does this PR solve? Small optimization in sml/preprocessing Optimize where(expression, 1, 0) to expression, which eliminates unnecessary where function.
1 parent 60780e5 commit e8fe5e2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

sml/preprocessing/preprocessing.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def label_binarize(y, *, classes, n_classes, neg_label=0, pos_label=1):
4242
ndarray of shape (n_samples, n_classes)
4343
Shape will be (n_samples, 1) for binary problems.
4444
"""
45-
eq_func = lambda x: jnp.where(classes == x, 1, 0)
46-
result = jax.vmap(eq_func)(y)
45+
eq_func = lambda x: classes == x
46+
result = jax.vmap(eq_func)(y).astype(jnp.int_)
4747

4848
if neg_label != 0 or pos_label != 1:
4949
result = jnp.where(result, pos_label, neg_label)
@@ -203,7 +203,7 @@ def binarize(X, *, threshold=0.0):
203203
Feature values below or equal to this are replaced by 0, above it by 1.
204204
205205
"""
206-
return jnp.where(X > threshold, 1, 0)
206+
return (X > threshold).astype(jnp.int_)
207207

208208

209209
class Binarizer:
@@ -626,7 +626,7 @@ def _weighted_percentile(x, q, w):
626626
adjusted_percentile = q / 100 * weight_cdf[-1]
627627

628628
def searchsorted_element(x_inner):
629-
encoding = jnp.where(x_inner >= weight_cdf[0:-1, 0], 1, 0)
629+
encoding = x_inner >= weight_cdf[0:-1, 0]
630630
return jnp.sum(encoding)
631631

632632
percentile_idx = jax.vmap(searchsorted_element)(adjusted_percentile)
@@ -1112,7 +1112,7 @@ def transform(self, X):
11121112

11131113
def compute_row(bin, x, c):
11141114
def compute_element(x):
1115-
encoding = jnp.where(x >= bin[1:-1], 1, 0)
1115+
encoding = x >= bin[1:-1]
11161116
return jnp.clip(jnp.sum(encoding), 0, c - 2)
11171117

11181118
return jax.vmap(compute_element)(x)
@@ -1125,7 +1125,7 @@ def compute_element(x):
11251125

11261126
def compute_row(bin, x):
11271127
def compute_element(x):
1128-
encoding = jnp.where(x >= bin[1:-1], 1, 0)
1128+
encoding = x >= bin[1:-1]
11291129
return jnp.sum(encoding)
11301130

11311131
return jax.vmap(compute_element)(x)

0 commit comments

Comments
 (0)