-
-
Notifications
You must be signed in to change notification settings - Fork 608
/
functions.jl
656 lines (505 loc) · 17.9 KB
/
functions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
# In this file, doctests which differ in the printed Float32 values won't fail
```@meta
DocTestFilters = r"[0-9\.]+f0"
```
"""
mae(ŷ, y; agg = mean)
Return the loss corresponding to mean absolute error:
agg(abs.(ŷ .- y))
# Example
```jldoctest
julia> y_model = [1.1, 1.9, 3.1];
julia> Flux.mae(y_model, 1:3)
0.10000000000000009
```
"""
function mae(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg(abs.(ŷ .- y))
end
"""
mse(ŷ, y; agg = mean)
Return the loss corresponding to mean square error:
agg((ŷ .- y) .^ 2)
See also: [`mae`](@ref), [`msle`](@ref), [`crossentropy`](@ref).
# Example
```jldoctest
julia> y_model = [1.1, 1.9, 3.1];
julia> y_true = 1:3;
julia> Flux.mse(y_model, y_true)
0.010000000000000018
```
"""
function mse(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg(abs2.(ŷ .- y))
end
"""
msle(ŷ, y; agg = mean, eps = eps(eltype(ŷ)))
The loss corresponding to mean squared logarithmic errors, calculated as
agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)) .^ 2)
The `ϵ == eps` term provides numerical stability.
Penalizes an under-estimation more than an over-estimatation.
# Example
```jldoctest
julia> Flux.msle(Float32[1.1, 2.2, 3.3], 1:3)
0.009084041f0
julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3)
0.011100831f0
```
"""
function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :msle, "ϵ" => "eps")
_check_sizes(ŷ, y)
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
end
function _huber_metric(abs_error, δ)
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
temp = Zygote.ignore_derivatives(abs_error .< δ)
x = ofeltype(abs_error, 0.5)
((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1 - temp)
end
"""
huber_loss(ŷ, y; delta = 1, agg = mean)
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
given the prediction `ŷ` and true values `y`.
| 0.5 * |ŷ - y|^2, for |ŷ - y| <= δ
Huber loss = |
| δ * (|ŷ - y| - 0.5 * δ), otherwise
# Example
```jldoctest
julia> ŷ = [1.1, 2.1, 3.1];
julia> Flux.huber_loss(ŷ, 1:3) # default δ = 1 > |ŷ - y|
0.005000000000000009
julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| > δ
0.003750000000000005
```
"""
function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing)
delta_tmp = _greek_ascii_depwarn(δ => delta, :huber_loss, "δ" => "delta")
δ = ofeltype(ŷ, delta_tmp)
_check_sizes(ŷ, y)
abs_error = abs.(ŷ .- y)
agg(_huber_metric.(abs_error, δ))
end
"""
label_smoothing(y::Union{Number, AbstractArray}, α; dims::Int=1)
Returns smoothed labels, meaning the confidence on label values are relaxed.
When `y` is given as one-hot vector or batch of one-hot, its calculated as
y .* (1 - α) .+ α / size(y, dims)
when `y` is given as a number or batch of numbers for binary classification,
its calculated as
y .* (1 - α) .+ α / 2
in which case the labels are squeezed towards `0.5`.
α is a number in interval (0, 1) called the smoothing factor. Higher the
value of α larger the smoothing of `y`.
`dims` denotes the one-hot dimension, unless `dims=0` which denotes the application
of label smoothing to binary distributions encoded in a single number.
# Example
```jldoctest
julia> y = Flux.onehotbatch([1, 1, 1, 0, 1, 0], 0:1)
2×6 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
⋅ ⋅ ⋅ 1 ⋅ 1
1 1 1 ⋅ 1 ⋅
julia> y_smoothed = Flux.label_smoothing(y, 0.2f0)
2×6 Matrix{Float32}:
0.1 0.1 0.1 0.9 0.1 0.9
0.9 0.9 0.9 0.1 0.9 0.1
julia> y_sim = softmax(y .* log(2f0))
2×6 Matrix{Float32}:
0.333333 0.333333 0.333333 0.666667 0.333333 0.666667
0.666667 0.666667 0.666667 0.333333 0.666667 0.333333
julia> y_dis = vcat(y_sim[2,:]', y_sim[1,:]')
2×6 Matrix{Float32}:
0.666667 0.666667 0.666667 0.333333 0.666667 0.333333
0.333333 0.333333 0.333333 0.666667 0.333333 0.666667
julia> Flux.crossentropy(y_sim, y) < Flux.crossentropy(y_sim, y_smoothed)
true
julia> Flux.crossentropy(y_dis, y) > Flux.crossentropy(y_dis, y_smoothed)
true
```
"""
function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int = 1)
if !(0 < α < 1)
throw(ArgumentError("α must be between 0 and 1"))
end
if dims == 0
y_smoothed = y .* (1 - α) .+ α*1//2
elseif dims == 1
y_smoothed = y .* (1 - α) .+ α* 1 // size(y, 1)
else
throw(ArgumentError("`dims` should be either 0 or 1"))
end
return y_smoothed
end
"""
crossentropy(ŷ, y; dims = 1, eps = eps(eltype(ŷ)), agg = mean)
Return the cross entropy between the given probability distributions;
calculated as
agg(-sum(y .* log.(ŷ .+ ϵ); dims))
Cross entropy is typically used as a loss in multi-class classification,
in which case the labels `y` are given in a one-hot format.
`dims` specifies the dimension (or the dimensions) containing the class probabilities.
The prediction `ŷ` is supposed to sum to one across `dims`,
as would be the case with the output of a [softmax](@ref Softmax) operation.
For numerical stability, it is recommended to use [`logitcrossentropy`](@ref)
rather than `softmax` followed by `crossentropy` .
Use [`label_smoothing`](@ref) to smooth the true labels as preprocessing before
computing the loss.
See also: [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref).
# Example
```jldoctest
julia> y_label = Flux.onehotbatch([0, 1, 2, 1, 0], 0:2)
3×5 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
1 ⋅ ⋅ ⋅ 1
⋅ 1 ⋅ 1 ⋅
⋅ ⋅ 1 ⋅ ⋅
julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Matrix{Float32}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241
julia> sum(y_model; dims=1)
1×5 Matrix{Float32}:
1.0 1.0 1.0 1.0 1.0
julia> Flux.crossentropy(y_model, y_label)
1.6076053f0
julia> 5 * ans ≈ Flux.crossentropy(y_model, y_label; agg=sum)
true
julia> y_smooth = Flux.label_smoothing(y_label, 0.15f0)
3×5 Matrix{Float32}:
0.9 0.05 0.05 0.05 0.9
0.05 0.9 0.05 0.9 0.05
0.05 0.05 0.9 0.05 0.05
julia> Flux.crossentropy(y_model, y_smooth)
1.5776052f0
```
"""
function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :crossentropy, "ϵ" => "eps")
_check_sizes(ŷ, y)
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims))
end
"""
logitcrossentropy(ŷ, y; dims = 1, agg = mean)
Return the cross entropy calculated by
agg(-sum(y .* logsoftmax(ŷ; dims); dims))
This is mathematically equivalent to `crossentropy(softmax(ŷ), y)`,
but is more numerically stable than using functions [`crossentropy`](@ref)
and [softmax](@ref Softmax) separately.
See also: [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`label_smoothing`](@ref).
# Example
```jldoctest
julia> y_label = Flux.onehotbatch(collect("abcabaa"), 'a':'c')
3×7 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
1 ⋅ ⋅ 1 ⋅ 1 1
⋅ 1 ⋅ ⋅ 1 ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
julia> y_model = reshape(vcat(-9:0, 0:9, 7.5f0), 3, 7)
3×7 Matrix{Float32}:
-9.0 -6.0 -3.0 0.0 2.0 5.0 8.0
-8.0 -5.0 -2.0 0.0 3.0 6.0 9.0
-7.0 -4.0 -1.0 1.0 4.0 7.0 7.5
julia> Flux.logitcrossentropy(y_model, y_label)
1.5791205f0
julia> Flux.crossentropy(softmax(y_model), y_label)
1.5791197f0
```
"""
function logitcrossentropy(ŷ, y; dims = 1, agg = mean)
_check_sizes(ŷ, y)
agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims))
end
"""
binarycrossentropy(ŷ, y; agg = mean, eps = eps(eltype(ŷ)))
Return the binary cross-entropy loss, computed as
agg(@.(-y * log(ŷ + ϵ) - (1 - y) * log(1 - ŷ + ϵ)))
Where typically, the prediction `ŷ` is given by the output of a [sigmoid](@ref man-activations) activation.
The `ϵ == eps` term is included to avoid infinity. Using [`logitbinarycrossentropy`](@ref) is recomended
over `binarycrossentropy` for numerical stability.
Use [`label_smoothing`](@ref) to smooth the `y` value as preprocessing before
computing the loss.
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref).
# Examples
```jldoctest
julia> y_bin = Bool[1,0,1]
3-element Vector{Bool}:
1
0
1
julia> y_prob = softmax(reshape(vcat(1:3, 3:5), 2, 3) .* 1f0)
2×3 Matrix{Float32}:
0.268941 0.5 0.268941
0.731059 0.5 0.731059
julia> Flux.binarycrossentropy(y_prob[2,:], y_bin)
0.43989f0
julia> all(p -> 0 < p < 1, y_prob[2,:]) # else DomainError
true
julia> y_hot = Flux.onehotbatch(y_bin, 0:1)
2×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
⋅ 1 ⋅
1 ⋅ 1
julia> Flux.crossentropy(y_prob, y_hot)
0.43989f0
```
"""
function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :binarycrossentropy, "ϵ" => "eps")
_check_sizes(ŷ, y)
agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ)))
end
"""
logitbinarycrossentropy(ŷ, y; agg = mean)
Mathematically equivalent to
[`binarycrossentropy(σ(ŷ), y)`](@ref) but is more numerically stable.
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref).
# Examples
```jldoctest
julia> y_bin = Bool[1,0,1];
julia> y_model = Float32[2, -1, pi]
3-element Vector{Float32}:
2.0
-1.0
3.1415927
julia> Flux.logitbinarycrossentropy(y_model, y_bin)
0.160832f0
julia> Flux.binarycrossentropy(sigmoid.(y_model), y_bin)
0.16083185f0
```
"""
function logitbinarycrossentropy(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg(@.((1 - y) * ŷ - logσ(ŷ)))
end
"""
kldivergence(ŷ, y; agg = mean, eps = eps(eltype(ŷ)))
Return the
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
between the given probability distributions.
The KL divergence is a measure of how much one probability distribution is different
from the other. It is always non-negative, and zero only when both the distributions are equal.
# Example
```jldoctest
julia> p1 = [1 0; 0 1]
2×2 Matrix{Int64}:
1 0
0 1
julia> p2 = fill(0.5, 2, 2)
2×2 Matrix{Float64}:
0.5 0.5
0.5 0.5
julia> Flux.kldivergence(p2, p1) ≈ log(2)
true
julia> Flux.kldivergence(p2, p1; agg = sum) ≈ 2log(2)
true
julia> Flux.kldivergence(p2, p2; eps = 0) # about -2e-16 with the regulator
0.0
julia> Flux.kldivergence(p1, p2; eps = 0) # about 17.3 with the regulator
Inf
```
"""
function kldivergence(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :kldivergence, "ϵ" => "eps")
_check_sizes(ŷ, y)
entropy = agg(sum(xlogx.(y); dims = dims))
cross_entropy = crossentropy(ŷ, y; dims, agg, eps=ϵ)
return entropy + cross_entropy
end
"""
poisson_loss(ŷ, y; agg = mean)
Return how much the predicted distribution `ŷ` diverges from the expected Poisson
distribution `y`; calculated as -
sum(ŷ .- y .* log.(ŷ)) / size(y, 2)
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
# Example
```jldoctest
julia> y_model = [1, 3, 3]; # data should only take integral values
julia> Flux.poisson_loss(y_model, 1:3)
0.5023128522198171
```
"""
function poisson_loss(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg(ŷ .- xlogy.(y, ŷ))
end
"""
hinge_loss(ŷ, y; agg = mean)
Return the [hinge_loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)
Usually used with classifiers like Support Vector Machines.
See also: [`squared_hinge_loss`](@ref)
# Example
```jldoctest
julia> y_true = [1, -1, 1, 1];
julia> y_pred = [0.1, 0.3, 1, 1.5];
julia> Flux.hinge_loss(y_pred, y_true)
0.55
julia> Flux.hinge_loss(y_pred[1], y_true[1]) != 0 # same sign but |ŷ| < 1
true
julia> Flux.hinge_loss(y_pred[end], y_true[end]) == 0 # same sign but |ŷ| >= 1
true
julia> Flux.hinge_loss(y_pred[2], y_true[2]) != 0 # opposite signs
true
```
"""
function hinge_loss(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg(max.(0, 1 .- ŷ .* y))
end
"""
squared_hinge_loss(ŷ, y)
Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y`
(containing 1 or -1); calculated as
sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)
Usually used with classifiers like Support Vector Machines.
See also: [`hinge_loss`](@ref)
# Example
```jldoctes
julia> y_true = [1, -1, 1, 1];
julia> y_pred = [0.1, 0.3, 1, 1.5];
julia> Flux.squared_hinge_loss(y_pred, y_true)
0.625
julia> Flux.squared_hinge_loss(y_pred[1], y_true[1]) != 0
true
julia> Flux.squared_hinge_loss(y_pred[end], y_true[end]) == 0
true
julia> Flux.squared_hinge_loss(y_pred[2], y_true[2]) != 0
true
```
"""
function squared_hinge_loss(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg((max.(0, 1 .- ŷ .* y)) .^ 2)
end
"""
dice_coeff_loss(ŷ, y; smooth = 1)
Return a loss based on the dice coefficient.
Used in the [V-Net](https://arxiv.org/abs/1606.04797) image segmentation
architecture.
The dice coefficient is similar to the F1_score. Loss calculated as:
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)
# Example
```jldoctest
julia> y_pred = [1.1, 2.1, 3.1];
julia> Flux.dice_coeff_loss(y_pred, 1:3)
0.000992391663909964
julia> 1 - Flux.dice_coeff_loss(y_pred, 1:3) # ~ F1 score for image segmentation
0.99900760833609
```
"""
function dice_coeff_loss(ŷ, y; smooth = 1)
s = ofeltype(ŷ, smooth)
_check_sizes(ŷ, y)
# TODO add agg
1 - (2 * sum(y .* ŷ) + s) / (sum(y .^ 2) + sum(ŷ .^ 2) + s)
end
"""
tversky_loss(ŷ, y; beta = 0.7)
Return the [Tversky loss](https://arxiv.org/abs/1706.05721).
Used with imbalanced data to give more weight to false negatives.
Larger `β == beta` weigh recall more than precision (by placing more emphasis on false negatives).
Calculated as:
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1)
"""
function tversky_loss(ŷ, y; beta::Real = 0.7, β = nothing)
beta_temp = _greek_ascii_depwarn(β => beta, :tversky_loss, "β" => "beta")
β = ofeltype(ŷ, beta_temp)
_check_sizes(ŷ, y)
#TODO add agg
num = sum(y .* ŷ) + 1
den = sum(y .* ŷ + β * (1 .- y) .* ŷ + (1 - β) * y .* (1 .- ŷ)) + 1
1 - num / den
end
"""
binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps=eps(eltype(ŷ)))
Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output).
For `gamma = 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref).
See also: [`Losses.focal_loss`](@ref) for multi-class setting
# Example
```jldoctest
julia> y = [0 1 0
1 0 1]
2×3 Matrix{Int64}:
0 1 0
1 0 1
julia> ŷ = [0.268941 0.5 0.268941
0.731059 0.5 0.731059]
2×3 Matrix{Float64}:
0.268941 0.5 0.268941
0.731059 0.5 0.731059
julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385
true
```
"""
function binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ = nothing, γ = nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :binary_focal_loss, "ϵ" => "eps")
gamma_temp = _greek_ascii_depwarn(γ => gamma, :binary_focal_loss, "γ" => "gamma")
γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp)
_check_sizes(ŷ, y)
ŷϵ = ŷ .+ ϵ
p_t = y .* ŷϵ + (1 .- y) .* (1 .- ŷϵ)
ce = .-log.(p_t)
weight = (1 .- p_t) .^ γ
loss = weight .* ce
agg(loss)
end
"""
focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps=eps(eltype(ŷ)))
Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
which can be used in classification tasks with highly imbalanced classes.
It down-weights well-classified examples and focuses on hard examples.
The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output).
The modulating factor, `γ == gamma`, controls the down-weighting strength.
For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref).
# Example
```jldoctest
julia> y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
3×5 Matrix{Int64}:
1 0 0 0 1
0 1 0 1 0
0 0 1 0 0
julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Matrix{Float32}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241
julia> Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628
true
```
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
"""
function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ=nothing, γ=nothing)
ϵ = _greek_ascii_depwarn(ϵ => eps, :focal_loss, "ϵ" => "eps")
gamma_temp = _greek_ascii_depwarn(γ => gamma, :focal_loss, "γ" => "gamma")
γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp)
_check_sizes(ŷ, y)
ŷϵ = ŷ .+ ϵ
agg(sum(@. -y * (1 - ŷϵ)^γ * log(ŷϵ); dims))
end
"""
siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean)
Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
which can be useful for training Siamese Networks. It is given by
agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
Specify `margin` to set the baseline for distance at which pairs are dissimilar.
# Example
```jldoctest
julia> ŷ = [0.5, 1.5, 2.5];
julia> Flux.siamese_contrastive_loss(ŷ, 1:3)
-4.833333333333333
julia> Flux.siamese_contrastive_loss(ŷ, 1:3, margin = 2)
-4.0
```
"""
function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1)
_check_sizes(ŷ, y)
margin < 0 && throw(DomainError(margin, "Margin must be non-negative"))
return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
end
```@meta
DocTestFilters = nothing
```