Skip to content

Commit 0738d6f

Browse files
Bugfix.
PiperOrigin-RevId: 478591776
1 parent 3f6d0ac commit 0738d6f

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _run_trained_attack(attack_input: AttackInputData,
6363
left_out_indices = prepared_attacker_data.left_out_indices
6464
features = prepared_attacker_data.features_all
6565
labels = prepared_attacker_data.labels_all
66+
sample_weights = prepared_attacker_data.sample_weights_all
6667

6768
# We are going to train multiple models on disjoint subsets of the data
6869
# (`features`, `labels`), so we can get the membership scores of all samples,
@@ -85,8 +86,21 @@ def _run_trained_attack(attack_input: AttackInputData,
8586
# Make sure one sample only got score predicted once
8687
assert np.all(np.isnan(scores[test_indices]))
8788

89+
# Setup sample weights if provided.
90+
if sample_weights is not None:
91+
# If sample weights are provided, only the weights at the training indices
92+
# are used for training. The weights at the test indices are not used
93+
# during prediction. Not that 'train' and 'test' refer to the data for the
94+
# attack models, not the data for the original models.
95+
sample_weights_train = np.squeeze(sample_weights[train_indices])
96+
else:
97+
sample_weights_train = None
98+
8899
attacker = models.create_attacker(attack_type, backend=backend)
89-
attacker.train_model(features[train_indices], labels[train_indices])
100+
attacker.train_model(
101+
features[train_indices],
102+
labels[train_indices],
103+
sample_weight=sample_weights_train)
90104
predictions = attacker.predict(features[test_indices])
91105
scores[test_indices] = predictions
92106

0 commit comments

Comments
 (0)