Skip to content

Commit

Permalink
Updating parallel parameter of AutoAttack. Cleaning up notebook
Browse files Browse the repository at this point in the history
Signed-off-by: Kieran Fraser <Kieran.Fraser@ibm.com>
  • Loading branch information
kieranfraser committed Sep 8, 2023
1 parent 4b00770 commit f6a3a5d
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 1,852 deletions.
14 changes: 9 additions & 5 deletions art/attacks/evasion/auto_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class AutoAttack(EvasionAttack):
"batch_size",
"estimator_orig",
"targeted",
"parallel",
]

_estimator_requirements = (BaseEstimator, ClassifierMixin)
Expand All @@ -70,7 +71,7 @@ def __init__(
batch_size: int = 32,
estimator_orig: Optional["CLASSIFIER_TYPE"] = None,
targeted: bool = False,
in_parallel: bool = False,
parallel: bool = False,
):
"""
Create a :class:`.AutoAttack` instance.
Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(
self.estimator_orig = estimator

self._targeted = targeted
self.in_parallel = in_parallel
self.parallel = parallel
self._check_params()

def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
Expand Down Expand Up @@ -185,7 +186,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
if attack.targeted:
attack.set_params(targeted=False)

if self.in_parallel:
if self.parallel:
args.append(
(
deepcopy(x_adv),
Expand Down Expand Up @@ -232,7 +233,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
targeted_labels[:, i], nb_classes=self.estimator.nb_classes
)

if self.in_parallel:
if self.parallel:
args.append(
(
deepcopy(x_adv),
Expand All @@ -258,7 +259,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
except ValueError as error:
logger.warning("Error completing attack: %s}", str(error))

if self.in_parallel:
if self.parallel:
with multiprocess.get_context("spawn").Pool() as pool:
# Results come back in the order that they were issued
results = pool.starmap(run_attack, args)
Expand Down Expand Up @@ -303,6 +304,9 @@ def run_attack(
:param y: An array of the labels.
:param sample_is_robust: Store the initial robustness of examples.
:param attack: Evasion attack to run.
:param estimator_orig: Original estimator to be attacked by adversarial examples.
:param norm: The norm of the adversarial perturbation. Possible values: "inf", np.inf, 1 or 2.
:param eps: Maximum perturbation that the attacker can introduce.
:return: An array holding the adversarial examples.
"""
# Attack only correctly classified samples
Expand Down
Loading

0 comments on commit f6a3a5d

Please sign in to comment.