Skip to content

Commit

Permalink
Merge pull request #13 from alvinwan/superclass-eval
Browse files Browse the repository at this point in the history
Superclass eval
  • Loading branch information
alvinwan authored Sep 23, 2020
2 parents 96cd44e + 2f70737 commit b8ef680
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 122 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,24 @@ python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn
</div>
</details>

<details><summary><b>Running zero-shot evaluation on superclasses.</summary>

```
# get wnids for animal and vehicle -- use the outputted wnids for below commands
nbdt-wnids --classes animal vehicle
# evaluate CIFAR10-trained ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200
python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=Superclass --superclass-wnids n00015388 n04524313 --pretrained
# download public checkpoint
wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth -O checkpoint/ckpt-CIFAR10-ResNet18-induced-SoftTreeSupLoss.pth
# evaluate CIFAR10-trained NBDT-ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200
python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=SuperclassNBDT --superclass-wnids n00015388 n04524313 --loss=SoftTreeSupLoss --resume
```

</details>

# Results

We compare against all previous decision-tree-based methods that report on CIFAR10, CIFAR100, and/or ImageNet, including methods that hinder interpretability by using impure leaves or a random forest. We report the baseline with the highest accuracy, of all these methods: Deep Neural Decision Forest (DNDF updated with ResNet18), Explainable Observer-Classifier (XOC), Deep ConvolutionalDecision Jungle (DCDJ), Network of Experts (NofE), Deep Decision Network(DDN), and Adaptive Neural Trees (ANT).
Expand Down
43 changes: 31 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from nbdt.models.utils import load_state_dict, make_kwarg_optional

maybe_install_wordnet()

datasets = data.cifar.names + data.imagenet.names + data.custom.names
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch-size', default=512, type=int,
help='Batch size used for training')
parser.add_argument('--epochs', '-e', default=200, type=int,
help='By default, lr schedule is scaled accordingly')
parser.add_argument('--dataset', default='CIFAR10', choices=data.cifar.names + data.imagenet.names + data.custom.names)
parser.add_argument('--dataset', default='CIFAR10', choices=datasets)
parser.add_argument('--arch', default='ResNet18', choices=list(models.get_model_choices()))
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
Expand All @@ -40,6 +40,15 @@
parser.add_argument('--pretrained', action='store_true',
help='Download pretrained model. Not all models support this.')
parser.add_argument('--eval', help='eval only', action='store_true')
parser.add_argument('--dataset-test', choices=datasets,
help='If not set, automatically set to train dataset')
parser.add_argument('--disable-test-eval',
help='Allows you to run model inference on a test dataset '
' different from train dataset. Use an anlayzer to define '
'a metric.',
action='store_true')

# options specific to this project and its dataloaders
parser.add_argument('--loss', choices=loss.names, default='CrossEntropyLoss')
parser.add_argument('--metric', choices=metrics.names, default='top1')
parser.add_argument('--analysis', choices=analysis.names, help='Run analysis after each epoch')
Expand All @@ -58,21 +67,24 @@

# Data
print('==> Preparing data..')
dataset = getattr(data, args.dataset)
dataset_train = getattr(data, args.dataset)
dataset_test = getattr(data, args.dataset_test or args.dataset)

transform_train = dataset.transform_train()
transform_test = dataset.transform_val()
transform_train = dataset_train.transform_train()
transform_test = dataset_test.transform_val()

dataset_kwargs = generate_kwargs(args, dataset, name=f'Dataset {args.dataset}', keys=data.custom.keys, globals=globals())
trainset = dataset(**dataset_kwargs, root='./data', train=True, download=True, transform=transform_train)
testset = dataset(**dataset_kwargs, root='./data', train=False, download=True, transform=transform_test)
dataset_train_kwargs = generate_kwargs(args, dataset_train, name=f'Dataset {dataset_train.__class__.__name__}', keys=data.custom.keys, globals=globals())
dataset_test_kwargs = generate_kwargs(args, dataset_test, name=f'Dataset {dataset_test.__class__.__name__}', keys=data.custom.keys, globals=globals())
trainset = dataset_train(**dataset_train_kwargs, root='./data', train=True, download=True, transform=transform_train)
testset = dataset_test(**dataset_test_kwargs, root='./data', train=False, download=True, transform=transform_test)

assert trainset.classes == testset.classes, (trainset.classes, testset.classes)
assert trainset.classes == testset.classes or args.disable_test_eval, (trainset.classes, testset.classes)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

Colors.cyan(f'Training with dataset {args.dataset} and {len(trainset.classes)} classes')
Colors.cyan(f'Testing with dataset {args.dataset_test or args.dataset} and {len(testset.classes)} classes')

# Model
print('==> Building model..')
Expand Down Expand Up @@ -162,10 +174,11 @@ def test(epoch, checkpoint=True):
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)

test_loss += loss.item()
metric.forward(outputs, targets)
if not args.disable_test_eval:
loss = criterion(outputs, targets)
test_loss += loss.item()
metric.forward(outputs, targets)
stat = analyzer.update_batch(outputs, targets)

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) %s' % (
Expand All @@ -185,6 +198,12 @@ def test(epoch, checkpoint=True):
torch.save(state, f'./checkpoint/{checkpoint_fname}.pth')
best_acc = acc

