Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING CHANGE] deel-lip upgrade to Keras 3.0 #91

Open
wants to merge 44 commits into
base: keras3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1838266
fix(Keras3): layer.input.shape instead of layer.input_shape
cofri Aug 1, 2024
7bd62e3
fix(Keras 3): argument order changed in Layer.add_weight()
cofri Aug 1, 2024
c980436
fix (Keras 3): tf.Variable.read_value() was removed
cofri Aug 1, 2024
8cae0fc
fix (Keras 3): Reduction API removed, replaced with string
cofri Aug 1, 2024
72aca2f
fix(Keras 3): argument order changed in Loss.__init__()
cofri Aug 1, 2024
85bd915
fix(Keras 3): Input layer must have a shape as a tuple
cofri Aug 1, 2024
813b77d
fix(Keras 3): model.save() does not accept TF SavedModel format
cofri Aug 1, 2024
cef9657
fix(Keras 3): model.save(path) does not create path if not exist
cofri Aug 2, 2024
235a5ac
fix(Keras 3): Adam optimizer does not support `lr` argument anymore
cofri Aug 1, 2024
7faac08
fix(Keras 3): Conv2DTranspose has no arg `output_padding` anymore
cofri Aug 1, 2024
f166742
fix(Keras 3): argument order in Sequential
cofri Sep 9, 2024
ef426b8
feat(callbacks): upgrade to Keras 3
cofri Aug 6, 2024
51e8f28
feat(compute_layer_sv): upgrade to Keras 3
cofri Aug 6, 2024
21b02c6
feat(constraints): upgrade to Keras 3
cofri Aug 6, 2024
aecf96e
feat(initializers): upgrade to Keras 3
cofri Aug 6, 2024
730533d
feat(losses): upgrade to Keras 3
cofri Aug 6, 2024
a70137b
feat(metrics): upgrade to Keras 3
cofri Aug 6, 2024
7a9af98
feat(model): upgrade to Keras 3
cofri Aug 6, 2024
42bf583
feat(normalizers): upgrade to Keras 3
cofri Aug 6, 2024
f32b392
feat(regularizers): upgrade to Keras 3
cofri Aug 6, 2024
fbde10d
feat(utils): upgrade to Keras 3
cofri Aug 6, 2024
36bd039
feat(activations): upgrade to Keras 3
cofri Aug 6, 2024
168ed33
feat(unconstrained): upgrade to Keras 3
cofri Aug 6, 2024
7a461fe
feat(pooling): upgrade to Keras 3
cofri Aug 6, 2024
d105cfa
feat(dense): upgrade to Keras 3
cofri Aug 7, 2024
e48ec06
feat(convolutional): upgrade to Keras 3
cofri Aug 7, 2024
e504dd8
feat(init): upgrade to Keras 3
cofri Aug 7, 2024
2f7b8d6
feat(test_activations): upgrade to Keras 3
cofri Aug 7, 2024
8b3a403
feat(test_compute_layer_sv): upgrade to Keras 3
cofri Aug 8, 2024
c096ce1
feat(test_condense): upgrade to Keras 3
cofri Aug 8, 2024
0f5ce8d
feat(test_initializers): upgrade to Keras 3
cofri Aug 8, 2024
708e2b0
feat(test_losses): upgrade to Keras 3
cofri Aug 8, 2024
4c73c9d
feat(test_metrics): upgrade to Keras 3
cofri Aug 8, 2024
9e8bbb1
feat(test_models): upgrade to Keras 3
cofri Aug 8, 2024
c8d3d15
feat(test_normalizers): upgrade to Keras 3
cofri Aug 8, 2024
ef6e0c1
feat(test_regularizers): upgrade to Keras 3
cofri Aug 8, 2024
99b2f96
feat(test_unconstrained_layers): upgrade to Keras 3
cofri Aug 8, 2024
cb764ed
feat(test_layers): upgrade to Keras 3
cofri Aug 8, 2024
0ce0a71
feat(layers): save/load own variables in dense and conv
cofri Aug 8, 2024
f0cd52d
feat(notebooks): upgrade to Keras 3
cofri Sep 6, 2024
e697bbb
chore: enforce TF>=2.16 and Keras 3
cofri Aug 8, 2024
b462401
chore: bump to deel-lip version 2.0.0
cofri Aug 8, 2024
ce88cc1
chore: clean github actions to latest Python and TF versions
cofri Sep 6, 2024
ed92ee3
fix(callbacks): Keras SVD op is not as expected
cofri Sep 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.7, "3.10"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
Expand Down
8 changes: 1 addition & 7 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,8 @@ jobs:
max-parallel: 4
matrix:
include:
- python-version: 3.7
tf-version: 2.3
- python-version: 3.9
tf-version: 2.7
- python-version: "3.10"
tf-version: 2.11
- python-version: "3.10"
tf-version: 2.15
tf-version: 2.17

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion deel/lip/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.5.0
2.0.0
2 changes: 1 addition & 1 deletion deel/lip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
--------

