Skip to content

Commit

Permalink
✨ Merge pull request #56 from xuanlongORZ/evidential_cla
Browse files Browse the repository at this point in the history
✨ Add evidential classification
  • Loading branch information
o-laurent authored Oct 24, 2023
2 parents a13ba5b + cbe17f2 commit 99c935e
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
if: ${{ github.event_name != 'pull_request' }}
if: ${{ github.event_name != 'pull_request' || github.base_ref == 'dev' }}
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ To date, the following deep learning baselines have been implemented:
- MIMO
- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873))
- Bayesian Neural Networks :construction: Work in progress :construction:
- Deep Evidential Regression
- Deep Evidential Classification & Regression

### Post-processing methods

Expand All @@ -74,6 +74,7 @@ We provide the following tutorials in our documentation:
- [Improve Top-label Calibration with Temperature Scaling](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html)
- [Deep Evidential Regression on a Toy Example](https://torch-uncertainty.github.io/auto_tutorials/tutorial_der_cubic.html)
- [Training a LeNet with Monte-Carlo Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html)
- [Training a LeNet with Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html)

## Awesome Uncertainty repositories

Expand Down
187 changes: 187 additions & 0 deletions auto_tutorials_source/tutorial_evidential_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#!/usr/bin/env python
# coding: utf-8

"""
Deep Evidential Classification on a Toy Example
===============================================
This tutorial aims to provide an introductory overview of Deep Evidential Classification (DEC) using a practical example. We demonstrate an application of DEC by tackling the toy-problem of fitting the MNIST dataset using a Multi-Layer Perceptron (MLP) neural network model. The output of the MLP is modeled as a Dirichlet distribution. The MLP is trained by minimizing the DEC loss function, composed of a Bayesian risk square error loss and a regularization term based on KL Divergence.
DEC represents an evidential approach to quantifying uncertainty in neural network classification models. This method involves introducing prior distributions over the parameters of the Categorical likelihood function. Then, the MLP model estimates the parameters of the evidential distribution.
Training a LeNet with DEC using TorchUncertainty models
-------------------------------------------------------
In this part, we train a neural network, based on the model and routines already implemented in TU.
1. Loading the utilities
~~~~~~~~~~~~~~~~~~~~~~~~
To train a LeNet with the DEC loss function using TorchUncertainty, we have to load the following utilities from TorchUncertainty:
- the cli handler: cli_main and argument parser: init_args
- the model: LeNet, which lies in torch_uncertainty.models
- the classification training routine in the torch_uncertainty.training.classification module
- the evidential objective: the DECLoss, which lies in the torch_uncertainty.losses file
- the datamodule that handles dataloaders: MNISTDataModule, which lies in the torch_uncertainty.datamodule
"""

# %%
from torch_uncertainty import cli_main, init_args
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.routines.classification import ClassificationSingle
from torch_uncertainty.losses import DECLoss
from torch_uncertainty.datamodules import MNISTDataModule


# %%
# We also need to define an optimizer using torch.optim as well as the
# neural network utils withing torch.nn, as well as the partial util to provide
# the modified default arguments for the DEC loss.
#
# We also import ArgvContext to avoid using the jupyter arguments as cli
# arguments, and therefore avoid errors.

import os
from functools import partial
from pathlib import Path

import torch
from cli_test_helpers import ArgvContext
from torch import nn, optim


# %%
# 2. Creating the Optimizer Wrapper
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We follow the official implementation in DEC, use the Adam optimizer
# with the default learning rate of 0.001 and a step scheduler.
def optim_lenet(model: nn.Module) -> dict:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)
exp_lr_scheduler = optim.lr_scheduler.StepLR(
optimizer, step_size=7, gamma=0.1
)
return {"optimizer": optimizer, "lr_scheduler": exp_lr_scheduler}


# %%
# 3. Creating the necessary variables
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In the following, we need to define the root of the logs, and to
# fake-parse the arguments needed for using the PyTorch Lightning Trainer. We
# also use the same MNIST classification example as that used in the
# original DEC paper. We only train for 5 epochs for the sake of time.
root = Path(os.path.abspath(""))

# We mock the arguments for the trainer. Replace with 25 epochs on your machine.
with ArgvContext(
"file.py",
"--max_epochs",
"5",
"--enable_progress_bar",
"True",
):
args = init_args(datamodule=MNISTDataModule)

net_name = "dec-lenet-mnist"

# datamodule
args.root = str(root / "data")
dm = MNISTDataModule(**vars(args))


model = lenet(
in_channels=dm.num_channels,
num_classes=dm.num_classes,
)

# %%
# 4. The Loss and the Training Routine
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Next, we need to define the loss to be used during training. To do this, we
# redefine the default parameters for the DEC loss using the partial
# function from functools. After that, we define the training routine using
# the single classification model training routine from
# torch_uncertainty.routines.classification.ClassificationSingle.
# In this routine, we provide the model, the DEC loss, the optimizer,
# and all the default arguments.

loss = partial(
DECLoss,
reg_weight=1e-2,
)