if args.disable_test_eval and (not args.analysis or args.analysis == 'Noop'):
Colors.red(
' * Warning: `disable_test_eval` is used but no custom metric '
'`--analysis` is supplied. I suggest supplying an analysis to perform '
' custom loss and accuracy calculation.')

if args.eval:
if not args.resume and not args.pretrained:
Colors.red(' * Warning: Model is not loaded from checkpoint. '
Expand Down
137 changes: 131 additions & 6 deletions nbdt/analysis.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from nbdt.utils import set_np_printoptions
from nbdt.utils import set_np_printoptions, Colors
from nbdt.graph import wnid_to_synset, synset_to_wnid
from nbdt.model import (
SoftEmbeddedDecisionRules as SoftRules,
HardEmbeddedDecisionRules as HardRules
)
from collections import defaultdict
import torch
from nbdt import metrics
import functools
import numpy as np


__all__ = names = (
'Noop', 'ConfusionMatrix', 'ConfusionMatrixJointNodes',
'IgnoredSamples', 'HardEmbeddedDecisionRules', 'SoftEmbeddedDecisionRules')
keys = ('path_graph', 'path_wnids', 'classes', 'dataset', 'metric')
'IgnoredSamples', 'HardEmbeddedDecisionRules', 'SoftEmbeddedDecisionRules',
'Superclass', 'SuperclassNBDT')
keys = ('path_graph', 'path_wnids', 'classes', 'dataset', 'metric',
'dataset_test', 'superclass_wnids')


def add_arguments(parser):
pass
parser.add_argument('--superclass-wnids', nargs='*', type=str)


def start_end_decorator(obj, name):
Expand Down Expand Up @@ -52,7 +57,7 @@ def __exit__(self, type, value, traceback):

class Noop:

accepts_classes = lambda trainset, **kwargs: trainset.classes
accepts_classes = lambda testset, **kwargs: testset.classes

def __init__(self, classes=()):
set_np_printoptions()
Expand Down Expand Up @@ -183,6 +188,7 @@ class DecisionRules(Noop):

def __init__(self, *args, Rules=HardRules, metric='top1', **kwargs):
self.rules = Rules(*args, **kwargs)
super().__init__(self.rules.tree.classes)
self.metric = getattr(metrics, metric)()

def start_test(self, epoch):
Expand All @@ -197,7 +203,7 @@ def update_batch(self, outputs, targets):

def end_test(self, epoch):
super().end_test(epoch)
accuracy = round(self.metric.correct / self.metric.total * 100., 2)
accuracy = round(self.metric.correct / (self.metric.total or 1) * 100., 2)
print(f'{self.name} Accuracy: {accuracy}%, {self.metric.correct}/{self.metric.total}')


Expand All @@ -214,3 +220,122 @@ class SoftEmbeddedDecisionRules(DecisionRules):

def __init__(self, *args, Rules=None, **kwargs):
super().__init__(*args, Rules=SoftRules, **kwargs)


class Superclass(DecisionRules):
"""Evaluate provided model on superclasses
Each wnid must be a hypernym of at least one label in the test set.
This metric will convert each predicted class into the corrresponding
wnid and report accuracy on this len(wnids)-class problem.
"""

accepts_dataset = lambda trainset, **kwargs: trainset.__class__.__name__
accepts_dataset_test = lambda testset, **kwargs: testset.__class__.__name__
name = 'Superclass'
accepts_superclass_wnids = True

def __init__(self, *args, superclass_wnids, dataset_test=None,
Rules=SoftRules, metric=None, **kwargs):
"""Pass wnids to classify.
Assumes index of each wnid is the index of the wnid in the rules.wnids
list. This agrees with Node.wnid_to_class_index as of writing, since
rules.wnids = get_wnids(...).
"""
# TODO: for now, ignores metric
super().__init__(*args, **kwargs)

kwargs['dataset'] = dataset_test
kwargs.pop('path_graph', '')
kwargs.pop('path_wnids', '')
self.rules_test = Rules(*args, **kwargs)
self.superclass_wnids = superclass_wnids
self.total = self.correct = 0

self.mapping_target, self.new_to_old_classes_target = Superclass.build_mapping(self.rules_test.tree.wnids_leaves, superclass_wnids)
self.mapping_pred, self.new_to_old_classes_pred = Superclass.build_mapping(self.rules.tree.wnids_leaves, superclass_wnids)

mapped_classes = [self.classes[i] for i in (self.mapping_target >= 0).nonzero()]
Colors.cyan(
f'==> Mapped {len(mapped_classes)} classes to your superclasses: '
f'{mapped_classes}')

@staticmethod
def build_mapping(dataset_wnids, superclass_wnids):
new_to_old_classes = defaultdict(lambda: [])
mapping = []
for old_index, dataset_wnid in enumerate(dataset_wnids):
synset = wnid_to_synset(dataset_wnid)
hypernyms = Superclass.all_hypernyms(synset)
hypernym_wnids = list(map(synset_to_wnid, hypernyms))