DEEL-LIP provides a simple interface to build and train Lipschitz-constrained neural
networks based on TensorFlow/Keras framework.
networks based on Keras framework.
"""
from os import path

Expand Down
24 changes: 15 additions & 9 deletions deel/lip/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
This module contains callbacks that can be added to keras training process.
"""
import os
from typing import Optional, Dict, Iterable
from typing import Dict, Iterable, Optional

import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import keras.ops as K
import numpy as np
from keras.callbacks import Callback

from .layers import Condensable


Expand Down Expand Up @@ -91,6 +92,8 @@ def __init__(
assert what in {"max", "all"}
self.what = what
self.logdir = logdir
import tensorflow as tf

self.file_writer = tf.summary.create_file_writer(
os.path.join(logdir, "metrics")
)
Expand All @@ -101,6 +104,8 @@ def __init__(
super().__init__()

def _monitor(self, step):
import tensorflow as tf

step = self.params["steps"] * self.epochs + step
for layer_name in self.monitored_layers:
layer = self.model.get_layer(layer_name)
Expand All @@ -113,11 +118,12 @@ def _monitor(self, step):
elif hasattr(layer, self.target):
kernel = getattr(layer, self.target)
w_shape = kernel.shape.as_list()
sigmas = tf.linalg.svd(
tf.keras.backend.reshape(kernel, [-1, w_shape[-1]]),
# TODO: compute_uv=False in next Keras version (3.6.0)
sigmas = K.svd(
K.reshape(kernel, [-1, w_shape[-1]]),
full_matrices=False,
compute_uv=False,
).numpy()
compute_uv=True,
)[1].numpy()
sig = sigmas[0]
else:
raise RuntimeWarning(
Expand Down Expand Up @@ -176,7 +182,7 @@ def __init__(self, param_name, fp, xp, step=0):

Args:
param_name (str): name of the parameter of the loss to tune. Must be a
tf.Variable.
keras.Variable.
fp (list): values of the loss parameter as steps given by the xp.
xp (list): step where the parameter equals fp.
step (int): step value, for serialization/deserialization purposes.
Expand Down Expand Up @@ -215,7 +221,7 @@ def __init__(self, param_name, rate=1):

def on_epoch_end(self, epoch: int, logs=None):
if epoch % self.rate == 0:
tf.print(
print(
"\n",
self.model.loss.name,
self.param_name,
Expand Down
44 changes: 21 additions & 23 deletions deel/lip/compute_layer_sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
It returns a dictionary indicating for each layer name a tuple (min sv, max sv).
"""

import keras
import numpy as np
import tensorflow as tf

from .layers import Condensable, GroupSort, MaxMin
from .layers.unconstrained import PadConv2D
Expand All @@ -27,7 +27,7 @@ def _compute_sv_dense(layer, input_sizes=None):
The singular values are computed using the SVD decomposition of the weight matrix.

Args:
layer (tf.keras.Layer): the Dense layer.
layer (keras.Layer): the Dense layer.
input_sizes (tuple, optional): unused here.

Returns:
Expand All @@ -46,16 +46,14 @@ def _generate_conv_matrix(layer, input_sizes):
dirac input.

Args:
layer (tf.keras.Layer): the convolutional layer to convert to dense.
layer (keras.Layer): the convolutional layer to convert to dense.
input_sizes (tuple): the input shape of the layer (with batch dimension as first
element).

Returns:
np.array: the equivalent matrix of the convolutional layer.
"""
single_layer_model = tf.keras.models.Sequential(
[tf.keras.layers.Input(input_sizes[1:]), layer]
)
single_layer_model = keras.Sequential([keras.Input(input_sizes[1:]), layer])
dirac_inp = np.zeros((input_sizes[2],) + input_sizes[1:]) # Line by line generation
in_size = input_sizes[1] * input_sizes[2]
channel_in = input_sizes[-1]
Expand All @@ -69,8 +67,8 @@ def _generate_conv_matrix(layer, input_sizes):
w_eqmatrix = np.zeros(
(in_size * channel_in, np.prod(out_pred.shape[1:]))
)
w_eqmatrix[start_index : (start_index + input_sizes[2]), :] = tf.reshape(
out_pred, (input_sizes[2], -1)
w_eqmatrix[start_index : (start_index + input_sizes[2]), :] = (
keras.ops.reshape(out_pred, (input_sizes[2], -1))
)
dirac_inp = 0.0 * dirac_inp
start_index += input_sizes[2]
Expand All @@ -86,7 +84,7 @@ def _compute_sv_conv2d_layer(layer, input_sizes):
the weight matrix.

Args:
layer (tf.keras.Layer): the convolutional layer.
layer (keras.Layer): the convolutional layer.
input_sizes (tuple): the input shape of the layer (with batch dimension as first
element).

Expand All @@ -103,14 +101,14 @@ def _compute_sv_activation(layer, input_sizes=None):

Warning: This is not singular values for non-linear functions but gradient norm.
"""
if isinstance(layer, tf.keras.layers.Activation):
function2SV = {tf.keras.activations.relu: (0, 1)}
if isinstance(layer, keras.layers.Activation):
function2SV = {keras.activations.relu: (0, 1)}
if layer.activation in function2SV.keys():
return function2SV[layer.activation]
else:
return (None, None)
layer2SV = {
tf.keras.layers.ReLU: (0, 1),
keras.layers.ReLU: (0, 1),
GroupSort: (1, 1),
MaxMin: (1, 1),
}
Expand Down Expand Up @@ -145,25 +143,25 @@ def compute_layer_sv(layer, supplementary_type2sv={}):
ReLU, Activation, and deel-lip layers)

Args:
layer (tf.keras.layers.Layer): a single tf.keras.layer
layer (keras.layers.Layer): a single keras.layer
supplementary_type2sv (dict, optional): a dictionary linking new layer type with
user-defined function to compute the singular values. Defaults to {}.
Returns:
tuple: a 2-tuple with lowest and largest singular values.
"""
default_type2sv = {
tf.keras.layers.Conv2D: _compute_sv_conv2d_layer,
tf.keras.layers.Conv2DTranspose: _compute_sv_conv2d_layer,
keras.layers.Conv2D: _compute_sv_conv2d_layer,
keras.layers.Conv2DTranspose: _compute_sv_conv2d_layer,
PadConv2D: _compute_sv_conv2d_layer,
tf.keras.layers.Dense: _compute_sv_dense,
tf.keras.layers.ReLU: _compute_sv_activation,
tf.keras.layers.Activation: _compute_sv_activation,
keras.layers.Dense: _compute_sv_dense,
keras.layers.ReLU: _compute_sv_activation,
keras.layers.Activation: _compute_sv_activation,
GroupSort: _compute_sv_activation,
MaxMin: _compute_sv_activation,
tf.keras.layers.Add: _compute_sv_add,
tf.keras.layers.BatchNormalization: _compute_sv_bn,
keras.layers.Add: _compute_sv_add,
keras.layers.BatchNormalization: _compute_sv_bn,
}
input_shape = layer.input_shape
input_shape = layer.input.shape if hasattr(layer.input, "shape") else None
if isinstance(layer, Condensable):
layer.condense()
layer = layer.vanilla_export()
Expand All @@ -179,7 +177,7 @@ def compute_model_sv(model, supplementary_type2sv={}):
"""Compute the largest and lowest singular values of all layers in a model.

Args:
model (tf.keras.Model): a tf.keras Model or Sequential.
model (keras.Model): a keras Model or Sequential.
supplementary_type2sv (dict, optional): a dictionary linking new layer type
with user defined function to compute the min and max singular values.

Expand All @@ -188,7 +186,7 @@ def compute_model_sv(model, supplementary_type2sv={}):
"""
list_sv = []
for layer in model.layers:
if isinstance(layer, tf.keras.Model):
if isinstance(layer, keras.Model):
list_sv.append((layer.name, (None, None)))
list_sv += compute_model_sv(layer, supplementary_type2sv)
else:
Expand Down
19 changes: 9 additions & 10 deletions deel/lip/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
This module contains extra constraint objects. These object can be added as params to
regular layers.
"""
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
import keras.ops as K
from keras.constraints import Constraint
from .normalizers import (
reshaped_kernel_orthogonalization,
DEFAULT_EPS_SPECTRAL,
DEFAULT_EPS_BJORCK,
DEFAULT_BETA_BJORCK,
)
from tensorflow.keras.utils import register_keras_serializable
from keras.saving import register_keras_serializable


@register_keras_serializable("deel-lip", "WeightClipConstraint")
Expand Down Expand Up @@ -49,8 +48,8 @@ def __init__(self, scale=1):
self.scale = scale

def __call__(self, w):
c = 1 / (tf.sqrt(tf.cast(tf.size(w), dtype=w.dtype)) * self.scale)
return tf.clip_by_value(w, -c, c)
c = 1 / (K.sqrt(K.cast(K.size(w), dtype=w.dtype)) * self.scale)
return K.clip(w, -c, c)

def get_config(self):
return {"scale": self.scale}
Expand All @@ -67,7 +66,7 @@ def __init__(self, eps=1e-7):
self.eps = eps

def __call__(self, w):
return w / (tf.sqrt(tf.reduce_sum(tf.square(w), keepdims=False)) + self.eps)
return w / (K.sqrt(K.sum(K.square(w), keepdims=False)) + self.eps)

def get_config(self):
return {"eps": self.eps}
Expand Down Expand Up @@ -95,15 +94,15 @@ def __init__(
eps_spectral (float): stopping criterion for the iterative power algorithm.
eps_bjorck (float): stopping criterion Bjorck algorithm.
beta_bjorck (float): beta parameter in bjorck algorithm.
u (tf.Tensor): vector used for iterated power method, can be set to None
u (Tensor): vector used for iterated power method, can be set to None
(used for serialization/deserialization purposes).
"""
self.eps_spectral = eps_spectral
self.eps_bjorck = eps_bjorck
self.beta_bjorck = beta_bjorck
self.k_coef_lip = k_coef_lip
if not (isinstance(u, tf.Tensor) or (u is None)):
u = tf.convert_to_tensor(u)
if not (K.is_tensor(u) or (u is None)):
u = K.convert_to_tensor(u)
self.u = u
super(SpectralConstraint, self).__init__()

Expand Down
12 changes: 6 additions & 6 deletions deel/lip/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
matrix initialization.
They can be used as kernel initializers in any Keras layer.
"""
from tensorflow.keras.initializers import Initializer
from tensorflow.keras import initializers
import keras
from keras.saving import register_keras_serializable

from .normalizers import (
reshaped_kernel_orthogonalization,
DEFAULT_EPS_SPECTRAL,
DEFAULT_EPS_BJORCK,
DEFAULT_BETA_BJORCK,
)
from tensorflow.keras.utils import register_keras_serializable


@register_keras_serializable("deel-lip", "SpectralInitializer")
class SpectralInitializer(Initializer):
class SpectralInitializer(keras.Initializer):
def __init__(
self,
eps_spectral=DEFAULT_EPS_SPECTRAL,
Expand All @@ -44,7 +44,7 @@ def __init__(
self.eps_bjorck = eps_bjorck
self.beta_bjorck = beta_bjorck
self.k_coef_lip = k_coef_lip
self.base_initializer = initializers.get(base_initializer)
self.base_initializer = keras.initializers.get(base_initializer)
super(SpectralInitializer, self).__init__()

def __call__(self, shape, dtype=None, partition_info=None):
Expand All @@ -65,5 +65,5 @@ def get_config(self):
"eps_bjorck": self.eps_bjorck,
"beta_bjorck": self.beta_bjorck,
"k_coef_lip": self.k_coef_lip,
"base_initializer": initializers.serialize(self.base_initializer),
"base_initializer": keras.initializers.serialize(self.base_initializer),
}
Loading