Skip to content

doctest update metrics #2144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
261 commits
Select commit Hold shift + click to select a range
8c963b1
Update multilabel_confusion_matrix.py
nataliyah123 Sep 2, 2020
d370fed
Update matthews_correlation_coefficient.py
nataliyah123 Sep 2, 2020
8d3789f
Update README.md (#2148)
seanpmorgan Sep 3, 2020
d466cb8
#2066 doctest update losses (#2138)
Harsh188 Sep 4, 2020
e5d641d
Update hamming.py
nataliyah123 Sep 4, 2020
c68e59e
Update hamming.py
nataliyah123 Sep 4, 2020
4f5f770
Update r_square.py
nataliyah123 Sep 4, 2020
fbb1551
Update matthews_correlation_coefficient.py
nataliyah123 Sep 4, 2020
79e155d
Update multilabel_confusion_matrix.py
nataliyah123 Sep 4, 2020
b9404dd
Update r_square.py
nataliyah123 Sep 4, 2020
1192360
Update matthews_correlation_coefficient.py
nataliyah123 Sep 4, 2020
83be56f
Update multilabel_confusion_matrix.py
nataliyah123 Sep 4, 2020
6d97289
Update r_square.py
nataliyah123 Sep 4, 2020
6441acd
Update r_square.py
nataliyah123 Sep 4, 2020
eae4f29
keras Api example
nataliyah123 Sep 5, 2020
cc0f689
Update cohens_kappa.py
nataliyah123 Sep 5, 2020
c137b16
Update r_square.py
nataliyah123 Sep 5, 2020
27e3147
Update r_square.py
nataliyah123 Sep 5, 2020
aa67440
Update matthews_correlation_coefficient.py
nataliyah123 Sep 5, 2020
91cdf73
uncommented update_state
nataliyah123 Sep 6, 2020
9fb5b53
requested_change
nataliyah123 Sep 6, 2020
04f68e6
Update cohens_kappa.py
nataliyah123 Sep 6, 2020
87d4711
Update r_square.py
nataliyah123 Sep 6, 2020
134f9ef
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
3424dc1
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
0d568af
Update matthews_correlation_coefficient.py
nataliyah123 Sep 6, 2020
db472cb
Update hamming.py
nataliyah123 Sep 6, 2020
61ccc42
Update geometric_mean.py
nataliyah123 Sep 6, 2020
20a74ff
Update cohens_kappa.py
nataliyah123 Sep 6, 2020
1240bd3
Update hamming.py
nataliyah123 Sep 6, 2020
facd2c9
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
82ac7ce
Update hamming.py
nataliyah123 Sep 6, 2020
d049d58
Add doctest section (#2151)
WindQAQ Sep 8, 2020
fc2c796
nbfmt tutorial notebooks (#2155)
lamberta Sep 9, 2020
fc50c34
Support fill_mode for transform (#2153)
WindQAQ Sep 10, 2020
3d37870
Speedup gaussian kernel generation (#2149)
WindQAQ Sep 10, 2020
41133a3
Correct a typo in average_optimizers_callback.ipynb (#2159)
anuragarnab Sep 10, 2020
5bdaf40
Avoid unnecessary reshapes for instance norm (#2158)
kaixih Sep 11, 2020
7f6bcf7
beam search decoding procedure added to seq2seq_nmt tutorial (#2140)
abhishek-niranjan Sep 11, 2020
e3e0853
Setup notebook testing (#2160)
WindQAQ Sep 11, 2020
4eeb2b2
Update hamming.py
nataliyah123 Sep 14, 2020
62fb9d5
y_true, y_pred are not accepted in hamming.py
nataliyah123 Sep 14, 2020
62e1eb8
Update hamming.py
nataliyah123 Sep 14, 2020
eb2ffb8
Update hamming.py
nataliyah123 Sep 14, 2020
58eae9f
assigned metric.update_Status to result
nataliyah123 Sep 14, 2020
a57a7f3
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
388e9e1
assigned m.update_status (yactual,ypred) to result
nataliyah123 Sep 14, 2020
8dbeaaa
Update r_square.py
nataliyah123 Sep 14, 2020
e818d2e
Update r_square.py
nataliyah123 Sep 14, 2020
07365ac
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
557cc3b
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
13466d7
Update geometric_mean.py
nataliyah123 Sep 14, 2020
1118061
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
3897e2f
Discriminative Layer Training (#969)
hyang0129 Sep 14, 2020
816aa23
Update geometric_mean.py
nataliyah123 Sep 14, 2020
ba10725
Update hamming.py
nataliyah123 Sep 14, 2020
ea66760
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
d99502f
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
27a975e
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
adc06a7
fixed typo in multioptimizer class name and added code owners (#2164)
hyang0129 Sep 14, 2020
4fc2ced
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
f6575a7
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
36cd439
Update hamming.py
nataliyah123 Sep 14, 2020
51f2b5f
Update hamming.py
nataliyah123 Sep 14, 2020
37f40a7
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
84a3ce6
Update hamming.py
nataliyah123 Sep 14, 2020
84da148
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
1b79cc4
Update hamming.py
nataliyah123 Sep 14, 2020
bcc75fa
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
e642583
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
7b8c9bc
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
c06c27b
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
dbd3261
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
2ddb74d
Update multilabel_confusion_matrix.py
nataliyah123 Sep 15, 2020
abe8561
unified diff
nataliyah123 Sep 15, 2020
cab0837
Update multilabel_confusion_matrix.py
nataliyah123 Sep 15, 2020
8320f16
Rnn testable doctests (#2147)
Harsh188 Sep 15, 2020
1c3c072
Added support for noisy dense layers. (#2099)
LeonShams Sep 15, 2020
dbcd5aa
Added stochastic depth layer (#2154)
MHStadler Sep 16, 2020
5fd46e9
fixed example
nataliyah123 Sep 17, 2020
e2006c3
Added filtered_input and constrained_decoding (#2166)
napsternxg Sep 17, 2020
13e40e6
Moved build_docs.py and BUILD into /tools/docs/ (#2167)
hp77-creator Sep 18, 2020
3aa7c61
added support for blankline
nataliyah123 Sep 20, 2020
1661c61
'space'
nataliyah123 Sep 20, 2020
ba99b55
update metrics cohen kappa
nataliyah123 Sep 2, 2020
6cc7a54
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
1b1b602
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
e549e9d
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
82e26b1
Update hamming.py
nataliyah123 Sep 2, 2020
6d58948
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
5689900
Update matthews_correlation_coefficient.py
nataliyah123 Sep 2, 2020
6f7cc89
Update multilabel_confusion_matrix.py
nataliyah123 Sep 2, 2020
b4f2c4c
Update r_square.py
nataliyah123 Sep 2, 2020
c5328d0
Update r_square.py
nataliyah123 Sep 2, 2020
b8ebbf8
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
3c2d1a4
Update multilabel_confusion_matrix.py
nataliyah123 Sep 2, 2020
05ce484
Update matthews_correlation_coefficient.py
nataliyah123 Sep 2, 2020
9ce7078
Update hamming.py
nataliyah123 Sep 4, 2020
78fdfec
Update hamming.py
nataliyah123 Sep 4, 2020
4e2ca52
Update r_square.py
nataliyah123 Sep 4, 2020
738886a
Update matthews_correlation_coefficient.py
nataliyah123 Sep 4, 2020
37e8220
Update multilabel_confusion_matrix.py
nataliyah123 Sep 4, 2020
2773675
Update r_square.py
nataliyah123 Sep 4, 2020
30f01a7
Update matthews_correlation_coefficient.py
nataliyah123 Sep 4, 2020
af599ef
Update multilabel_confusion_matrix.py
nataliyah123 Sep 4, 2020
199ee0f
Update r_square.py
nataliyah123 Sep 4, 2020
120222f
Update r_square.py
nataliyah123 Sep 4, 2020
44fe8f6
keras Api example
nataliyah123 Sep 5, 2020
1230dcf
Update cohens_kappa.py
nataliyah123 Sep 5, 2020
3984ece
Update r_square.py
nataliyah123 Sep 5, 2020
a0718ce
Update r_square.py
nataliyah123 Sep 5, 2020
20d26fe
Update matthews_correlation_coefficient.py
nataliyah123 Sep 5, 2020
50a4ef1
uncommented update_state
nataliyah123 Sep 6, 2020
ae251ac
requested_change
nataliyah123 Sep 6, 2020
e8df84f
Update cohens_kappa.py
nataliyah123 Sep 6, 2020
d1acae0
Update r_square.py
nataliyah123 Sep 6, 2020
73263d6
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
e607e61
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
5b6f6e2
Update matthews_correlation_coefficient.py
nataliyah123 Sep 6, 2020
97bd120
Update hamming.py
nataliyah123 Sep 6, 2020
687c0b7
Update geometric_mean.py
nataliyah123 Sep 6, 2020
e1ec9c7
Update cohens_kappa.py
nataliyah123 Sep 6, 2020
0250834
Update hamming.py
nataliyah123 Sep 6, 2020
cd166b4
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
27d58b0
Update hamming.py
nataliyah123 Sep 6, 2020
3c266b0
Update hamming.py
nataliyah123 Sep 14, 2020
edc9d2f
y_true, y_pred are not accepted in hamming.py
nataliyah123 Sep 14, 2020
dc27baf
Update hamming.py
nataliyah123 Sep 14, 2020
fea2499
Update hamming.py
nataliyah123 Sep 14, 2020
e4cf1aa
assigned metric.update_Status to result
nataliyah123 Sep 14, 2020
2589f43
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
3162aa9
assigned m.update_status (yactual,ypred) to result
nataliyah123 Sep 14, 2020
36410b8
Update r_square.py
nataliyah123 Sep 14, 2020
0097957
Update r_square.py
nataliyah123 Sep 14, 2020
430894e
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
64d28c8
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
ca7ddc8
Update geometric_mean.py
nataliyah123 Sep 14, 2020
9a86082
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
53017f1
Update geometric_mean.py
nataliyah123 Sep 14, 2020
e5629b1
Update hamming.py
nataliyah123 Sep 14, 2020
adeda64
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
753e36e
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
4a936e2
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
63d7807
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
86213f8
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
d9ff82e
Update hamming.py
nataliyah123 Sep 14, 2020
ed0dc7e
Update hamming.py
nataliyah123 Sep 14, 2020
bfa3613
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
dd00321
Update hamming.py
nataliyah123 Sep 14, 2020
589fbed
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
5b7a915
Update hamming.py
nataliyah123 Sep 14, 2020
2342ae5
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
37c72ba
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
0a9a261
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
6a1fb84
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
d3ed8ed
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
c7d881b
Update multilabel_confusion_matrix.py
nataliyah123 Sep 15, 2020
37bd5ed
unified diff
nataliyah123 Sep 15, 2020
dfc6dc0
Update multilabel_confusion_matrix.py
nataliyah123 Sep 15, 2020
19c0a6c
fixed example
nataliyah123 Sep 17, 2020
6c0e971
added support for blankline
nataliyah123 Sep 20, 2020
ce34419
'space'
nataliyah123 Sep 20, 2020
f241471
Merge branch '#2066-doc-update-metrics' of https://github.com/nataliy…
nataliyah123 Sep 20, 2020
9aa2598
update metrics cohen kappa
nataliyah123 Sep 2, 2020
46c44c0
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
867fa7f
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
7c3557e
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
22516db
Update hamming.py
nataliyah123 Sep 2, 2020
d1cabfb
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
e51d182
Update matthews_correlation_coefficient.py
nataliyah123 Sep 2, 2020
7daf2f2
Update multilabel_confusion_matrix.py
nataliyah123 Sep 2, 2020
591fae3
Update r_square.py
nataliyah123 Sep 2, 2020
da70765
Update r_square.py
nataliyah123 Sep 2, 2020
127bd4e
Update cohens_kappa.py
nataliyah123 Sep 2, 2020
2984c08
Update multilabel_confusion_matrix.py
nataliyah123 Sep 2, 2020
558717b
Update matthews_correlation_coefficient.py
nataliyah123 Sep 2, 2020
6888bd2
Update hamming.py
nataliyah123 Sep 4, 2020
f36b72e
Update hamming.py
nataliyah123 Sep 4, 2020
c62394a
Update r_square.py
nataliyah123 Sep 4, 2020
cd3b849
Update matthews_correlation_coefficient.py
nataliyah123 Sep 4, 2020
620ddc2
Update multilabel_confusion_matrix.py
nataliyah123 Sep 4, 2020
8ad555a
Update r_square.py
nataliyah123 Sep 4, 2020
d7b0c6b
Update matthews_correlation_coefficient.py
nataliyah123 Sep 4, 2020
09f6912
Update multilabel_confusion_matrix.py
nataliyah123 Sep 4, 2020
f3419c2
Update r_square.py
nataliyah123 Sep 4, 2020
8cfed76
Update r_square.py
nataliyah123 Sep 4, 2020
5a832eb
keras Api example
nataliyah123 Sep 5, 2020
552503a
Update cohens_kappa.py
nataliyah123 Sep 5, 2020
1e4bf35
Update r_square.py
nataliyah123 Sep 5, 2020
21a6a23
Update r_square.py
nataliyah123 Sep 5, 2020
0b8e732
Update matthews_correlation_coefficient.py
nataliyah123 Sep 5, 2020
080ae65
uncommented update_state
nataliyah123 Sep 6, 2020
04ed329
requested_change
nataliyah123 Sep 6, 2020
5b5da62
Update cohens_kappa.py
nataliyah123 Sep 6, 2020
480c103
Update r_square.py
nataliyah123 Sep 6, 2020
12cd059
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
f3f920e
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
cf1c24b
Update matthews_correlation_coefficient.py
nataliyah123 Sep 6, 2020
3349e87
Update hamming.py
nataliyah123 Sep 6, 2020
c4a0d61
Update geometric_mean.py
nataliyah123 Sep 6, 2020
ad3e744
Update cohens_kappa.py
nataliyah123 Sep 6, 2020
1e784d2
Update hamming.py
nataliyah123 Sep 6, 2020
4684cbe
Update multilabel_confusion_matrix.py
nataliyah123 Sep 6, 2020
9aa6a97
Update hamming.py
nataliyah123 Sep 6, 2020
b9a2057
Update hamming.py
nataliyah123 Sep 14, 2020
6b767ea
y_true, y_pred are not accepted in hamming.py
nataliyah123 Sep 14, 2020
1d74c2e
Update hamming.py
nataliyah123 Sep 14, 2020
45145cf
Update hamming.py
nataliyah123 Sep 14, 2020
09c6fb8
assigned metric.update_Status to result
nataliyah123 Sep 14, 2020
e49474e
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
4cd4ae6
assigned m.update_status (yactual,ypred) to result
nataliyah123 Sep 14, 2020
f29455b
Update r_square.py
nataliyah123 Sep 14, 2020
095e814
Update r_square.py
nataliyah123 Sep 14, 2020
8cef4bb
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
ee80f36
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
4f37fda
Update geometric_mean.py
nataliyah123 Sep 14, 2020
31c8742
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
83660c8
Update geometric_mean.py
nataliyah123 Sep 14, 2020
4e83541
Update hamming.py
nataliyah123 Sep 14, 2020
1b961ed
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
6957efb
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
058cea4
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
a3cf087
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
8d57deb
Update matthews_correlation_coefficient.py
nataliyah123 Sep 14, 2020
48d4516
Update hamming.py
nataliyah123 Sep 14, 2020
c774e0f
Update hamming.py
nataliyah123 Sep 14, 2020
5540c9a
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
987b2b6
Update hamming.py
nataliyah123 Sep 14, 2020
a4b1f45
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
4135dcb
Update hamming.py
nataliyah123 Sep 14, 2020
c22457a
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
7bca624
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
8dab2de
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
193d93f
Update cohens_kappa.py
nataliyah123 Sep 14, 2020
3979510
Update multilabel_confusion_matrix.py
nataliyah123 Sep 14, 2020
4e1ec0d
Update multilabel_confusion_matrix.py
nataliyah123 Sep 15, 2020
857a2b9
unified diff
nataliyah123 Sep 15, 2020
109f211
Update multilabel_confusion_matrix.py
nataliyah123 Sep 15, 2020
7ebb741
fixed example
nataliyah123 Sep 17, 2020
3f1a30a
added support for blankline
nataliyah123 Sep 20, 2020
568f75c
'space'
nataliyah123 Sep 20, 2020
eaf045a
update metrics cohen kappa
nataliyah123 Sep 2, 2020
6b567f3
Update hamming.py
nataliyah123 Sep 2, 2020
ac6b80b
Update multilabel_confusion_matrix.py
nataliyah123 Sep 2, 2020
f0b20ef
Update geometric_mean.py
nataliyah123 Sep 14, 2020
a43194b
Update geometric_mean.py
nataliyah123 Sep 14, 2020
796d4cb
added support for blankline
nataliyah123 Sep 20, 2020
4a0edb4
'space'
nataliyah123 Sep 20, 2020
e7c574c
r_square file
nataliyah123 Sep 21, 2020
7d85e4f
all
nataliyah123 Sep 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added stochastic depth layer (#2154)
* Added stochastic depth layer

* Fixed code style and added missing __init__ entry

* Fixed tests and style

* Fixed code style

* Updated CODEOWNERS

* Added codeowners for tests

* Changes after code review

* Test and formatting fixes

* Fixed doc string

* Added mixed precision test

* Further code review changes

* Code review changes
  • Loading branch information
MHStadler authored Sep 16, 2020
commit dbcd5aaef6da32bcd24040c55a49cef1538ef765
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
/tensorflow_addons/layers/tests/esn_test.py @pedrolarben
/tensorflow_addons/layers/snake.py @failure-to-thrive
/tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive
/tensorflow_addons/layers/stochastic_depth.py @mhstadler @windqaq
/tensorflow_addons/layers/tests/stochastic_depth_test.py @mhstadler @windqaq
/tensorflow_addons/layers/noisy_dense.py @leonshams
/tensorflow_addons/layers/tests/noisy_dense_test.py @leonshams

Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@
from tensorflow_addons.layers.tlu import TLU
from tensorflow_addons.layers.wrappers import WeightNormalization
from tensorflow_addons.layers.esn import ESN
from tensorflow_addons.layers.stochastic_depth import StochasticDepth
from tensorflow_addons.layers.noisy_dense import NoisyDense
88 changes: 88 additions & 0 deletions tensorflow_addons/layers/stochastic_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import tensorflow as tf
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class StochasticDepth(tf.keras.layers.Layer):
"""Stochastic Depth layer.

Implements Stochastic Depth as described in
[Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches
in residual architectures.

Usage:
Residual architectures with fixed depth, use residual branches that are merged back into the main network
by adding the residual branch back to the input:

>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
>>> output = tf.keras.layers.Add()([input, residual])
>>> output.shape
TensorShape([1, 3, 3, 1])

StochasticDepth acts as a drop-in replacement for the addition:

>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
>>> output = tfa.layers.StochasticDepth()([input, residual])
>>> output.shape
TensorShape([1, 3, 3, 1])

At train time, StochasticDepth returns:

$$
x[0] + b_l * x[1],
$$

where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$

At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$):

$$
x[0] + p_l * x[1]
$$

Arguments:
survival_probability: float, the probability of the residual branch being kept.

Call Arguments:
inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape.

Output shape:
Equal to the shape of inputs `shortcut`, and `residual`
"""

@typechecked
def __init__(self, survival_probability: float = 0.5, **kwargs):
super().__init__(**kwargs)

self.survival_probability = survival_probability

def call(self, x, training=None):
if not isinstance(x, list) or len(x) != 2:
raise ValueError("input must be a list of length 2.")

shortcut, residual = x

# Random bernoulli variable indicating whether the branch should be kept or not or not
b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability)

def _call_train():
return shortcut + b_l * residual

def _call_test():
return shortcut + self.survival_probability * residual

return tf.keras.backend.in_train_phase(
_call_train, _call_test, training=training
)

def compute_output_shape(self, input_shape):
return input_shape[0]

def get_config(self):
base_config = super().get_config()

config = {"survival_probability": self.survival_probability}

return {**base_config, **config}
58 changes: 58 additions & 0 deletions tensorflow_addons/layers/tests/stochastic_depth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import numpy as np
import tensorflow as tf

from tensorflow_addons.layers.stochastic_depth import StochasticDepth
from tensorflow_addons.utils import test_utils

_KEEP_SEED = 1111
_DROP_SEED = 2222


@pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED])
@pytest.mark.parametrize("training", [True, False])
def stochastic_depth_test(seed, training):
np.random.seed(seed)
tf.random.set_seed(seed)

survival_probability = 0.5

shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32)

if training:
if seed == _KEEP_SEED:
# shortcut + residual
expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32)
elif seed == _DROP_SEED:
# shortcut
expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
else:
# shortcut + p_l * residual
expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32)

test_utils.layer_test(
StochasticDepth,
kwargs={"survival_probability": survival_probability},
input_data=[shortcut, residual],
expected_output=expected_output,
)


@pytest.mark.usefixtures("run_with_mixed_precision_policy")
def test_with_mixed_precision_policy():
policy = tf.keras.mixed_precision.experimental.global_policy()

shortcut = np.asarray([[0.2, 0.1, 0.4]])
residual = np.asarray([[0.2, 0.4, 0.5]])

output = StochasticDepth()([shortcut, residual])

assert output.dtype == policy.compute_dtype


def test_serialization():
stoch_depth = StochasticDepth(survival_probability=0.5)
serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth)
new_layer = tf.keras.layers.deserialize(serialized_stoch_depth)
assert stoch_depth.get_config() == new_layer.get_config()