Skip to content

Commit

Permalink
Remove references to outputs in contrastive model. (tensorflow#311)
Browse files Browse the repository at this point in the history
* Remove references to outputs in contrastive model.

We use the inputs and outputs to support saving the contrastive model
using the Keras API, however, we override train and test steps as well as
predict. This means we don't currently support multiple output heads on
the embedding output. This PR removes all references to multi-headed
outputs and explicitly sets the indexer to use the predictor output.

* Provide default contrastive projector and predictor.

Users had to provide their own MLP models for the projector and
predictor. This required understanding more about the underlying
algorithms. This change now adds default projector and predictor models
based on the original papers.

* Update unsupervised colab.

* Comment out projector and predictor create model functions. We now
  automatically create the MLP models for users, but the commented code
  is left in case the user wants to customize them.
* Verify that the model trains and reloads.
* Loss and performance is slightly better than before.
* Update the create_contrastive_model function to pass a list of outputs
  to better track the outputs. The model still overrides the predict
  function though as we need to apply the L2 Norm at the output.

* Fix mypy error.

* Update ouput var name and use epsilon constant.
  • Loading branch information
owenvallis authored Feb 15, 2023
1 parent c4a8675 commit 7d0f528
Show file tree
Hide file tree
Showing 4 changed files with 1,900 additions and 307 deletions.
2,042 changes: 1,812 additions & 230 deletions examples/unsupervised_hello_world.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tensorflow_similarity/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,6 @@ def evaluate_classification(
# convert_to_tensor is called on a list.
query_labels = tf.convert_to_tensor(np.array(target_labels))

# TODO(ovallis): The float type should be derived from the model.
lookup_distances = unpack_lookup_distances(lookups, dtype=tf.keras.backend.floatx())
lookup_labels = unpack_lookup_labels(lookups, dtype=query_labels.dtype)
thresholds: FloatTensor = tf.cast(
Expand Down
134 changes: 79 additions & 55 deletions tensorflow_similarity/models/contrastive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from tensorflow_similarity.distances import Distance, distance_canonicalizer
from tensorflow_similarity.evaluators.evaluator import Evaluator
from tensorflow_similarity.indexer import Indexer
from tensorflow_similarity.layers import ActivationStdLoggingLayer
from tensorflow_similarity.losses import MetricLoss
from tensorflow_similarity.matchers import ClassificationMatch
from tensorflow_similarity.retrieval_metrics import RetrievalMetric
Expand All @@ -50,26 +51,94 @@
Tensor,
)

# Value based on implementation from original papers.
BN_EPSILON = 1.001e-5


def get_projector(input_dim, dim=512, activation="relu", num_layers: int = 3):
inputs = tf.keras.layers.Input((input_dim,), name="projector_input")
x = inputs

for i in range(num_layers - 1):
x = tf.keras.layers.Dense(
dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name=f"projector_layer_{i}",
)(x)
x = tf.keras.layers.BatchNormalization(epsilon=BN_EPSILON, name=f"batch_normalization_{i}")(x)
x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_{i}")(x)
x = tf.keras.layers.Dense(
dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name="projector_output",
)(x)
x = tf.keras.layers.BatchNormalization(
epsilon=BN_EPSILON,
center=False, # Page:5, Paragraph:2 of SimSiam paper
scale=False, # Page:5, Paragraph:2 of SimSiam paper
name="batch_normalization_ouput",
)(x)
# Metric Logging layer. Monitors the std of the layer activations.
# Degnerate solutions colapse to 0 while valid solutions will move
# towards something like 0.0220. The actual number will depend on the layer size.
outputs = ActivationStdLoggingLayer(name="proj_std")(x)
projector = tf.keras.Model(inputs, outputs, name="projector")
return projector


def get_predictor(input_dim, hidden_dim=512, activation="relu"):
inputs = tf.keras.layers.Input(shape=(input_dim,), name="predictor_input")
x = inputs

x = tf.keras.layers.Dense(
hidden_dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name="predictor_layer_0",
)(x)
x = tf.keras.layers.BatchNormalization(epsilon=BN_EPSILON, name="batch_normalization_0")(x)
x = tf.keras.layers.Activation(activation, name=f"{activation}_activation_0")(x)

x = tf.keras.layers.Dense(
input_dim,
kernel_initializer=tf.keras.initializers.LecunUniform(),
name="predictor_output",
)(x)
# Metric Logging layer. Monitors the std of the layer activations.
# Degnerate solutions colapse to 0 while valid solutions will move
# towards something like 0.0220. The actual number will depend on the layer size.
outputs = ActivationStdLoggingLayer(name="pred_std")(x)
predictor = tf.keras.Model(inputs, outputs, name="predictor")
return predictor


def create_contrastive_model(
*args,
backbone: tf.keras.Model,
projector: tf.keras.Model,
projector: tf.keras.Model | None = None,
predictor: tf.keras.Model | None = None,
algorithm: str = "simsiam",
**kwargs,
) -> ContrastiveModel:
"""Create a contrastive model."""
if projector is None:
projector = get_projector(input_dim=backbone.output_shape[-1], num_layers=2)

input_shape = backbone.input_shape[1:]
inputs = tf.keras.layers.Input(shape=input_shape, name="main_model_input")
projector_features = projector(backbone(inputs))
if algorithm == "simsiam":
if predictor is None:
raise ValueError("The predictor should be specified when using the simsiam algorithm.")
outputs = predictor(projector(backbone(inputs)))
elif algorithm in ("simclr", "barlow"):
outputs = projector(backbone(inputs))
predictor = get_predictor(input_dim=projector.output_shape[-1])
predictor_features = predictor(projector_features)
else:
raise ValueError(f"Unknown algorithm: {algorithm}")
predictor_features = None

outputs = [projector_features]
if predictor_features is not None:
outputs.append(predictor_features)

return ContrastiveModel(
*args,
Expand Down Expand Up @@ -99,13 +168,11 @@ def __init__(
self.projector = projector
self.predictor = predictor

self.outputs = [self.backbone.output]
self.output_names = ["backbone_output"]
self.algorithm = algorithm

self._create_loss_trackers()

self.supported_algorithms = ("simsiam", "simclr", "barlow")
self.supported_algorithms = ("simsiam", "simclr", "barlow", "vicreg")

if self.algorithm not in self.supported_algorithms:
raise ValueError(
Expand All @@ -123,7 +190,6 @@ def compile(
run_eagerly: bool = False,
steps_per_execution: int = 1,
distance: Distance | str = "cosine",
embedding_output: int | None = None,
kv_store: Store | str = "memory",
search: Search | str = "nmslib",
evaluator: Evaluator | str = "memory",
Expand Down Expand Up @@ -160,15 +226,6 @@ def compile(
See [Evaluation Metrics](../eval_metrics.md) for a list of available
metrics.
For multi-output models you can specify different metrics for
different outputs by passing a dictionary, such as
`metrics={'similarity': 'min_neg_gap', 'other': ['accuracy',
'mse']}`. You can also pass a list (len = len(outputs)) of lists of
metrics such as `metrics=[['min_neg_gap'], ['accuracy', 'mse']]` or
`metrics=['min_neg_gap', ['accuracy', 'mse']]`. For outputs which
are not related to metrics learning, you can use any of the standard
`tf.keras.metrics`.
loss_weights: Optional list or dictionary specifying scalar
coefficients (Python floats) to weight the loss contributions of
different model outputs. The loss value that will be minimized by
Expand Down Expand Up @@ -207,12 +264,6 @@ def compile(
evaluator: What type of `Evaluator()` to use to evaluate index
performance. Defaults to in-memory one.
embedding_output: Which model output head predicts the embeddings
that should be indexed. Defaults to None which is for single output
model. For multi-head model, the callee, usually the
`SimilarityModel()` class is responsible for passing the correct
one.
stat_buffer_size: Size of the sliding windows buffer used to compute
index performance. Defaults to 1000.
Expand All @@ -228,7 +279,6 @@ def compile(
search=search,
kv_store=kv_store,
evaluator=evaluator,
embedding_output=embedding_output,
stat_buffer_size=stat_buffer_size,
)

Expand Down Expand Up @@ -358,7 +408,7 @@ def _forward_pass(self, view1, view2, training):
l2 = self.compiled_loss(tf.stop_gradient(z2), p1)
loss = l1 + l2
pred1, pred2 = p1, p2
elif self.algorithm in ["simclr", "barlow"]:
elif self.algorithm in ("simclr", "barlow", "vicreg"):
loss = self.compiled_loss(z1, z2)
pred1, pred2 = z1, z2

Expand Down Expand Up @@ -414,7 +464,6 @@ def predict(
workers,
use_multiprocessing,
)

x = self.projector.predict(
x,
batch_size,
Expand All @@ -436,7 +485,6 @@ def create_index(
search: Search | str = "nmslib",
kv_store: Store | str = "memory",
evaluator: Evaluator | str = "memory",
embedding_output: int | None = None,
stat_buffer_size: int = 1000,
) -> None:
"""Create the model index to make embeddings searchable via KNN.
Expand All @@ -460,43 +508,19 @@ def create_index(
evaluator: What type of `Evaluator()` to use to evaluate index
performance. Defaults to in-memory one.
embedding_output: Which model output head predicts the embeddings
that should be indexed. Defaults to None which is for single output
model. For multi-head model, the callee, usually the
`SimilarityModel()` class is responsible for passing the correct
one.
stat_buffer_size: Size of the sliding windows buffer used to compute
index performance. Defaults to 1000.
Raises:
ValueError: Invalid search framework or key value store.
"""
# check if we we need to set the embedding head
num_outputs = len(self.output_names)
if embedding_output is not None and embedding_output > num_outputs:
raise ValueError("Embedding_output value exceed number of model outputs")

if embedding_output is None and num_outputs > 1:
print(
"Embedding output set to be model output 0. ",
"Use the embedding_output arg to override this.",
)
embedding_output = 0

# fetch embedding size as some ANN libs requires it for init
if num_outputs > 1 and embedding_output is not None:
self.embedding_size = self.outputs[embedding_output].shape[1]
else:
self.embedding_size = self.outputs[0].shape[1]

self._index = Indexer(
embedding_size=self.embedding_size,
embedding_size=self.projector.output_shape[-1],
distance=distance,
search=search,
kv_store=kv_store,
evaluator=evaluator,
embedding_output=embedding_output,
embedding_output=None,
stat_buffer_size=stat_buffer_size,
)

Expand Down
30 changes: 9 additions & 21 deletions tests/models/test_contrastive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import tensorflow as tf

from tensorflow_similarity.layers import ActivationStdLoggingLayer
from tensorflow_similarity.losses import SimSiamLoss
from tensorflow_similarity.models import ContrastiveModel, create_contrastive_model

Expand All @@ -32,24 +33,8 @@ def test_save_and_reload(self):
outputs=backbone_output,
)

projector_input = tf.keras.layers.Input(shape=(4,))
projector_output = tf.keras.layers.Dense(4)(projector_input)
projector = tf.keras.Model(
inputs=projector_input,
outputs=projector_output,
)

predictor_input = tf.keras.layers.Input(shape=(4,))
predictor_output = tf.keras.layers.Dense(4)(predictor_input)
predictor = tf.keras.Model(
inputs=predictor_input,
outputs=predictor_output,
)

model = create_contrastive_model(
backbone=backbone,
projector=projector,
predictor=predictor,
algorithm="simsiam",
)
opt = tf.keras.optimizers.RMSprop(learning_rate=0.5)
Expand All @@ -72,7 +57,10 @@ def test_save_and_reload(self):
# with tf.distribute.MirroredStrategy().scope():
loaded_model = tf.keras.models.load_model(
out_dir,
custom_objects={"ContrastiveModel": ContrastiveModel},
custom_objects={
"ContrastiveModel": ContrastiveModel,
"ActivationStdLoggingLayer": ActivationStdLoggingLayer,
},
)

pred = loaded_model.predict(x)
Expand All @@ -81,11 +69,11 @@ def test_save_and_reload(self):
self.assertEqual(loaded_model.optimizer.lr, 0.5)
self.assertAllEqual(loaded_model.backbone.input_shape, (None, 3))
self.assertAllEqual(loaded_model.backbone.output_shape, (None, 4))
self.assertAllEqual(loaded_model.predictor.input_shape, (None, 4))
self.assertAllEqual(loaded_model.predictor.output_shape, (None, 4))
self.assertAllEqual(loaded_model.projector.input_shape, (None, 4))
self.assertAllEqual(loaded_model.projector.output_shape, (None, 4))
self.assertAllEqual(pred.shape, (2, 4))
self.assertAllEqual(loaded_model.projector.output_shape, (None, 512))
self.assertAllEqual(loaded_model.predictor.input_shape, (None, 512))
self.assertAllEqual(loaded_model.predictor.output_shape, (None, 512))
self.assertAllEqual(pred.shape, (2, 512))
self.assertAllEqual(model.predict(x), pred)


Expand Down

0 comments on commit 7d0f528

Please sign in to comment.