value = -1
for new_index, superclass_wnid in enumerate(superclass_wnids):
if superclass_wnid in hypernym_wnids:
value = new_index
break
mapping.append(value)
new_to_old_classes[value].append(old_index)
mapping = torch.Tensor(mapping)
return mapping, new_to_old_classes

@staticmethod
def all_hypernyms(synset):
hypernyms = []
frontier = [synset]
while frontier:
current = frontier.pop(0)
hypernyms.append(current)
frontier.extend(current.hypernyms())
return hypernyms

def forward(self, outputs, targets):
if self.mapping_target.device != targets.device:
self.mapping_target = self.mapping_target.to(targets.device)

if self.mapping_pred.device != outputs.device:
self.mapping_pred = self.mapping_pred.to(outputs.device)

targets = self.mapping_target[targets]
outputs = outputs[targets >= 0]
targets = targets[targets >= 0]

outputs[:, self.mapping_pred < 0] = -100
if outputs.size(0) == 0:
return torch.Tensor([]), torch.Tensor([])
predicted = outputs.max(1)[1]
predicted = self.mapping_pred[predicted].to(targets.device)
return predicted, targets

def update_batch(self, outputs, targets):
predicted, targets = self.forward(outputs, targets)

n_samples = predicted.size(0)
self.total += n_samples
self.correct += (predicted == targets).sum().item()
accuracy = round(self.correct / (float(self.total) or 1), 4) * 100
return f'{self.name}: {accuracy}%'


class SuperclassNBDT(Superclass):

name = 'Superclass-NBDT'

def __init__(self, *args, Rules=None, **kwargs):
super().__init__(*args, Rules=SoftRules, **kwargs)

def forward(self, outputs, targets):
outputs = self.rules.get_node_logits(
outputs,
new_to_old_classes=self.new_to_old_classes_pred,
num_classes=max(self.new_to_old_classes_pred) + 1)
predicted = outputs.max(1)[1].to(targets.device)

if self.mapping_target.device != targets.device:
self.mapping_target = self.mapping_target.to(targets.device)

targets = self.mapping_target[targets]
predicted = predicted[targets >= 0]
targets = targets[targets >= 0]
return predicted, targets
39 changes: 23 additions & 16 deletions nbdt/bin/nbdt-wnids
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,30 @@ datasets = ('CIFAR10', 'CIFAR100', 'Cityscapes') + data.imagenet.names + \
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=datasets, default='CIFAR10')
parser.add_argument('--root', default='./nbdt/wnids')
parser.add_argument('--classes', type=str, nargs='*',
help='INSTEAD of writing WNIDs for a dataset, convert JUST'
' this class name to a WNID.')
data.custom.add_arguments(parser)
args = parser.parse_args()

dataset = getattr(data, args.dataset)
dataset_kwargs = generate_kwargs(args, dataset,
name=f'Dataset {args.dataset}',
keys=data.custom.keys,
globals=globals())
if args.dataset not in ['Cityscapes', 'PascalContext', 'LookIntoPerson', 'ADE20K']:
dataset_kwargs['download'] = True
if args.classes:
classes = args.classes
else:
dataset = getattr(data, args.dataset)
dataset_kwargs = generate_kwargs(args, dataset,
name=f'Dataset {args.dataset}',
keys=data.custom.keys,
globals=globals())
if args.dataset not in ['Cityscapes', 'PascalContext', 'LookIntoPerson', 'ADE20K']:
dataset_kwargs['download'] = True

dataset = dataset(**dataset_kwargs, root='./data', download=True)
dataset = dataset(**dataset_kwargs, root='./data')

classes = dataset.classes
if args.dataset == 'Cityscapes':
classes = [cls.name for cls in dataset.classes if not cls.ignore_in_eval]
if args.dataset == 'PascalContext':
classes = [cls for cls in dataset.classes if cls != 'background']
classes = dataset.classes
if args.dataset == 'Cityscapes':
classes = [cls.name for cls in dataset.classes if not cls.ignore_in_eval]
if args.dataset == 'PascalContext':
classes = [cls for cls in dataset.classes if cls != 'background']

path = Path(os.path.join(args.root, f'{args.dataset}.txt'))
os.makedirs(path.parent, exist_ok=True)
Expand Down Expand Up @@ -105,7 +111,7 @@ hardcoded_mapping = {
}

wnids = []
for i, cls in enumerate(dataset.classes):
for i, cls in enumerate(classes):
if cls in hardcoded_mapping:
synset = hardcoded_mapping[cls]
else:
Expand All @@ -119,8 +125,9 @@ for i, cls in enumerate(dataset.classes):
print(f'{wnid}: ({cls}) {synset.definition()}')
wnids.append(wnid)

write_wnids(wnids, path)
if not args.classes:
write_wnids(wnids, path)
Colors.green(f'==> Wrote to {path}')

if failures:
Colors.red(f'==> Warning: failed to find wordnet IDs for {failures}')
Colors.green(f'==> Wrote to {path}')
Loading

0 comments on commit b8ef680

Please sign in to comment.