Skip to content

Commit

Permalink
style changes after review
Browse files Browse the repository at this point in the history
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
  • Loading branch information
Muhammad Zaid Hameed authored and Muhammad Zaid Hameed committed Dec 12, 2023
1 parent 13c9e98 commit e472e5f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
2 changes: 1 addition & 1 deletion art/defences/trainer/adversarial_trainer_oaat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
:param lpips_classifier: Weight averaging model for calculating activations.
:param list_avg_models: list of models for weight averaging.
:param attack: attack to use for data augmentation in adversarial training
:param train_params: parmaters' dictionary related to adversarial training
:param train_params: parameters' dictionary related to adversarial training
"""
self._attack = attack
self._proxy_classifier = proxy_classifier
Expand Down
34 changes: 14 additions & 20 deletions art/defences/trainer/adversarial_trainer_oaat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@
"""
from __future__ import absolute_import, division, print_function, unicode_literals

from collections import OrderedDict
import logging
import os
import time
from typing import Optional, Tuple, TYPE_CHECKING, List, Dict, Union
from collections import OrderedDict
import six

import six
import numpy as np
from tqdm.auto import trange
from art import config

from art import config
from art.defences.trainer.adversarial_trainer_oaat import AdversarialTrainerOAAT
from art.estimators.classification.pytorch import PyTorchClassifier
from art.data_generators import DataGenerator
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
:param lpips_classifier: Weight averaging model for calculating activations.
:param list_avg_models: list of models for weight averaging.
:param attack: attack to use for data augmentation in adversarial training.
:param train_params: training parmaters' dictionary related to adversarial training
:param train_params: training parameters' dictionary related to adversarial training
"""
super().__init__(classifier, proxy_classifier, lpips_classifier, list_avg_models, attack, train_params)
self._classifier: PyTorchClassifier
Expand Down Expand Up @@ -104,7 +105,6 @@ def fit(
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
the target classifier.
"""
import os
import torch

logger.info("Performing adversarial training with OAAT protocol")
Expand Down Expand Up @@ -302,7 +302,6 @@ def fit_generator(
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
the target classifier.
"""
import os
import torch

logger.info("Performing adversarial training with OAAT protocol")
Expand Down Expand Up @@ -895,7 +894,7 @@ def update_learning_rate(
else:
raise ValueError(f"lr_schedule {lr_schedule} not supported")

def _attack_lpips( # type: ignore
def _attack_lpips(
self,
x: np.ndarray,
y: np.ndarray,
Expand Down Expand Up @@ -993,7 +992,7 @@ def _one_step_adv_example(

return x_adv

def _compute_perturbation( # pylint: disable=W0221
def _compute_perturbation(
self, x: "torch.Tensor", x_init: "torch.Tensor", y: "torch.Tensor", training_mode: bool = False
) -> "torch.Tensor":
"""
Expand All @@ -1010,9 +1009,6 @@ def _compute_perturbation( # pylint: disable=W0221
"""
import torch

# Pick a small scalar to avoid division by 0
tol = 10e-8

self._classifier.model.train(mode=training_mode)
self._lpips_classifier.model.train(mode=training_mode)

Expand Down Expand Up @@ -1124,17 +1120,17 @@ def _compute_perturbation( # pylint: disable=W0221

elif self._train_params["norm"] == 1:
ind = tuple(range(1, len(x.shape)))
grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + tol) # type: ignore
grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + EPS) # type: ignore

elif self._train_params["norm"] == 2:
ind = tuple(range(1, len(x.shape)))
grad = grad / (torch.sqrt(torch.sum(grad * grad, axis=ind, keepdims=True)) + tol) # type: ignore
grad = grad / (torch.sqrt(torch.sum(grad * grad, axis=ind, keepdims=True)) + EPS) # type: ignore

assert x.shape == grad.shape

return grad

def _apply_perturbation( # pylint: disable=W0221
def _apply_perturbation(
self, x: "torch.Tensor", perturbation: "torch.Tensor", eps_step: Union[int, float, np.ndarray]
) -> "torch.Tensor":
"""
Expand Down Expand Up @@ -1173,8 +1169,6 @@ def _projection(
"""
import torch

# Pick a small scalar to avoid division by 0
tol = 10e-8
values_tmp = values.reshape(values.shape[0], -1)

if norm_p == 2:
Expand All @@ -1187,7 +1181,7 @@ def _projection(
values_tmp
* torch.min(
torch.tensor([1.0], dtype=torch.float32).to(self._classifier.device),
eps / (torch.norm(values_tmp, p=2, dim=1) + tol),
eps / (torch.norm(values_tmp, p=2, dim=1) + EPS),
).unsqueeze_(-1)
)

Expand All @@ -1201,14 +1195,14 @@ def _projection(
values_tmp
* torch.min(
torch.tensor([1.0], dtype=torch.float32).to(self._classifier.device),
eps / (torch.norm(values_tmp, p=1, dim=1) + tol),
eps / (torch.norm(values_tmp, p=1, dim=1) + EPS),
).unsqueeze_(-1)
)

elif norm_p in [np.inf, "inf"]:
if isinstance(eps, np.ndarray):
eps = eps * np.ones_like(values.cpu())
eps = eps.reshape([eps.shape[0], -1]) # type: ignore
eps_array = eps * np.ones_like(values.cpu())
eps = eps_array.reshape([eps_array.shape[0], -1])

values_tmp = values_tmp.sign() * torch.min(
values_tmp.abs(), torch.tensor([eps], dtype=torch.float32).to(self._classifier.device)
Expand Down

0 comments on commit e472e5f

Please sign in to comment.