23
23
24
24
import tensorflow as tf
25
25
from tensorflow_privacy .privacy .fast_gradient_clipping import gradient_clipping_utils
26
+ from tensorflow_privacy .privacy .fast_gradient_clipping import layer_registry as lr
26
27
27
28
28
- def get_registry_generator_fn (tape , layer_registry ):
29
+ def get_registry_generator_fn (
30
+ tape : tf .GradientTape , layer_registry : lr .LayerRegistry
31
+ ):
29
32
"""Creates the generator function for `compute_gradient_norms()`."""
30
33
if layer_registry is None :
31
34
# Needed for backwards compatibility.
@@ -53,7 +56,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
53
56
return registry_generator_fn
54
57
55
58
56
- def compute_gradient_norms (input_model , x_batch , y_batch , layer_registry ):
59
+ def compute_gradient_norms (
60
+ input_model : tf .keras .Model ,
61
+ x_batch : tf .Tensor ,
62
+ y_batch : tf .Tensor ,
63
+ layer_registry : lr .LayerRegistry ,
64
+ ):
57
65
"""Computes the per-example loss gradient norms for given data.
58
66
59
67
Applies a variant of the approach given in
@@ -106,7 +114,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
106
114
return tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
107
115
108
116
109
- def compute_clip_weights (l2_norm_clip , gradient_norms ):
117
+ def compute_clip_weights (l2_norm_clip : float , gradient_norms : tf . Tensor ):
110
118
"""Computes the per-example loss/clip weights for clipping.
111
119
112
120
When the sum of the per-example losses is replaced a weighted sum, where
@@ -132,7 +140,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
132
140
133
141
134
142
def compute_pred_and_clipped_gradients (
135
- input_model , x_batch , y_batch , l2_norm_clip , layer_registry
143
+ input_model : tf .keras .Model ,
144
+ x_batch : tf .Tensor ,
145
+ y_batch : tf .Tensor ,
146
+ l2_norm_clip : float ,
147
+ layer_registry : lr .LayerRegistry ,
136
148
):
137
149
"""Computes the per-example predictions and per-example clipped loss gradient.
138
150
0 commit comments