Skip to content

Commit e307f37

Browse files
Add DPEvent to the measurements in DifferentiallyPrivateFactory's next_fn. This requires implementing initial_sample_state in TreeRangeSumQuery for tests to pass.
PiperOrigin-RevId: 479436434
1 parent 0738d6f commit e307f37

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

tensorflow_privacy/privacy/dp_query/dp_query.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,11 @@ def derive_metrics(self, global_state):
268268

269269
def _zeros_like(arg):
270270
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
271-
try:
272-
arg = tf.convert_to_tensor(value=arg)
273-
except TypeError:
274-
pass
271+
if not isinstance(arg, tf.TensorSpec):
272+
try:
273+
arg = tf.convert_to_tensor(value=arg)
274+
except TypeError:
275+
pass
275276
return tf.zeros(arg.shape, arg.dtype)
276277

277278

tensorflow_privacy/privacy/dp_query/tree_range_query.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import distutils
2020
import math
21-
from typing import Optional
21+
from typing import Any, Optional
2222

2323
import attr
2424
import dp_accounting
@@ -136,6 +136,12 @@ def initial_global_state(self):
136136
arity=self._arity,
137137
inner_query_state=self._inner_query.initial_global_state())
138138

139+
def initial_sample_state(self, template: Optional[Any] = None):
140+
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
141+
unprocessed_sample_state = super().initial_sample_state(template)
142+
sample_params = self.derive_sample_params(self.initial_global_state())
143+
return self.preprocess_record(sample_params, unprocessed_sample_state)
144+
139145
def derive_sample_params(self, global_state):
140146
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
141147
return (global_state.arity,

0 commit comments

Comments
 (0)