Skip to content

Commit

Permalink
Fix loss + more tests
Browse files Browse the repository at this point in the history
Signed-off-by: abigailt <abigailt@il.ibm.com>
  • Loading branch information
abigailgold committed Apr 25, 2022
1 parent 3d311e4 commit a16aa55
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 98 deletions.
94 changes: 19 additions & 75 deletions art/estimators/regression/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements the classifier `KerasClassifier` for Keras models.
This module implements the regressor `KerasRegressor` for Keras models.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

Expand Down Expand Up @@ -87,8 +87,8 @@ def __init__(
maximum values allowed for features. If floats are provided, these will be used as the range of all
features. If arrays are provided, each value will be considered the bound for a feature, thus
the shape of clip values needs to match the total number of features.
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
:param preprocessing_defences: Preprocessing defence(s) to be applied by the regressor.
:param postprocessing_defences: Postprocessing defence(s) to be applied by the regressor.
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
used for data preprocessing. The first value will be subtracted from the input. The input will then
be divided by the second one.
Expand Down Expand Up @@ -127,7 +127,7 @@ def _initialize_params(
output_layer: int,
):
"""
Initialize most parameters of the classifier. This is a convenience function called by `__init__` and
Initialize most parameters of the regressor. This is a convenience function called by `__init__` and
`__setstate__` to avoid code duplication.
:param model: Keras model
Expand Down Expand Up @@ -169,7 +169,7 @@ def _initialize_params(

self._input_shape = k.int_shape(self._input)[1:]
logger.debug(
"Inferred %s as input shape for Keras classifier.",
"Inferred %s as input shape for Keras regressor.",
str(self.input_shape),
)

Expand All @@ -180,7 +180,16 @@ def _initialize_params(
else:
self._orig_loss = self._model.loss
if isinstance(self._model.loss, six.string_types):
loss_function = getattr(k, self._model.loss)
if self._model.loss in[
"mean_squared_error",
"mean_absolute_error",
"mean_absolute_percentage_error",
"mean_squared_logarithmic_error",
"cosine_similarity"
]:
loss_function = getattr(keras.losses, self._model.loss)
else:
loss_function = getattr(k, self._model.loss)

elif "__name__" in dir(self._model.loss) and self._model.loss.__name__ in [
"mean_squared_error",
Expand All @@ -205,72 +214,7 @@ def _initialize_params(
else:
loss_function = getattr(k, self._model.loss.__name__)

# Check if loss function is an instance of loss function generator, the try is required because some of the
# modules are not available in older Keras versions
# try:
# flag_is_instance = isinstance(
# loss_function,
# (
# keras.losses.CategoricalHinge,
# keras.losses.CategoricalCrossentropy,
# keras.losses.BinaryCrossentropy,
# keras.losses.KLDivergence,
# ),
# )
# except AttributeError: # pragma: no cover
# flag_is_instance = False
#
# # Check if the labels have to be reduced to index labels and create placeholder for labels
# if (
# "__name__" in dir(loss_function)
# and loss_function.__name__
# in [
# "categorical_hinge",
# "categorical_crossentropy",
# "binary_crossentropy",
# "kullback_leibler_divergence",
# ]
# ) or flag_is_instance:
# self._reduce_labels = False
# label_ph = k.placeholder(shape=self._output.shape)
# elif (
# "__name__" in dir(loss_function) and loss_function.__name__ in ["sparse_categorical_crossentropy"]
# ) or isinstance(loss_function, keras.losses.SparseCategoricalCrossentropy):
# self._reduce_labels = True
# label_ph = k.placeholder(
# shape=[
# None,
# ]
# )
# else: # pragma: no cover
# raise ValueError("Loss function not recognised.")

label_ph = k.placeholder(shape=self._output.shape)

# Define the loss using the loss function
# if "__name__" in dir(loss_function,) and loss_function.__name__ in [
# "categorical_crossentropy",
# "sparse_categorical_crossentropy",
# "binary_crossentropy",
# ]:
# loss_ = loss_function(label_ph, self._output, from_logits=self._use_logits)
#
# elif "__name__" in dir(loss_function) and loss_function.__name__ in [
# "categorical_hinge",
# "kullback_leibler_divergence",
# ]:
# loss_ = loss_function(label_ph, self._output)
#
# elif isinstance(
# loss_function,
# (
# keras.losses.CategoricalHinge,
# keras.losses.CategoricalCrossentropy,
# keras.losses.SparseCategoricalCrossentropy,
# keras.losses.KLDivergence,
# keras.losses.BinaryCrossentropy,
# ),
# ):
loss_ = loss_function(label_ph, self._output)

# Define loss gradients
Expand Down Expand Up @@ -440,7 +384,7 @@ def predict( # pylint: disable=W0221

def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
"""
Fit the classifier on the training set `(x, y)`.
Fit the regressor on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
Expand Down Expand Up @@ -612,7 +556,7 @@ def _get_layers(self) -> List[str]:
from keras.engine.topology import InputLayer # pylint: disable=E0611

layer_names = [layer.name for layer in self._model.layers[:-1] if not isinstance(layer, InputLayer)]
logger.info("Inferred %i hidden layers on Keras classifier.", len(layer_names))
logger.info("Inferred %i hidden layers on Keras regressor.", len(layer_names))

return layer_names

Expand All @@ -637,7 +581,7 @@ def save(self, filename: str, path: Optional[str] = None) -> None:

def __getstate__(self) -> Dict[str, Any]:
"""
Use to ensure `KerasClassifier` can be pickled.
Use to ensure `KerasRegressor` can be pickled.
:return: State dictionary with instance parameters.
"""
Expand Down Expand Up @@ -670,7 +614,7 @@ def __getstate__(self) -> Dict[str, Any]:

def __setstate__(self, state: Dict[str, Any]) -> None:
"""
Use to ensure `KerasClassifier` can be unpickled.
Use to ensure `KerasRegressor` can be unpickled.
:param state: State dictionary with instance parameters to restore.
"""
Expand Down
172 changes: 151 additions & 21 deletions tests/estimators/regression/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,184 @@
import unittest
import numpy as np

import keras
import tensorflow as tf

from art.estimators.regression.keras import KerasRegressor

from tests.utils import TestBase, master_seed
from tests.utils import TestBase, master_seed, get_tabular_regressor_kr

logger = logging.getLogger(__name__)


class TestScikitlearnDecisionTreeRegressor(TestBase):
class TestKerasRegressor(TestBase):
@classmethod
def setUpClass(cls):
master_seed(seed=1234, set_tensorflow=True)
super().setUpClass()

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

cls.art_model = get_tabular_regressor_kr()

def test_type(self):
with self.assertRaises(TypeError):
KerasRegressor(model="model")

def test_predict(self):
y_predicted = self.art_model.predict(self.x_test_diabetes[:4])
y_expected = np.asarray([[24.9], [52.7], [30.4], [68.1]])
np.testing.assert_array_almost_equal(y_predicted, y_expected, decimal=1)

def test_save(self):
self.art_model.save(filename="test.file", path=None)
self.art_model.save(filename="test.file", path="./")

def test_input_shape(self):
np.testing.assert_equal(self.art_model.input_shape, (10,))

def test_input_layer(self):
np.testing.assert_equal(isinstance(self.art_model.input_layer, int), True)

def test_output_layer(self):
np.testing.assert_equal(isinstance(self.art_model.output_layer, int), True)

def test_compute_loss(self):
test_loss = self.art_model.compute_loss(self.x_test_diabetes[:4], self.y_test_diabetes[:4])
loss_expected = [6089.8, 2746.3, 5306.8, 1554.9]
np.testing.assert_array_almost_equal(test_loss, loss_expected, decimal=1)

def test_loss_gradient(self):
grad = self.art_model.loss_gradient(self.x_test_diabetes[:4], self.y_test_diabetes[:4])
grad_expected = [-333.9, 586.4, -1190.9, -123.9, -1206.2, -883.7, 295.9, -830.5, -1333.1, -553.8]
np.testing.assert_array_almost_equal(grad[0], grad_expected, decimal=1)

def test_get_activations(self):
act = self.art_model.get_activations(self.x_test_diabetes[:4], 1)
act_expected = [0, 0, 0, 7.8, 8.5, 0, 5.6, 0, 6.6, 5.8]
np.testing.assert_array_almost_equal(act[0], act_expected, decimal=1)


class TestKerasRegressorClass(TestBase):
@classmethod
def setUpClass(cls):
master_seed(seed=1234, set_tensorflow=True)
super().setUpClass()

import tensorflow as tf
import tensorflow.keras as keras
tf.compat.v1.disable_eager_execution()

class TestModel(tf.keras.Model):

def __init__(self):
super().__init__()
self.dense1 = keras.layers.Dense(10, activation=tf.nn.relu)
self.dense2 = keras.layers.Dense(100, activation=tf.nn.relu)
self.dense3 = keras.layers.Dense(10, activation=tf.nn.relu)
self.dense4 = keras.layers.Dense(1)

def call(self, inputs):
x = self.dense1(inputs)
return self.dense4(self.dense3(self.dense2(x)))

cls.keras_model = TestModel()
cls.keras_model.compile(loss=keras.losses.CosineSimilarity(axis=-1, reduction="auto", name="cosine_similarity"),
optimizer=keras.optimizers.Adam(learning_rate=0.01),
metrics=["accuracy"])
cls.keras_model.fit(cls.x_train_diabetes, cls.y_train_diabetes)

cls.art_model = KerasRegressor(model=cls.keras_model)

def test_type(self):
with self.assertRaises(TypeError):
KerasRegressor(model="model")

def test_predict(self):
y_predicted = self.art_model.predict(self.x_test_diabetes[:4])
np.testing.assert_equal(len(np.unique(y_predicted)), 4)

def test_save(self):
self.art_model.save(filename="test.file", path=None)
self.art_model.save(filename="test.file", path="./")

def test_input_shape(self):
np.testing.assert_equal(self.art_model.input_shape, (10,))

def test_input_layer(self):
np.testing.assert_equal(isinstance(self.art_model.input_layer, int), True)

def test_output_layer(self):
np.testing.assert_equal(isinstance(self.art_model.output_layer, int), True)

def test_compute_loss(self):
test_loss = self.art_model.compute_loss(self.x_test_diabetes[:4], self.y_test_diabetes[:4].astype(np.float32))
# cosine similarity works on vectors, so it returns the same value for each sample
np.testing.assert_equal(len(np.unique(test_loss)), 1)

def test_loss_gradient(self):
grad = self.art_model.loss_gradient(self.x_test_diabetes[:4], self.y_test_diabetes[:4])
# cosine similarity works on vectors, so it returns the same value for each sample
np.testing.assert_equal(len(np.unique(grad[0])), 1)


class TestKerasRegressorFunctional(TestBase):
@classmethod
def setUpClass(cls):
master_seed(seed=1234)
master_seed(seed=1234, set_tensorflow=True)
super().setUpClass()

import tensorflow as tf
import keras
from keras.models import Model
tf.compat.v1.disable_eager_execution()

model = keras.models.Sequential()
# model.add(keras.Input(shape=(10,)))
model.add(keras.layers.Dense(10, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(10, activation="relu"))
model.add(keras.layers.Dense(1))
def functional():
in_layer = keras.layers.Input(shape=(10,))
layer = keras.layers.Dense(100, activation=tf.nn.relu)(in_layer)
layer = keras.layers.Dense(10, activation=tf.nn.relu)(layer)
out_layer = keras.layers.Dense(1)(layer)

model.compile(loss=keras.losses.mean_squared_error, optimizer=keras.optimizers.Adam(learning_rate=0.01),
metrics=["accuracy"])
model = Model(inputs=[in_layer], outputs=[out_layer])

model.compile(
loss=keras.losses.MeanAbsoluteError(),
optimizer=keras.optimizers.Adam(learning_rate=0.01),
metrics=["accuracy"])

return model

cls.keras_model = functional()
cls.keras_model.fit(cls.x_train_diabetes, cls.y_train_diabetes)

cls.keras_model = model
cls.art_model = KerasRegressor(model=cls.keras_model)
cls.art_model.fit(x=cls.x_train_diabetes, y=cls.y_train_diabetes)

def test_type(self):
self.assertIsInstance(self.art_model, type(KerasRegressor(model=self.keras_model)))
with self.assertRaises(TypeError):
KerasRegressor(model="model")

def test_predict(self):
y_predicted = self.art_model.predict(self.x_test_diabetes[:4])
y_expected = np.asarray([69.0, 81.0, 68.0, 68.0])
# np.testing.assert_array_almost_equal(y_predicted, y_expected, decimal=1)
np.testing.assert_equal(len(np.unique(y_predicted)), 4)

def test_save(self):
self.art_model.save(filename="test.file", path=None)
self.art_model.save(filename="test.file", path="./")

def test_clone_for_refitting(self):
_ = self.art_model.clone_for_refitting()
def test_input_shape(self):
np.testing.assert_equal(self.art_model.input_shape, (10,))

def test_input_layer(self):
np.testing.assert_equal(isinstance(self.art_model.input_layer, int), True)

def test_output_layer(self):
np.testing.assert_equal(isinstance(self.art_model.output_layer, int), True)

def test_compute_loss(self):
test_loss = self.art_model.compute_loss(self.x_test_diabetes[:4], self.y_test_diabetes[:4].astype(np.float32))
np.testing.assert_equal(len(np.unique(test_loss)), 4)

def test_loss_gradient(self):
grad = self.art_model.loss_gradient(self.x_test_diabetes[:4], self.y_test_diabetes[:4])
np.testing.assert_equal(len(np.unique(grad[0])), 10)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit a16aa55

Please sign in to comment.