baseline = ClassificationSingle(
model=model,
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=loss,
optimization_procedure=optim_lenet,
**vars(args),
)

# %%
# 5. Gathering Everything and Training the Model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

results = cli_main(baseline, dm, root, net_name, args)

# %%
# 6. Testing the Model
# ~~~~~~~~~~~~~~~~~~~~
# Now that the model is trained, let's test it on MNIST.

import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms.functional as F

import numpy as np


def imshow(img) -> None:
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()


def rotated_mnist(angle: int) -> None:
"""Rotate MNIST images and show images and confidence.
Args:
angle: Rotation angle in degrees.
"""
rotated_images = F.rotate(images, angle)
# print rotated images
imshow(torchvision.utils.make_grid(rotated_images[:4, ...]))
print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))

evidence = baseline(rotated_images)
alpha = torch.relu(evidence) + 1
strength = torch.sum(alpha, dim=1, keepdim=True)
probs = alpha / strength
entropy = -1 * torch.sum(probs * torch.log(probs), dim=1, keepdim=True)
for j in range(4):
predicted = torch.argmax(probs[j, :])
print(
f"Predicted digits for the image {j}: {predicted} with strength "
f"{strength[j,0]:.3} and entropy {entropy[j,0]:.3}."
)


dataiter = iter(dm.val_dataloader())
images, labels = next(dataiter)

with torch.no_grad():
baseline.eval()
rotated_mnist(0)
rotated_mnist(45)
rotated_mnist(90)


# %%
# References
# ----------
#
# - **Deep Evidential Classification:** Murat Sensoy, Lance Kaplan, & Melih Kandemir (2018). Evidential Deep Learning to Quantify Classification Uncertainty `NeurIPS 2018 <https://arxiv.org/pdf/1806.01768>`_
4 changes: 1 addition & 3 deletions auto_tutorials_source/tutorial_mc_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from torch_uncertainty import cli_main, init_args
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.baselines import ResNet
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.routines.classification import ClassificationEnsemble
from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18
Expand All @@ -42,7 +41,6 @@
# arguments, and therefore avoid errors.

import os
from functools import partial
from pathlib import Path

from torch import nn
Expand Down Expand Up @@ -77,7 +75,7 @@
"--num_estimators",
"16",
):
args = init_args(network=ResNet, datamodule=MNISTDataModule)
args = init_args(datamodule=MNISTDataModule)

net_name = "mc-dropout-lenet-mnist"

Expand Down
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ Losses
KLDiv
ELBOLoss
NIGLoss
DECLoss

Post-Processing Methods
-----------------------
Expand Down
11 changes: 11 additions & 0 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ Uncertainty Models

The following uncertainty models are implemented.

Deep Evidential Classification
^^^^^^^^^^^^^^^^^^^^^^^^^^

For Deep Evidential Classification, consider citing:

**Evidential Deep Learning to Quantify Classification Uncertainty**

* Authors: *Murat Sensoy, Lance Kaplan, Melih Kandemir*
* Paper: `NeurIPS 2018 <https://arxiv.org/pdf/1806.01768>`__.


Deep Evidential Regression
^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion experiments/classification/cifar100/wideresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"wideresnet28x10", "cifar100", args.version
"wideresnet28x10", "cifar100", args.version
),
style="cifar",
**vars(args),
Expand Down
10 changes: 2 additions & 8 deletions tests/routines/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn

from torch_uncertainty import cli_main, init_args
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.losses import DECLoss, ELBOLoss
from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18
from torch_uncertainty.routines.classification import (
ClassificationEnsemble,
Expand All @@ -31,7 +31,6 @@ def test_cli_main_dummy_binary(self):
DummyClassificationBaseline, DummyClassificationDataModule
)

# datamodule
args.root = str(root / "data")
dm = DummyClassificationDataModule(num_classes=1, **vars(args))

Expand All @@ -43,7 +42,6 @@ def test_cli_main_dummy_binary(self):
baseline_type="single",
**vars(args),
)

cli_main(model, dm, root, "dummy", args)

def test_cli_main_dummy_ood(self):
Expand All @@ -53,7 +51,6 @@ def test_cli_main_dummy_ood(self):
DummyClassificationBaseline, DummyClassificationDataModule
)

# datamodule
args.root = str(root / "data")
dm = DummyClassificationDataModule(**vars(args))
loss = partial(
Expand All @@ -70,27 +67,24 @@ def test_cli_main_dummy_ood(self):
baseline_type="single",
**vars(args),
)

cli_main(model, dm, root, "dummy", args)

with ArgvContext("file.py", "--evaluate_ood", "--entropy"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)

# datamodule
args.root = str(root / "data")
dm = DummyClassificationDataModule(**vars(args))

model = DummyClassificationBaseline(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
loss=DECLoss,
optimization_procedure=optim_cifar10_resnet18,
baseline_type="single",
**vars(args),
)

cli_main(model, dm, root, "dummy", args)

def test_classification_failures(self):
Expand Down
Loading

0 comments on commit 99c935e

Please sign in to comment.