Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding label check to trades adversarial trainer #2231

Merged
merged 5 commits into from
Aug 17, 2023

Conversation

GiulioZizzo
Copy link
Collaborator

@GiulioZizzo GiulioZizzo commented Jul 28, 2023

Description

We add a dimensionality check to the labels in Trades adversarial training before applying argmax operations.

Fixes #2230

Type of change

Please check all relevant options.

  • Improvement (non-breaking)
  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Testing

Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.

  • Updates to test_adversarial_trainer_trades_pytorch to check with both one hot and class index style of label.

Test Configuration:

  • OS: MacOS
  • Python version: 3.9
  • ART version or commit number: ART 1.15
  • TensorFlow / Keras / PyTorch / MXNet version: Torch 2.0.1

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
@codecov-commenter
Copy link

codecov-commenter commented Jul 28, 2023

Codecov Report

Merging #2231 (d43f473) into dev_1.15.1 (c63d5d5) will increase coverage by 1.00%.
The diff coverage is 100.00%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Impacted file tree graph

@@              Coverage Diff               @@
##           dev_1.15.1    #2231      +/-   ##
==============================================
+ Coverage       84.60%   85.61%   +1.00%     
==============================================
  Files             308      308              
  Lines           27464    27470       +6     
  Branches         5045     5046       +1     
==============================================
+ Hits            23237    23519     +282     
+ Misses           2961     2670     -291     
- Partials         1266     1281      +15     
Files Changed Coverage Δ
...nces/trainer/adversarial_trainer_trades_pytorch.py 88.39% <100.00%> (+0.65%) ⬆️

... and 11 files with indirect coverage changes

@beat-buesser beat-buesser self-assigned this Aug 3, 2023
@beat-buesser beat-buesser added this to the ART 1.15.1 milestone Aug 3, 2023
@beat-buesser beat-buesser added the improvement Improve implementation label Aug 3, 2023
@beat-buesser beat-buesser changed the base branch from dev_1.16.0 to dev_1.15.1 August 3, 2023 11:20
@beat-buesser
Copy link
Collaborator

Hi @GiulioZizzo Thank you very much for you pull request! I have changed the target branch to dev_1.15.1 for the next patch release.

@Zaid-Hameed Because you have implemented the TRADES trainer, could you please add a review of the proposed changes?

@beat-buesser beat-buesser self-requested a review August 7, 2023 14:37
Copy link
Collaborator

@Zaid-Hameed Zaid-Hameed Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Giulio, Can you please check if adding from_logits=True makes any changes to tests outcomes i.e., changing line 37 to
classifier, _ = image_dl_estimator(from_logits=True)

This is just to make sure that softmax is not being applied twice in test files.

Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
Copy link
Collaborator

@Zaid-Hameed Zaid-Hameed left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Giulio for making the changes.

Copy link
Collaborator

@beat-buesser beat-buesser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @GiulioZizzo I think you proposal for functionality is good. I'm proposing to follow the ART patterns in solving the issue, what do you think?

@@ -240,7 +245,7 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
)

# Check label shape
if self._classifier._reduce_labels: # pylint: disable=W0212
if self._classifier._reduce_labels and y_preprocessed.ndim > 1: # pylint: disable=W0212
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed because ART uses internally the one-hot-encoded format for classification labels.

output = np.argmax(self.predict(x_test), axis=1)
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
if y_test.ndim > 1:
Copy link
Collaborator

@beat-buesser beat-buesser Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of this change is good. Let's follow the ART pattern by using y = check_and_transform_label_format(y, nb_classes=self._classifier.nb_classes) for y and validation_data[1] imported with from art.utils import check_and_transform_label_format at the beginning of this method to ensure they are both one-hot-encoded.

@beat-buesser beat-buesser merged commit 52c240a into Trusted-AI:dev_1.15.1 Aug 17, 2023
@beat-buesser beat-buesser linked an issue Aug 17, 2023 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improve implementation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding a label check for Trades adversarial trainer.
4 participants