|
37 | 37 |
|
38 | 38 | epochs = 10
|
39 | 39 | batch_size = 16
|
40 |
| -margin = 1 # Margin for constrastive loss. |
| 40 | +margin = 1 # Margin for contrastive loss. |
41 | 41 |
|
42 | 42 | """
|
43 | 43 | ## Load the MNIST dataset
|
@@ -301,33 +301,33 @@ def euclidean_distance(vects):
|
301 | 301 |
|
302 | 302 |
|
303 | 303 | """
|
304 |
| -## Define the constrastive Loss |
| 304 | +## Define the contrastive Loss |
305 | 305 | """
|
306 | 306 |
|
307 | 307 |
|
308 | 308 | def loss(margin=1):
|
309 |
| - """Provides 'constrastive_loss' an enclosing scope with variable 'margin'. |
| 309 | + """Provides 'contrastive_loss' an enclosing scope with variable 'margin'. |
310 | 310 |
|
311 | 311 | Arguments:
|
312 | 312 | margin: Integer, defines the baseline for distance for which pairs
|
313 | 313 | should be classified as dissimilar. - (default is 1).
|
314 | 314 |
|
315 | 315 | Returns:
|
316 |
| - 'constrastive_loss' function with data ('margin') attached. |
| 316 | + 'contrastive_loss' function with data ('margin') attached. |
317 | 317 | """
|
318 | 318 |
|
319 | 319 | # Contrastive loss = mean( (1-true_value) * square(prediction) +
|
320 | 320 | # true_value * square( max(margin-prediction, 0) ))
|
321 | 321 | def contrastive_loss(y_true, y_pred):
|
322 |
| - """Calculates the constrastive loss. |
| 322 | + """Calculates the contrastive loss. |
323 | 323 |
|
324 | 324 | Arguments:
|
325 | 325 | y_true: List of labels, each label is of type float32.
|
326 | 326 | y_pred: List of predictions of same length as of y_true,
|
327 | 327 | each label is of type float32.
|
328 | 328 |
|
329 | 329 | Returns:
|
330 |
| - A tensor containing constrastive loss as floating point value. |
| 330 | + A tensor containing contrastive loss as floating point value. |
331 | 331 | """
|
332 | 332 |
|
333 | 333 | square_pred = tf.math.square(y_pred)
|
@@ -389,8 +389,8 @@ def plt_metric(history, metric, title, has_valid=True):
|
389 | 389 | # Plot the accuracy
|
390 | 390 | plt_metric(history=history.history, metric="accuracy", title="Model accuracy")
|
391 | 391 |
|
392 |
| -# Plot the constrastive loss |
393 |
| -plt_metric(history=history.history, metric="loss", title="Constrastive Loss") |
| 392 | +# Plot the contrastive loss |
| 393 | +plt_metric(history=history.history, metric="loss", title="Contrastive Loss") |
394 | 394 |
|
395 | 395 | """
|
396 | 396 | ## Evaluate the model
|
|
0 commit comments