Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Oracle Aligned Adversarial Training #2348

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style check corrections
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
  • Loading branch information
Muhammad Zaid Hameed authored and Muhammad Zaid Hameed committed Dec 9, 2023
commit 13c9e988b314510067ebbde50e7efca372342f46
5 changes: 2 additions & 3 deletions art/defences/trainer/adversarial_trainer_oaat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import abc
from typing import Optional, Tuple, List, TYPE_CHECKING
from typing import Optional, Tuple, TYPE_CHECKING, Sequence

import numpy as np

Expand All @@ -51,7 +51,7 @@ def __init__(
classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
proxy_classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
lpips_classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
list_avg_models: List["CLASSIFIER_LOSS_GRADIENTS_TYPE"],
list_avg_models: Sequence["CLASSIFIER_LOSS_GRADIENTS_TYPE"],
attack: EvasionAttack,
train_params: dict,
):
Expand All @@ -66,7 +66,6 @@ def __init__(
:param train_params: parmaters' dictionary related to adversarial training
beat-buesser marked this conversation as resolved.
Show resolved Hide resolved
"""
self._attack = attack
self._classifier = classifier
self._proxy_classifier = proxy_classifier
self._lpips_classifier = lpips_classifier
self._list_avg_models = list_avg_models
Expand Down
33 changes: 19 additions & 14 deletions art/defences/trainer/adversarial_trainer_oaat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def _modify_classifier(
def get_layer_activations( # type: ignore
p_classifier: PyTorchClassifier,
x: "torch.Tensor",
layers: List[Union[int, str]] = None,
layers: List[Union[int, str]],
) -> Tuple[Dict[str, "torch.Tensor"], List[str]]:
"""
Return the output of the specified layers for input `x`. `layers` is a list of either layer indices (between 0
Expand All @@ -750,21 +750,20 @@ def get_layer_activations( # type: ignore
:return: Tuple containing the output dict and a list of layers' names. In dictionary each element is a
layer's output where the first dimension is the batch size corresponding to `x'.
"""
import torch

p_classifier.model.train(mode=False)

list_layer_names = []
for layer in layers:
if isinstance(layer, six.string_types):
if layer not in p_classifier.layer_names:
if layer not in p_classifier._layer_names: # pylint: disable=W0212
raise ValueError(f"Layer name {layer} not supported")
layer_name = layer
list_layer_names.append(layer_name)

elif isinstance(layer, int):
layer_index = layer
layer_name = p_classifier.layer_names[layer_index]
layer_name = p_classifier._layer_names[layer_index] # pylint: disable=W0212
list_layer_names.append(layer_name)

else:
Expand All @@ -778,7 +777,7 @@ def hook(model, input, output): # pylint: disable=W0622,W0613
return hook

if not hasattr(p_classifier, "_features"):
p_classifier._features: Dict[str, torch.Tensor] = {}
p_classifier._features = {} # pylint: disable=W0212
# register forward hooks on the layers of choice

for layer_name in list_layer_names:
Expand Down Expand Up @@ -825,7 +824,7 @@ def calculate_lpips_distance( # type: ignore
p_classifier: PyTorchClassifier,
input_1: "torch.Tensor",
input_2: "torch.Tensor",
layers: List[Union[int, str]] = None,
layers: List[Union[int, str]],
) -> "torch.Tensor":
"""
Return the LPIPS distance between input_1 and input_2. `layers` is a list of either layer indices (between 0 and
Expand Down Expand Up @@ -874,22 +873,28 @@ def update_learning_rate(
else:
l_r = self._train_params["lr"] * 0.5 * (1 + np.cos(epoch / nb_epochs * np.pi))

for param_group in optimizer.param_groups:
param_group["lr"] = l_r

elif lr_schedule.lower() == "linear":
l_r = (epoch + 1) * (self._train_params["lr"] / 10)

for param_group in optimizer.param_groups:
param_group["lr"] = l_r

elif lr_schedule.lower() == "step":
if epoch >= 75 * nb_epochs / 110:
l_r = self._train_params["lr"] * 0.1
if epoch >= 90 * nb_epochs / 110:
l_r = self._train_params["lr"] * 0.01
if epoch >= 100 * nb_epochs / 110:
l_r = self._train_params["lr"] * 0.001

for param_group in optimizer.param_groups:
param_group["lr"] = l_r
else:
raise ValueError(f"lr_schedule {lr_schedule} not supported")

for param_group in optimizer.param_groups:
param_group["lr"] = l_r

def _attack_lpips( # type: ignore
beat-buesser marked this conversation as resolved.
Show resolved Hide resolved
self,
x: np.ndarray,
Expand All @@ -898,7 +903,7 @@ def _attack_lpips( # type: ignore
eps_step: Union[int, float, np.ndarray],
max_iter: int,
training_mode: bool,
) -> "torch.Tensor":
) -> np.ndarray:
"""
Compute adversarial examples with cross entropy and lpips distance.

Expand All @@ -916,12 +921,12 @@ def _attack_lpips( # type: ignore
"""
import torch

x = torch.from_numpy(x.astype(ART_NUMPY_DTYPE)).to(self._classifier.device)
y = torch.from_numpy(y.astype(ART_NUMPY_DTYPE)).to(self._classifier.device)
adv_x = torch.clone(x)
x_t = torch.from_numpy(x.astype(ART_NUMPY_DTYPE)).to(self._classifier.device)
y_t = torch.from_numpy(y.astype(ART_NUMPY_DTYPE)).to(self._classifier.device)
adv_x = torch.clone(x_t)

for i_max_iter in range(max_iter):
adv_x = self._one_step_adv_example(adv_x, x, y, eps, eps_step, i_max_iter == 0, training_mode)
adv_x = self._one_step_adv_example(adv_x, x_t, y_t, eps, eps_step, i_max_iter == 0, training_mode)

return adv_x.cpu().detach().numpy()

Expand Down