Skip to content

Commit 5c8b762

Browse files
Add registry function and tests for tf.keras.layers.EinsumDense.
PiperOrigin-RevId: 511864469
1 parent d7cd3f8 commit 5c8b762

File tree

6 files changed

+637
-87
lines changed

6 files changed

+637
-87
lines changed

tensorflow_privacy/privacy/fast_gradient_clipping/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@ py_library(
66
name = "gradient_clipping_utils",
77
srcs = ["gradient_clipping_utils.py"],
88
srcs_version = "PY3",
9+
deps = [":layer_registry"],
10+
)
11+
12+
py_library(
13+
name = "einsum_utils",
14+
srcs = ["einsum_utils.py"],
15+
srcs_version = "PY3",
916
)
1017

1118
py_library(
1219
name = "layer_registry",
1320
srcs = ["layer_registry.py"],
1421
srcs_version = "PY3",
22+
deps = [":einsum_utils"],
1523
)
1624

1725
py_library(

tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323

2424
import tensorflow as tf
2525
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
26+
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
2627

2728

28-
def get_registry_generator_fn(tape, layer_registry):
29+
def get_registry_generator_fn(
30+
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
31+
):
2932
"""Creates the generator function for `compute_gradient_norms()`."""
3033
if layer_registry is None:
3134
# Needed for backwards compatibility.
@@ -53,7 +56,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
5356
return registry_generator_fn
5457

5558

56-
def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
59+
def compute_gradient_norms(
60+
input_model: tf.keras.Model,
61+
x_batch: tf.Tensor,
62+
y_batch: tf.Tensor,
63+
layer_registry: lr.LayerRegistry,
64+
):
5765
"""Computes the per-example loss gradient norms for given data.
5866
5967
Applies a variant of the approach given in
@@ -106,7 +114,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
106114
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))
107115

108116

109-
def compute_clip_weights(l2_norm_clip, gradient_norms):
117+
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
110118
"""Computes the per-example loss/clip weights for clipping.
111119
112120
When the sum of the per-example losses is replaced a weighted sum, where
@@ -132,7 +140,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):
132140

133141

134142
def compute_pred_and_clipped_gradients(
135-
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
143+
input_model: tf.keras.Model,
144+
x_batch: tf.Tensor,
145+
y_batch: tf.Tensor,
146+
l2_norm_clip: float,
147+
layer_registry: lr.LayerRegistry,
136148
):
137149
"""Computes the per-example predictions and per-example clipped loss gradient.
138150

0 commit comments

Comments
 (0)