-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Examples] refactor examples and add github action to automatically validate the scripts. #138
Conversation
e517a1c
to
586c344
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some minor comments. Otherwise LGTM
parser = argparse.ArgumentParser() | ||
parser.add_argument("--method", type=str, default="explicit") | ||
parser.add_argument("--device", type=str, default="cuda") | ||
args = parser.parse_args() | ||
|
||
# load the dataset, we only need the train dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this line is not grammatically correct
parser.add_argument("--device", type=str, default="cuda") | ||
args = parser.parse_args() | ||
|
||
# load the dataset, we only need the train dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same grammatical issue
@@ -42,17 +49,18 @@ def get_mnist_indices_and_adjust_labels(dataset): | |||
sampler=SubsetSampler(range(1000)), | |||
) | |||
|
|||
model = train_mnist_lr(train_loader_full) | |||
model.cuda() | |||
model = train_cifar2_resnet9(train_loader, num_epochs=3, num_classes=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe rename this function as train_cifar_resnet9?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. I have revised the module to be named as dattri.benchmark.datasets.cifar
that contains both cifar2 and cifar10 functions.
examples/readme.md
Outdated
@@ -0,0 +1,25 @@ | |||
# `dattri` examples | |||
This folder contains bite-sized examples which can help users to build their own application by `dattri`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This folder contains bite-sized examples that can help users build their own applications with dattri
.
examples/readme.md
Outdated
## Noisy label detection | ||
This section contains using different attributors to detect the noisy label in various datasets. | ||
|
||
[Use influence function to detect noisy label in Mnist10 + Logistic regression.](./noisy_label_detection/influence_function_noisy_label.py) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label -> labels
examples/readme.md
Outdated
|
||
[Use influence function to detect noisy label in Mnist10 + Logistic regression.](./noisy_label_detection/influence_function_noisy_label.py) | ||
|
||
[Use TracIN to detect noisy label in Mnist10 + MLP.](./noisy_label_detection/tracin_noisy_label.py) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label -> labels
examples/readme.md
Outdated
|
||
[Use TracIN to detect noisy label in Mnist10 + MLP.](./noisy_label_detection/tracin_noisy_label.py) | ||
|
||
[Use TRAK to detect noisy label in CIFAR10 + ResNet-9.](./noisy_label_detection/trak_noisy_label.py) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label -> labels
examples/readme.md
Outdated
|
||
## Use pretrained checkpoints and pre-calculated ground truth | ||
|
||
This section contains examples to use the pretrained checkpoints and pre-calculated ground truth provided by `dattri` to evaluate the data attribution methods. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to use -> using
examples/readme.md
Outdated
|
||
## Estimate the brittleness | ||
|
||
This section contains examples to use attribution score to estimate the brittleness of a model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to use -> using
attribution score -> attribution scores
586c344
to
1561cbb
Compare
Description
1. Motivation and Context
Our example collection is kind of messy and hard to navigate. This PR aims to make the organized and add github action to test them to enable long-term usability.
This PR does not add any examples. New example scripts will be updated according to the necessity and request.
2. Summary of the change
/example
->/examples
dattri
repo (add cifar10 support,AttributionTask
will load the checkpoints in the same device as the model, auc will transform the score to cpu to avoid user redundant code).3. What tests have been added/updated for the change?