Skip to content

Commit e3af2a6

Browse files
authored
Small ops fixes for Torch unit tests (#316)
* Add PyTorch numpy functionality * Add dtype conversion * Partial fix for PyTorch numpy tests * small logic fix * Revert numpy_test * Add tensor conversion to numpy * Fix some arithmetic tests * Fix some torch functions for numpy compatibility * Fix pytorch ops for numpy compatibility, add TODOs * Fix formatting * Implement nits and fix dtype standardization * Add pytest skipif decorator and fix nits * Fix formatting and rename dtypes map * Split tests by backend * Merge space * Fix dtype issues from new type checking * Implement torch.full and torch.full_like numpy compatible * Implements logspace and linspace with tensor support for start and stop * Replace len of shape with ndim * Fix formatting * Implement torch.trace * Implement eye k diagonal arg * Implement torch.tri * Fix formatting issues * Fix torch.take dimensionality * Add split functionality * Revert torch.eye implementation to prevent conflict * Implement all padding modes * Adds torch image resizing and torchvision dependency. * Fix conditional syntax * Make torchvision import optional * Partial implementation of torch RNN * Duplicate torch demo file * Small ops fixes for torch unit tests * delete nonfunctional gpu test file * Revert rnn and formatting fixes * Revert progbar * Fix formatting
1 parent 983635f commit e3af2a6

File tree

7 files changed

+14
-2
lines changed

7 files changed

+14
-2
lines changed

keras_core/activations/activations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def mish(x):
416416
417417
- [Misra, 2019](https://arxiv.org/abs/1908.08681)
418418
"""
419+
x = backend.convert_to_tensor(x)
419420
return Mish.static_call(x)
420421

421422

keras_core/backend/torch/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def logsumexp(x, axis=None, keepdims=False):
5858
max_x = torch.max(x)
5959
return torch.log(torch.sum(torch.exp(x - max_x))) + max_x
6060

61-
max_x = torch.max(x, dim=axis, keepdim=True).values
61+
max_x = torch.amax(x, dim=axis, keepdim=True)
6262
result = (
6363
torch.log(torch.sum(torch.exp(x - max_x), dim=axis, keepdim=True))
6464
+ max_x

keras_core/backend/torch/numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def append(
134134

135135
def arange(start, stop=None, step=1, dtype=None):
136136
dtype = to_torch_dtype(dtype)
137+
if stop is None:
138+
return torch.arange(end=start, dtype=dtype)
137139
return torch.arange(start, stop, step=step, dtype=dtype)
138140

139141

keras_core/constraints/constraints_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from keras_core import backend
34
from keras_core import constraints
45
from keras_core import testing
56

@@ -35,12 +36,14 @@ def test_non_neg(self):
3536
def test_unit_norm(self):
3637
constraint_fn = constraints.UnitNorm()
3738
output = constraint_fn(get_example_array())
39+
output = backend.convert_to_numpy(output)
3840
l2 = np.sqrt(np.sum(np.square(output), axis=0))
3941
self.assertAllClose(l2, 1.0)
4042

4143
def test_min_max_norm(self):
4244
constraint_fn = constraints.MinMaxNorm(min_value=0.2, max_value=0.5)
4345
output = constraint_fn(get_example_array())
46+
output = backend.convert_to_numpy(output)
4447
l2 = np.sqrt(np.sum(np.square(output), axis=0))
4548
self.assertFalse(l2[l2 < 0.2])
4649
self.assertFalse(l2[l2 > 0.5 + 1e-6])

keras_core/layers/preprocessing/normalization_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_normalization_adapt(self, input_type):
6868
layer.adapt(data)
6969
self.assertTrue(layer.built)
7070
output = layer(x)
71+
output = backend.convert_to_numpy(output)
7172
self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)
7273
self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)
7374

@@ -84,6 +85,7 @@ def test_normalization_adapt(self, input_type):
8485
layer.adapt(data)
8586
self.assertTrue(layer.built)
8687
output = layer(x)
88+
output = backend.convert_to_numpy(output)
8789
self.assertAllClose(np.var(output, axis=(0, 3)), 1.0, atol=1e-5)
8890
self.assertAllClose(np.mean(output, axis=(0, 3)), 0.0, atol=1e-5)
8991

keras_core/layers/preprocessing/random_brightness_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import tensorflow as tf
33

4+
from keras_core import backend
45
from keras_core import layers
56
from keras_core import testing
67

@@ -36,6 +37,7 @@ def test_output(self):
3637
inputs = np.random.randint(0, 255, size=(224, 224, 3))
3738
output = layer(inputs)
3839
diff = output - inputs
40+
diff = backend.convert_to_numpy(diff)
3941
self.assertTrue(np.amin(diff) >= 0)
4042
self.assertTrue(np.mean(diff) > 0)
4143

@@ -45,6 +47,7 @@ def test_output(self):
4547
inputs = np.random.randint(0, 255, size=(224, 224, 3))
4648
output = layer(inputs)
4749
diff = output - inputs
50+
diff = backend.convert_to_numpy(diff)
4851
self.assertTrue(np.amax(diff) <= 0)
4952
self.assertTrue(np.mean(diff) < 0)
5053

keras_core/operations/operation_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
class OpWithMultipleInputs(operation.Operation):
1111
def call(self, x, y, z=None):
12-
return x + 2 * y + 3 * z
12+
return 3 * z + x + 2 * y
13+
# Order of operations issue with torch backend
1314

1415
def compute_output_spec(self, x, y, z=None):
1516
return keras_tensor.KerasTensor(x.shape, x.dtype)

0 commit comments

Comments
 (0)