Skip to content

Commit

Permalink
Fix typo "constrastive"
Browse files Browse the repository at this point in the history
  • Loading branch information
jondo authored Mar 24, 2023
1 parent 24d8429 commit 41f4d3d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions examples/vision/siamese_contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

epochs = 10
batch_size = 16
margin = 1 # Margin for constrastive loss.
margin = 1 # Margin for contrastive loss.

"""
## Load the MNIST dataset
Expand Down Expand Up @@ -301,33 +301,33 @@ def euclidean_distance(vects):


"""
## Define the constrastive Loss
## Define the contrastive Loss
"""


def loss(margin=1):
"""Provides 'constrastive_loss' an enclosing scope with variable 'margin'.
"""Provides 'contrastive_loss' an enclosing scope with variable 'margin'.
Arguments:
margin: Integer, defines the baseline for distance for which pairs
should be classified as dissimilar. - (default is 1).
Returns:
'constrastive_loss' function with data ('margin') attached.
'contrastive_loss' function with data ('margin') attached.
"""

# Contrastive loss = mean( (1-true_value) * square(prediction) +
# true_value * square( max(margin-prediction, 0) ))
def contrastive_loss(y_true, y_pred):
"""Calculates the constrastive loss.
"""Calculates the contrastive loss.
Arguments:
y_true: List of labels, each label is of type float32.
y_pred: List of predictions of same length as of y_true,
each label is of type float32.
Returns:
A tensor containing constrastive loss as floating point value.
A tensor containing contrastive loss as floating point value.
"""

square_pred = tf.math.square(y_pred)
Expand Down Expand Up @@ -389,8 +389,8 @@ def plt_metric(history, metric, title, has_valid=True):
# Plot the accuracy
plt_metric(history=history.history, metric="accuracy", title="Model accuracy")

# Plot the constrastive loss
plt_metric(history=history.history, metric="loss", title="Constrastive Loss")
# Plot the contrastive loss
plt_metric(history=history.history, metric="loss", title="Contrastive Loss")

"""
## Evaluate the model
Expand Down

0 comments on commit 41f4d3d

Please sign in to comment.