diff --git a/README.md b/README.md index d67f942..4ac6768 100644 --- a/README.md +++ b/README.md @@ -383,6 +383,24 @@ python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn +
Running zero-shot evaluation on superclasses. + +``` +# get wnids for animal and vehicle -- use the outputted wnids for below commands +nbdt-wnids --classes animal vehicle + +# evaluate CIFAR10-trained ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200 +python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=Superclass --superclass-wnids n00015388 n04524313 --pretrained + +# download public checkpoint +wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth -O checkpoint/ckpt-CIFAR10-ResNet18-induced-SoftTreeSupLoss.pth + +# evaluate CIFAR10-trained NBDT-ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200 +python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=SuperclassNBDT --superclass-wnids n00015388 n04524313 --loss=SoftTreeSupLoss --resume +``` + +
+ # Results We compare against all previous decision-tree-based methods that report on CIFAR10, CIFAR100, and/or ImageNet, including methods that hinder interpretability by using impure leaves or a random forest. We report the baseline with the highest accuracy, of all these methods: Deep Neural Decision Forest (DNDF updated with ResNet18), Explainable Observer-Classifier (XOC), Deep ConvolutionalDecision Jungle (DCDJ), Network of Experts (NofE), Deep Decision Network(DDN), and Adaptive Neural Trees (ANT). diff --git a/main.py b/main.py index 4264047..125767d 100644 --- a/main.py +++ b/main.py @@ -21,13 +21,13 @@ from nbdt.models.utils import load_state_dict, make_kwarg_optional maybe_install_wordnet() - +datasets = data.cifar.names + data.imagenet.names + data.custom.names parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') parser.add_argument('--batch-size', default=512, type=int, help='Batch size used for training') parser.add_argument('--epochs', '-e', default=200, type=int, help='By default, lr schedule is scaled accordingly') -parser.add_argument('--dataset', default='CIFAR10', choices=data.cifar.names + data.imagenet.names + data.custom.names) +parser.add_argument('--dataset', default='CIFAR10', choices=datasets) parser.add_argument('--arch', default='ResNet18', choices=list(models.get_model_choices())) parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') @@ -40,6 +40,15 @@ parser.add_argument('--pretrained', action='store_true', help='Download pretrained model. Not all models support this.') parser.add_argument('--eval', help='eval only', action='store_true') +parser.add_argument('--dataset-test', choices=datasets, + help='If not set, automatically set to train dataset') +parser.add_argument('--disable-test-eval', + help='Allows you to run model inference on a test dataset ' + ' different from train dataset. Use an anlayzer to define ' + 'a metric.', + action='store_true') + +# options specific to this project and its dataloaders parser.add_argument('--loss', choices=loss.names, default='CrossEntropyLoss') parser.add_argument('--metric', choices=metrics.names, default='top1') parser.add_argument('--analysis', choices=analysis.names, help='Run analysis after each epoch') @@ -58,21 +67,24 @@ # Data print('==> Preparing data..') -dataset = getattr(data, args.dataset) +dataset_train = getattr(data, args.dataset) +dataset_test = getattr(data, args.dataset_test or args.dataset) -transform_train = dataset.transform_train() -transform_test = dataset.transform_val() +transform_train = dataset_train.transform_train() +transform_test = dataset_test.transform_val() -dataset_kwargs = generate_kwargs(args, dataset, name=f'Dataset {args.dataset}', keys=data.custom.keys, globals=globals()) -trainset = dataset(**dataset_kwargs, root='./data', train=True, download=True, transform=transform_train) -testset = dataset(**dataset_kwargs, root='./data', train=False, download=True, transform=transform_test) +dataset_train_kwargs = generate_kwargs(args, dataset_train, name=f'Dataset {dataset_train.__class__.__name__}', keys=data.custom.keys, globals=globals()) +dataset_test_kwargs = generate_kwargs(args, dataset_test, name=f'Dataset {dataset_test.__class__.__name__}', keys=data.custom.keys, globals=globals()) +trainset = dataset_train(**dataset_train_kwargs, root='./data', train=True, download=True, transform=transform_train) +testset = dataset_test(**dataset_test_kwargs, root='./data', train=False, download=True, transform=transform_test) -assert trainset.classes == testset.classes, (trainset.classes, testset.classes) +assert trainset.classes == testset.classes or args.disable_test_eval, (trainset.classes, testset.classes) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) Colors.cyan(f'Training with dataset {args.dataset} and {len(trainset.classes)} classes') +Colors.cyan(f'Testing with dataset {args.dataset_test or args.dataset} and {len(testset.classes)} classes') # Model print('==> Building model..') @@ -162,10 +174,11 @@ def test(epoch, checkpoint=True): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) - loss = criterion(outputs, targets) - test_loss += loss.item() - metric.forward(outputs, targets) + if not args.disable_test_eval: + loss = criterion(outputs, targets) + test_loss += loss.item() + metric.forward(outputs, targets) stat = analyzer.update_batch(outputs, targets) progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) %s' % ( @@ -185,6 +198,12 @@ def test(epoch, checkpoint=True): torch.save(state, f'./checkpoint/{checkpoint_fname}.pth') best_acc = acc +if args.disable_test_eval and (not args.analysis or args.analysis == 'Noop'): + Colors.red( + ' * Warning: `disable_test_eval` is used but no custom metric ' + '`--analysis` is supplied. I suggest supplying an analysis to perform ' + ' custom loss and accuracy calculation.') + if args.eval: if not args.resume and not args.pretrained: Colors.red(' * Warning: Model is not loaded from checkpoint. ' diff --git a/nbdt/analysis.py b/nbdt/analysis.py index 6e65590..29f0694 100644 --- a/nbdt/analysis.py +++ b/nbdt/analysis.py @@ -1,8 +1,11 @@ -from nbdt.utils import set_np_printoptions +from nbdt.utils import set_np_printoptions, Colors +from nbdt.graph import wnid_to_synset, synset_to_wnid from nbdt.model import ( SoftEmbeddedDecisionRules as SoftRules, HardEmbeddedDecisionRules as HardRules ) +from collections import defaultdict +import torch from nbdt import metrics import functools import numpy as np @@ -10,12 +13,14 @@ __all__ = names = ( 'Noop', 'ConfusionMatrix', 'ConfusionMatrixJointNodes', - 'IgnoredSamples', 'HardEmbeddedDecisionRules', 'SoftEmbeddedDecisionRules') -keys = ('path_graph', 'path_wnids', 'classes', 'dataset', 'metric') + 'IgnoredSamples', 'HardEmbeddedDecisionRules', 'SoftEmbeddedDecisionRules', + 'Superclass', 'SuperclassNBDT') +keys = ('path_graph', 'path_wnids', 'classes', 'dataset', 'metric', + 'dataset_test', 'superclass_wnids') def add_arguments(parser): - pass + parser.add_argument('--superclass-wnids', nargs='*', type=str) def start_end_decorator(obj, name): @@ -52,7 +57,7 @@ def __exit__(self, type, value, traceback): class Noop: - accepts_classes = lambda trainset, **kwargs: trainset.classes + accepts_classes = lambda testset, **kwargs: testset.classes def __init__(self, classes=()): set_np_printoptions() @@ -183,6 +188,7 @@ class DecisionRules(Noop): def __init__(self, *args, Rules=HardRules, metric='top1', **kwargs): self.rules = Rules(*args, **kwargs) + super().__init__(self.rules.tree.classes) self.metric = getattr(metrics, metric)() def start_test(self, epoch): @@ -197,7 +203,7 @@ def update_batch(self, outputs, targets): def end_test(self, epoch): super().end_test(epoch) - accuracy = round(self.metric.correct / self.metric.total * 100., 2) + accuracy = round(self.metric.correct / (self.metric.total or 1) * 100., 2) print(f'{self.name} Accuracy: {accuracy}%, {self.metric.correct}/{self.metric.total}') @@ -214,3 +220,122 @@ class SoftEmbeddedDecisionRules(DecisionRules): def __init__(self, *args, Rules=None, **kwargs): super().__init__(*args, Rules=SoftRules, **kwargs) + + +class Superclass(DecisionRules): + """Evaluate provided model on superclasses + + Each wnid must be a hypernym of at least one label in the test set. + This metric will convert each predicted class into the corrresponding + wnid and report accuracy on this len(wnids)-class problem. + """ + + accepts_dataset = lambda trainset, **kwargs: trainset.__class__.__name__ + accepts_dataset_test = lambda testset, **kwargs: testset.__class__.__name__ + name = 'Superclass' + accepts_superclass_wnids = True + + def __init__(self, *args, superclass_wnids, dataset_test=None, + Rules=SoftRules, metric=None, **kwargs): + """Pass wnids to classify. + + Assumes index of each wnid is the index of the wnid in the rules.wnids + list. This agrees with Node.wnid_to_class_index as of writing, since + rules.wnids = get_wnids(...). + """ + # TODO: for now, ignores metric + super().__init__(*args, **kwargs) + + kwargs['dataset'] = dataset_test + kwargs.pop('path_graph', '') + kwargs.pop('path_wnids', '') + self.rules_test = Rules(*args, **kwargs) + self.superclass_wnids = superclass_wnids + self.total = self.correct = 0 + + self.mapping_target, self.new_to_old_classes_target = Superclass.build_mapping(self.rules_test.tree.wnids_leaves, superclass_wnids) + self.mapping_pred, self.new_to_old_classes_pred = Superclass.build_mapping(self.rules.tree.wnids_leaves, superclass_wnids) + + mapped_classes = [self.classes[i] for i in (self.mapping_target >= 0).nonzero()] + Colors.cyan( + f'==> Mapped {len(mapped_classes)} classes to your superclasses: ' + f'{mapped_classes}') + + @staticmethod + def build_mapping(dataset_wnids, superclass_wnids): + new_to_old_classes = defaultdict(lambda: []) + mapping = [] + for old_index, dataset_wnid in enumerate(dataset_wnids): + synset = wnid_to_synset(dataset_wnid) + hypernyms = Superclass.all_hypernyms(synset) + hypernym_wnids = list(map(synset_to_wnid, hypernyms)) + + value = -1 + for new_index, superclass_wnid in enumerate(superclass_wnids): + if superclass_wnid in hypernym_wnids: + value = new_index + break + mapping.append(value) + new_to_old_classes[value].append(old_index) + mapping = torch.Tensor(mapping) + return mapping, new_to_old_classes + + @staticmethod + def all_hypernyms(synset): + hypernyms = [] + frontier = [synset] + while frontier: + current = frontier.pop(0) + hypernyms.append(current) + frontier.extend(current.hypernyms()) + return hypernyms + + def forward(self, outputs, targets): + if self.mapping_target.device != targets.device: + self.mapping_target = self.mapping_target.to(targets.device) + + if self.mapping_pred.device != outputs.device: + self.mapping_pred = self.mapping_pred.to(outputs.device) + + targets = self.mapping_target[targets] + outputs = outputs[targets >= 0] + targets = targets[targets >= 0] + + outputs[:, self.mapping_pred < 0] = -100 + if outputs.size(0) == 0: + return torch.Tensor([]), torch.Tensor([]) + predicted = outputs.max(1)[1] + predicted = self.mapping_pred[predicted].to(targets.device) + return predicted, targets + + def update_batch(self, outputs, targets): + predicted, targets = self.forward(outputs, targets) + + n_samples = predicted.size(0) + self.total += n_samples + self.correct += (predicted == targets).sum().item() + accuracy = round(self.correct / (float(self.total) or 1), 4) * 100 + return f'{self.name}: {accuracy}%' + + +class SuperclassNBDT(Superclass): + + name = 'Superclass-NBDT' + + def __init__(self, *args, Rules=None, **kwargs): + super().__init__(*args, Rules=SoftRules, **kwargs) + + def forward(self, outputs, targets): + outputs = self.rules.get_node_logits( + outputs, + new_to_old_classes=self.new_to_old_classes_pred, + num_classes=max(self.new_to_old_classes_pred) + 1) + predicted = outputs.max(1)[1].to(targets.device) + + if self.mapping_target.device != targets.device: + self.mapping_target = self.mapping_target.to(targets.device) + + targets = self.mapping_target[targets] + predicted = predicted[targets >= 0] + targets = targets[targets >= 0] + return predicted, targets diff --git a/nbdt/bin/nbdt-wnids b/nbdt/bin/nbdt-wnids index 0ef4489..6d92b3b 100755 --- a/nbdt/bin/nbdt-wnids +++ b/nbdt/bin/nbdt-wnids @@ -20,24 +20,30 @@ datasets = ('CIFAR10', 'CIFAR100', 'Cityscapes') + data.imagenet.names + \ parser = argparse.ArgumentParser() parser.add_argument('--dataset', choices=datasets, default='CIFAR10') parser.add_argument('--root', default='./nbdt/wnids') +parser.add_argument('--classes', type=str, nargs='*', + help='INSTEAD of writing WNIDs for a dataset, convert JUST' + ' this class name to a WNID.') data.custom.add_arguments(parser) args = parser.parse_args() -dataset = getattr(data, args.dataset) -dataset_kwargs = generate_kwargs(args, dataset, - name=f'Dataset {args.dataset}', - keys=data.custom.keys, - globals=globals()) -if args.dataset not in ['Cityscapes', 'PascalContext', 'LookIntoPerson', 'ADE20K']: - dataset_kwargs['download'] = True +if args.classes: + classes = args.classes +else: + dataset = getattr(data, args.dataset) + dataset_kwargs = generate_kwargs(args, dataset, + name=f'Dataset {args.dataset}', + keys=data.custom.keys, + globals=globals()) + if args.dataset not in ['Cityscapes', 'PascalContext', 'LookIntoPerson', 'ADE20K']: + dataset_kwargs['download'] = True -dataset = dataset(**dataset_kwargs, root='./data', download=True) + dataset = dataset(**dataset_kwargs, root='./data') -classes = dataset.classes -if args.dataset == 'Cityscapes': - classes = [cls.name for cls in dataset.classes if not cls.ignore_in_eval] -if args.dataset == 'PascalContext': - classes = [cls for cls in dataset.classes if cls != 'background'] + classes = dataset.classes + if args.dataset == 'Cityscapes': + classes = [cls.name for cls in dataset.classes if not cls.ignore_in_eval] + if args.dataset == 'PascalContext': + classes = [cls for cls in dataset.classes if cls != 'background'] path = Path(os.path.join(args.root, f'{args.dataset}.txt')) os.makedirs(path.parent, exist_ok=True) @@ -105,7 +111,7 @@ hardcoded_mapping = { } wnids = [] -for i, cls in enumerate(dataset.classes): +for i, cls in enumerate(classes): if cls in hardcoded_mapping: synset = hardcoded_mapping[cls] else: @@ -119,8 +125,9 @@ for i, cls in enumerate(dataset.classes): print(f'{wnid}: ({cls}) {synset.definition()}') wnids.append(wnid) -write_wnids(wnids, path) +if not args.classes: + write_wnids(wnids, path) + Colors.green(f'==> Wrote to {path}') if failures: Colors.red(f'==> Warning: failed to find wordnet IDs for {failures}') -Colors.green(f'==> Wrote to {path}') diff --git a/nbdt/data/custom.py b/nbdt/data/custom.py index d6b50c1..f9ba168 100644 --- a/nbdt/data/custom.py +++ b/nbdt/data/custom.py @@ -121,36 +121,20 @@ def __init__(self, dataset, include_labels=(0,)): ]) -class CIFAR10ResampleLabels(ResampleLabelsDataset): - - def __init__(self, *args, root='./data', probability_labels=1, **kwargs): - super().__init__( - dataset=cifar.CIFAR10(*args, root=root, **kwargs), - probability_labels=probability_labels) - - -class CIFAR100ResampleLabels(ResampleLabelsDataset): - - def __init__(self, *args, root='./data', probability_labels=1, **kwargs): - super().__init__( - dataset=cifar.CIFAR100(*args, root=root, **kwargs), - probability_labels=probability_labels) - - -class TinyImagenet200ResampleLabels(ResampleLabelsDataset): - - def __init__(self, *args, root='./data', probability_labels=1, **kwargs): - super().__init__( - dataset=imagenet.TinyImagenet200(*args, root=root, **kwargs), - probability_labels=probability_labels) +def get_resample_labels_dataset(dataset): + class Cls(ResampleLabelsDataset): + def __init__(self, *args, root='./data', probability_labels=1, **kwargs): + super().__init__( + dataset=dataset(*args, root=root, **kwargs), + probability_labels=probability_labels) + Cls.__name__ = f'{dataset.__class__.__name__}ResampleLabels' + return Cls -class Imagenet1000ResampleLabels(ResampleLabelsDataset): - - def __init__(self, *args, root='./data', probability_labels=1, **kwargs): - super().__init__( - dataset=imagenet.Imagenet1000(*args, root=root, **kwargs), - probability_labels=probability_labels) +CIFAR10ResampleLabels = get_resample_labels_dataset(cifar.CIFAR10) +CIFAR100ResampleLabels = get_resample_labels_dataset(cifar.CIFAR100) +TinyImagenet200ResampleLabels = get_resample_labels_dataset(imagenet.TinyImagenet200) +Imagenet1000ResampleLabels = get_resample_labels_dataset(imagenet.Imagenet1000) class IncludeClassesDataset(IncludeLabelsDataset): @@ -168,36 +152,20 @@ def __init__(self, dataset, include_classes=()): ]) -class CIFAR10IncludeLabels(IncludeLabelsDataset): - - def __init__(self, *args, root='./data', include_labels=(0,), **kwargs): - super().__init__( - dataset=cifar.CIFAR10(*args, root=root, **kwargs), - include_labels=include_labels) - - -class CIFAR100IncludeLabels(IncludeLabelsDataset): - - def __init__(self, *args, root='./data', include_labels=(0,), **kwargs): - super().__init__( - dataset=cifar.CIFAR100(*args, root=root, **kwargs), - include_labels=include_labels) - - -class TinyImagenet200IncludeLabels(IncludeLabelsDataset): - - def __init__(self, *args, root='./data', include_labels=(0,), **kwargs): - super().__init__( - dataset=imagenet.TinyImagenet200(*args, root=root, **kwargs), - include_labels=include_labels) - +def get_include_labels_dataset(dataset): + class Cls(IncludeLabelsDataset): + def __init__(self, *args, root='./data', include_labels=(0,), **kwargs): + super().__init__( + dataset=dataset(*args, root=root, **kwargs), + include_labels=include_labels) + Cls.__name__ = f'{dataset.__class__.__name__}IncludeLabels' + return Cls -class Imagenet1000IncludeLabels(IncludeLabelsDataset): - def __init__(self, *args, root='./data', include_labels=(0,), **kwargs): - super().__init__( - dataset=imagenet.Imagenet1000(*args, root=root, **kwargs), - include_labels=include_labels) +CIFAR10IncludeLabels = get_include_labels_dataset(cifar.CIFAR10) +CIFAR100IncludeLabels = get_include_labels_dataset(cifar.CIFAR100) +TinyImagenet200IncludeLabels = get_include_labels_dataset(imagenet.TinyImagenet200) +Imagenet1000IncludeLabels = get_include_labels_dataset(imagenet.Imagenet1000) class ExcludeLabelsDataset(IncludeLabelsDataset): @@ -213,33 +181,17 @@ def __init__(self, dataset, exclude_labels=(0,)): include_labels=include_labels) -class CIFAR10ExcludeLabels(ExcludeLabelsDataset): - - def __init__(self, *args, root='./data', exclude_labels=(0,), **kwargs): - super().__init__( - dataset=cifar.CIFAR10(*args, root=root, **kwargs), - exclude_labels=exclude_labels) - - -class CIFAR100ExcludeLabels(ExcludeLabelsDataset): - - def __init__(self, *args, root='./data', exclude_labels=(0,), **kwargs): - super().__init__( - dataset=cifar.CIFAR100(*args, root=root, **kwargs), - exclude_labels=exclude_labels) - +def get_exclude_labels_dataset(dataset): + class Cls(ExcludeLabelsDataset): + def __init__(self, *args, root='./data', exclude_labels=(0,), **kwargs): + super().__init__( + dataset=dataset(*args, root=root, **kwargs), + exclude_labels=exclude_labels) + Cls.__name__ = f'{dataset.__class__.__name__}ExcludeLabels' + return Cls -class TinyImagenet200ExcludeLabels(ExcludeLabelsDataset): - def __init__(self, *args, root='./data', exclude_labels=(0,), **kwargs): - super().__init__( - dataset=imagenet.TinyImagenet200(*args, root=root, **kwargs), - exclude_labels=exclude_labels) - - -class Imagenet1000ExcludeLabels(ExcludeLabelsDataset): - - def __init__(self, *args, root='./data', exclude_labels=(0,), **kwargs): - super().__init__( - dataset=imagenet.Imagenet1000(*args, root=root, **kwargs), - exclude_labels=exclude_labels) +CIFAR10ExcludeLabels = get_exclude_labels_dataset(cifar.CIFAR10) +CIFAR100ExcludeLabels = get_exclude_labels_dataset(cifar.CIFAR100) +TinyImagenet200ExcludeLabels = get_exclude_labels_dataset(imagenet.TinyImagenet200) +Imagenet1000ExcludeLabels = get_exclude_labels_dataset(imagenet.Imagenet1000) diff --git a/nbdt/metrics.py b/nbdt/metrics.py index b0639cd..3175304 100644 --- a/nbdt/metrics.py +++ b/nbdt/metrics.py @@ -21,7 +21,7 @@ def forward(self, outputs, targets): self.total += targets.size(0) def report(self): - return self.correct / self.total + return self.correct / (self.total or 1) def __repr__(self): return f'Top{self.k}: {self.report()}' diff --git a/nbdt/model.py b/nbdt/model.py index 2484750..963fb6e 100644 --- a/nbdt/model.py +++ b/nbdt/model.py @@ -57,14 +57,19 @@ def __init__(self, self.I = torch.eye(len(self.tree.classes)) @staticmethod - def get_node_logits(outputs, node): + def get_node_logits(outputs, node=None, + new_to_old_classes=None, num_classes=None): """Get output for a particular node This `outputs` above are the output of the neural network. """ + assert node or (new_to_old_classes and num_classes), \ + 'Either pass node or (new_to_old_classes mapping and num_classes)' + new_to_old_classes = new_to_old_classes or node.child_index_to_class_index + num_classes = num_classes or node.num_classes return torch.stack([ - outputs.T[node.child_index_to_class_index[child_index]].mean(dim=0) - for child_index in range(node.num_classes) + outputs.T[new_to_old_classes[child_index]].mean(dim=0) + for child_index in range(num_classes) ]).T @classmethod