diff --git a/images/regression_ensemble.png b/images/regression_ensemble.png new file mode 100644 index 0000000..8ce1939 Binary files /dev/null and b/images/regression_ensemble.png differ diff --git a/images/triangular_simplex.png b/images/triangular_simplex.png new file mode 100644 index 0000000..76a635e Binary files /dev/null and b/images/triangular_simplex.png differ diff --git a/src/data/__pycache__/toy_loading.cpython-37.pyc b/src/data/__pycache__/toy_loading.cpython-37.pyc new file mode 100644 index 0000000..7079e1b Binary files /dev/null and b/src/data/__pycache__/toy_loading.cpython-37.pyc differ diff --git a/src/data/toy_loading.py b/src/data/toy_loading.py new file mode 100644 index 0000000..fb88502 --- /dev/null +++ b/src/data/toy_loading.py @@ -0,0 +1,136 @@ +import pandas as pd +import numpy as np +import torch +from torch.utils import data as tdata + + +def get_toy_dataset( + target_generator_fn, + noise_generator_fn, + train_limits=(-1.0, 1.0), + test_limits=(-1.5, 1.5), ood_abs_limits=(1.1, 1.3), + train_samples=100, test_samples=200, + ood_samples=40, random_state=0 +): + """Generates one-dimensional regression dataset""" + np.random.seed(random_state) + x_train = np.random.uniform(train_limits[0], train_limits[1], (train_samples,)) + + y_train = target_generator_fn(x_train) + np.random.seed(random_state) + y_noise = noise_generator_fn(x_train) * np.random.randn(y_train.shape[0]) + y_train += y_noise + + np.random.seed(random_state) + x_ood_1 = np.random.uniform(ood_abs_limits[0], ood_abs_limits[1], (ood_samples // 2,)) + np.random.seed(random_state) + x_ood_2 = np.random.uniform(-ood_abs_limits[1], -ood_abs_limits[0], (ood_samples // 2,)) + x_ood = np.concatenate([x_ood_1, x_ood_2], axis=0) + + y_ood = target_generator_fn(x_ood) + np.random.seed(random_state) + y_ood += noise_generator_fn(x_ood) * np.random.randn(y_ood.shape[0]) + + x_test = np.linspace(test_limits[0], test_limits[1], test_samples) + y_test = target_generator_fn(x_test) + + train_data, test_data, ood_data = [ + tdata.TensorDataset( + torch.Tensor(x_c).unsqueeze(1), + torch.Tensor(y_c).unsqueeze(1) + ) for (x_c, y_c) in zip( + [x_train, x_test, x_ood], [y_train, y_test, y_ood] + ) + ] + return train_data, test_data, ood_data, y_noise + +def get_arrays_from_loader(loader): + first_elems = [] + second_elems = [] + for item in loader: + first_elems += [item[0]] + second_elems += [item[1]] + return torch.cat(first_elems, dim=0), torch.cat(second_elems, dim=0) + +def get_table_loaders( + train_data, test_data, batch_size, ood_data=None, ood_test_data=None, + ood_batch_size=None, shuffle=True, normalize_targets=False, target_id=-1, +): + feature_len = train_data.shape[1] - 1 + if target_id == -1: + x_train, y_train = train_data[:, :-1], train_data[:, -1:] + x_test, y_test = test_data[:, :-1], test_data[:, -1:] + elif target_id > -1: + idxs = list(range(train_data.shape[1])) + idxs.pop(target_id) + x_train, y_train = train_data[:, idxs],\ + train_data[:, target_id].reshape(-1,1) + x_test, y_test = test_data[:, idxs],\ + test_data[:, target_id].reshape(-1,1) + else: + raise ValueError("Provide target_id >= -1") + + # Normalize train/test features & targets (if necessary) + x_means, x_stds = x_train.mean(axis=0), x_train.std(axis=0) + if normalize_targets: + y_means, y_stds = y_train.mean(axis=0), y_train.std(axis=0) + else: + y_means, y_stds = 0., 1. + x_train = (x_train - x_means) / x_stds + y_train = (y_train - y_means) / y_stds + x_test = (x_test - x_means) / x_stds + y_test = (y_test - y_means) / y_stds + # Normalize ood features + if ood_data is not None: + if target_id > -1: + idxs = list(range(train_data.shape[1])) + idxs.pop(target_id) + x_ood = ood_data[:, idxs] + x_ood = (x_ood - x_means) / x_stds + else: + x_ood = ood_data[:, :feature_len] + x_ood = (x_ood - x_means) / x_stds + assert not np.isnan(x_ood).any() + if ood_test_data is not None: + if target_id > -1: + idxs = list(range(train_data.shape[1])) + idxs.pop(target_id) + x_ood_test = ood_test_data[:, idxs] + x_ood_test = (x_ood_test - x_means) / x_stds + else: + x_ood_test = ood_test_data[:, :feature_len] + x_ood_test = (x_ood_test - x_means) / x_stds + assert not np.isnan(x_ood_test).any() + + assert not np.isnan(y_test).any() + assert not np.isnan(y_train).any() + assert not np.isnan(x_test).any() + assert not np.isnan(x_train).any() + ood_loader = None + ood_test_loader = None + # Initialize loaders + train_loader = tdata.DataLoader( + tdata.TensorDataset( + torch.Tensor(x_train), torch.Tensor(y_train) + ), + batch_size=batch_size, + shuffle=shuffle + ) + test_loader = tdata.DataLoader( + tdata.TensorDataset( + torch.Tensor(x_test), torch.Tensor(y_test) + ), + batch_size=batch_size, shuffle=False + ) + if ood_data is not None: + ood_loader = tdata.DataLoader( + tdata.TensorDataset(torch.Tensor(x_ood)), + batch_size=ood_batch_size, shuffle=shuffle + ) + if ood_test_data is not None: + ood_test_loader = tdata.DataLoader( + tdata.TensorDataset(torch.Tensor(x_ood_test)), + batch_size=ood_batch_size, shuffle=False + ) + return train_loader, test_loader, ood_loader, ood_test_loader,\ + [torch.FloatTensor([y_means]), torch.FloatTensor([y_stds])] diff --git a/src/distributions/__pycache__/distributions.cpython-37.pyc b/src/distributions/__pycache__/distributions.cpython-37.pyc new file mode 100644 index 0000000..a437025 Binary files /dev/null and b/src/distributions/__pycache__/distributions.cpython-37.pyc differ diff --git a/src/distributions/__pycache__/mixture_distribution.cpython-37.pyc b/src/distributions/__pycache__/mixture_distribution.cpython-37.pyc new file mode 100644 index 0000000..433e137 Binary files /dev/null and b/src/distributions/__pycache__/mixture_distribution.cpython-37.pyc differ diff --git a/src/distributions/__pycache__/prior_distribution.cpython-37.pyc b/src/distributions/__pycache__/prior_distribution.cpython-37.pyc new file mode 100644 index 0000000..8272e92 Binary files /dev/null and b/src/distributions/__pycache__/prior_distribution.cpython-37.pyc differ diff --git a/src/distributions/distributions.py b/src/distributions/distributions.py new file mode 100644 index 0000000..3a7f366 --- /dev/null +++ b/src/distributions/distributions.py @@ -0,0 +1,270 @@ +import math +import torch + +from torch.distributions import constraints, Distribution, Normal +from torch.distributions import register_kl +from torch.distributions.kl import kl_divergence +from torch.distributions.independent import Independent + +from src.utils.func_utils import mvdigamma, rel_error + + +class DiagonalWishart(Distribution): + r""" + Creates a diagonal version of Wishart distribution parameterized + by its scale :attr:`scale_diag` and degrees of freedom :attr:`df`. + + Args: + scale_diag (Tensor) (or L): scale of the distribution with shapes (bs, ..., p), + where p is the dimensionality of a distribution. + df (Tensor) (or \nu): degrees of freedom with shapes (bs, ...). It should have + the same shape as :attr:`scale_diag`, but without last dim. + """ + arg_constraints = {'scale_diag': constraints.positive, 'df': constraints.positive} + support = constraints.positive + has_rsample = False + + def __init__(self, scale_diag, df, validate_args=True): + if scale_diag.dim() < 1 or df.dim() < 1: + raise ValueError("scale_diag or df must be at least one-dimensional.") + if df.size(-1) == 1 and scale_diag.size(-1) != 1: + raise ValueError("df shouldn't end with dimensionality 1 if scale_diag doesn't") + df_ = df.unsqueeze(-1) # add dim on right + self.scale_diag, df_ = torch.broadcast_tensors(scale_diag, df_) + self.df = df_[..., 0] # drop rightmost dim + + batch_shape, event_shape = self.scale_diag.shape[:-1], self.scale_diag.shape[-1:] + self.dimensionality = event_shape.numel() + if (self.df <= (self.dimensionality - 1)).any(): + raise ValueError("df must be greater than dimensionality - 1") + super(DiagonalWishart, self).__init__(batch_shape, event_shape, validate_args=validate_args) + + @property + def mean(self): + return self.df.unsqueeze(-1) * self.scale_diag + + @property + def variance(self): + return 2 * self.df.unsqueeze(-1) * self.scale_diag.pow(2) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + tr_term = -0.5 * torch.div(value, self.scale_diag).sum(dim=-1) + norm_term = 0.5 * (self.df - self.dimensionality - 1) * torch.log(value).sum(dim=-1) + + return -self.log_normalizer() + norm_term + tr_term + + def log_normalizer(self): + log_normalizer_1 = 0.5 * self.df * self.dimensionality * math.log(2) + log_normalizer_2 = 0.5 * self.df * self.scale_diag.log().sum(dim=-1) + log_normalizer_3 = torch.mvlgamma(0.5 * self.df, self.dimensionality) + return log_normalizer_1 + log_normalizer_2 + log_normalizer_3 + + def log_expectation(self): + mvdigamma_term = mvdigamma(0.5 * self.df, self.dimensionality) + other_terms = self.dimensionality * math.log(2) + torch.log(self.scale_diag).sum(dim=-1) + return mvdigamma_term + other_terms + + def entropy(self): + return self.log_normalizer() - 0.5 * (self.df - self.dimensionality - 1) * self.log_expectation()\ + + 0.5 * self.df * self.dimensionality + + +class NormalDiagonalWishart(Distribution): + r""" + Creates a diagonal version of Normal-Wishart distribution parameterized + by its mean :attr:`mean`, diagonal precision :attr:`precision_diag`, + degrees of freedom :attr:`df` and belief in mean :attr:`belief` + + Args: + loc (Tensor) (or m): location of the distribution with shapes (bs, ..., p), + where p is dimensionality of the distribution. + precision_diag (Tensor or float) (or L): precision of the distribution with shapes (bs, ..., p), + where p is dimensionality of the distribution. It should have the same shape + as :attr:`mean`. + belief (Tensor or float) (or \kappa): confidence of belief in mean with shapes (bs, ...). + It should have the same shape as :attr:`mean`, but without last dim. + df (Tensor or float) (or \nu): degrees of freedom with shapes (bs, ...). It should have + the same shape as :attr:`mean`, but without last dim. + """ + arg_constraints = { + 'precision_diag': constraints.positive, + 'belief': constraints.positive, + 'df': constraints.positive, + } + support = constraints.real + has_rsample = False + + def __init__(self, loc, precision_diag, belief, df, validate_args=True): + precision_diag, belief, df = self.convert_float_params_to_tensor( + loc, precision_diag, belief, df + ) + if loc.dim() < 1 or precision_diag.dim() < 1 or df.dim() < 1 or belief.dim() < 1: + raise ValueError("loc, precision_diag, df, belief must be at least one-dimensional.") + if belief.size(-1) == 1 and precision_diag.size(-1) != 1: + raise ValueError("belief shouldn't end with dimensionality 1 if precision_diag doesn't") + if df.size(-1) == 1 and precision_diag.size(-1) != 1: + raise ValueError("df shouldn't end with dimensionality 1 if precision_diag doesn't") + df_, belief_ = df.unsqueeze(-1), belief.unsqueeze(-1) # add dim on right + self.loc, self.precision_diag, df_, belief_ = torch.broadcast_tensors(loc, precision_diag, df_, belief_) + self.df, self.belief = df_[..., 0], belief_[..., 0] # drop rightmost dim + + batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] + self.dimensionality = event_shape.numel() + if (self.df <= (self.dimensionality + 1)).any(): + raise ValueError("df must be greater than dimensionality + 1 to have expectation") + super(NormalDiagonalWishart, self).__init__( + batch_shape, event_shape, validate_args=validate_args + ) + + def log_prob(self, value_mean, value_precision): + if self._validate_args: + self._validate_sample(value_mean) + self._validate_sample(value_precision) + if (value_precision <= 0).any(): + raise ValueError("desired precision must be greater that 0") + wishart_log_prob = DiagonalWishart( + self.precision_diag, self.df + ).log_prob(value_precision) + normal_log_prob = Independent( + Normal( + self.loc, ( + 1 / (self.belief.unsqueeze(-1) * value_precision) + ).pow(0.5) + ), 1 + ).log_prob(value_mean) + return normal_log_prob + wishart_log_prob + + def expectation_entropy_normal(self): + return 0.5 * ( + self.dimensionality * ( + 1 + math.log(2 * math.pi) + ) - torch.log( + 2 * self.precision_diag * self.belief.unsqueeze(-1) + ).sum(dim=-1) + - mvdigamma(0.5 * self.df, self.dimensionality) + ) + + def entropy(self): + wishart_entropy = DiagonalWishart(self.precision_diag, self.df).entropy() + expectation_entropy_normal = self.expectation_entropy_normal() + return wishart_entropy + expectation_entropy_normal + + def convert_float_params_to_tensor(self, loc, precision_diag, belief, df): + if isinstance(precision_diag, float): + precision_diag = precision_diag * torch.ones_like(loc).to(loc.device) + if isinstance(belief, float): + belief = belief * torch.ones_like(loc).to(loc.device)[..., 0] + if isinstance(df, float): + df = df * torch.ones_like(loc).to(loc.device)[..., 0] + return precision_diag, belief, df + + +@register_kl(DiagonalWishart,DiagonalWishart) +def kl_diag_wishart(p: DiagonalWishart, q: DiagonalWishart): + if p.event_shape != q.event_shape: + raise ValueError("KL-divergence between two Diagonal Wisharts with\ + different event shapes cannot be computed") + log_det_term = -(0.5 * q.df) * torch.div( + p.scale_diag, q.scale_diag + ).log().sum(dim=-1) + tr_term = (0.5 * p.df) * ( + torch.div(p.scale_diag, q.scale_diag).sum(dim=-1) - p.dimensionality + ) + mvlgamma_term = torch.mvlgamma(0.5 * q.df, q.dimensionality) - torch.mvlgamma(0.5 * p.df, p.dimensionality) + digamma_term = 0.5 * (p.df - q.df) * mvdigamma(0.5 * p.df, p.dimensionality) + return log_det_term + tr_term + mvlgamma_term + digamma_term + + +@register_kl(NormalDiagonalWishart, NormalDiagonalWishart) +def kl_normal_diag_wishart(p: NormalDiagonalWishart, q: NormalDiagonalWishart): + if p.event_shape != q.event_shape: + raise ValueError("KL-divergence between two Normal Diagonal Wisharts with\ + different event shapes cannot be computed") + + wishart_KL = kl_divergence( + DiagonalWishart(p.precision_diag, p.df), + DiagonalWishart(q.precision_diag, q.df) + ) + weighted_mse_term = torch.sum( + 0.5 * q.belief.unsqueeze(-1) *\ + (p.loc - q.loc).pow(2) * p.precision_diag * p.df.unsqueeze(-1), + dim=-1 + ) + expected_conditioned_normal_KL = ( + weighted_mse_term + (0.5 * p.dimensionality) * ( + torch.div(q.belief, p.belief) - torch.div(q.belief, p.belief).log() - 1 + ) + ) + + return expected_conditioned_normal_KL + wishart_KL + + +if __name__ == '__main__': + import numpy as np + from scipy.stats import wishart + x = np.linspace(1e-6, 20, 100) + + print("Testing wishart entropy/logprob vs scipy implementation...") + for k in range(1000): + df_val = torch.randn(1).exp() + 2 + scale_val = torch.randn(1).exp() + + scipy_dist = wishart(df=df_val.item(), scale=scale_val.item()) + torch_dist = DiagonalWishart( + scale_val.unsqueeze(-1), + df_val + ) + + torch_ent = torch_dist.entropy()[0] + scipy_ent = torch.FloatTensor([scipy_dist.entropy()]) + if (rel_error(torch_ent, scipy_ent) > 1e-3).any(): + raise ValueError("Entropies of torch and scipy versions doesn't match") + + scipy_w = torch.FloatTensor(scipy_dist.logpdf(x)) + torch_w = torch_dist.log_prob(torch.FloatTensor(x).unsqueeze(-1)) + + if (rel_error(torch_w, scipy_w) > 1e-6).any(): + raise ValueError("Log pdf of torch and scipy versions doesn't match") + + print("Testing wishart KL divergence...") + df1, scale1 = torch.randn(32).exp() + 2, torch.randn(32).exp() + 1e-5 + df2, scale2 = torch.randn(32).exp() + 2, torch.randn(32).exp() + 1e-5 + init_df1, init_scale1 = df1[0].clone(), scale1[0].clone() + dist2 = DiagonalWishart(scale2.unsqueeze(-1), df2) + df1.requires_grad, scale1.requires_grad = True, True + gamma = 0.1 + for k in range(10000): + dist1 = DiagonalWishart(scale1.unsqueeze(-1), df1) + loss = kl_divergence(dist1, dist2).mean() + if k % 1000 == 0: + print(k, loss.item()) + loss.backward() + with torch.no_grad(): + scale1 = scale1 - gamma * scale1.grad + df1 = df1 - gamma * df1.grad + scale1.requires_grad, df1.requires_grad = True, True + print('df1 init', init_df1, init_scale1) + print('df1 final', df1[0], scale1[0]) + print('df2', df2[0], scale2[0]) + + print("All tests passed.") + + print("Testing normal wishart...") + y = np.linspace(5, 20, 100) + torch_dist = NormalDiagonalWishart( + torch.tensor([10]).float().view(1, 1).repeat(100, 1), + torch.tensor([2.57]).float().view(1, 1).repeat(100, 1), + torch.tensor([0.7]).float().repeat(100), + torch.tensor([3.33]).float().repeat(100), + ) + + ex_w = torch_dist.log_prob( + torch.FloatTensor(x).unsqueeze(-1), + torch.FloatTensor(y).unsqueeze(-1) + ) + print(ex_w.shape) + ex_w = torch_dist.entropy()[0] + #print(ex_w) diff --git a/src/distributions/mixture_distribution.py b/src/distributions/mixture_distribution.py new file mode 100644 index 0000000..67efe71 --- /dev/null +++ b/src/distributions/mixture_distribution.py @@ -0,0 +1,78 @@ +import torch +from itertools import combinations +from torch.distributions import Distribution, Normal, kl_divergence + + +class GaussianDiagonalMixture(Distribution): + r""" + Creates a mixture of diagonal Normal distributions parameterized + by their means :attr:`means` and scales :attr:`scales`. + """ + def __init__(self, means, scales): + assert len(means) == len(scales) + assert means[0].size(-1) == 1 and scales[0].size(-1) == 1 + + self.distributions = [] + for i in range(len(means)): + self.distributions.append( + Normal(means[i], scales[i], validate_args=True) + ) + + def expected_mean(self): + return sum([dist.mean for dist in self.distributions]) / len(self.distributions) + + def expected_entropy(self): + return sum([dist.entropy().squeeze() for dist in self.distributions]) / len(self.distributions) + + def expected_pairwise_kl(self): + curr_sum_pairwise_kl = None + num_pairs = 0 + + for dist1, dist2 in combinations(self.distributions, r=2): + num_pairs += 1 + if curr_sum_pairwise_kl is None: + curr_sum_pairwise_kl = kl_divergence(dist1, dist2) + else: + curr_sum_pairwise_kl += kl_divergence(dist1, dist2) + return curr_sum_pairwise_kl.squeeze() / num_pairs + + def variance_of_expected(self): + avg_mean = self.expected_mean() + return sum([(dist.mean.pow(2) - avg_mean.pow(2)).squeeze() for dist in self.distributions]) / len(self.distributions) + + def log_variance_of_expected(self): + return self.variance_of_expected().log() + + def expected_variance(self): + return sum([dist.variance.squeeze() for dist in self.distributions]) / len(self.distributions) + + def log_expected_variance(self): + return self.expected_variance().log() + + def total_variance(self): + return self.variance_of_expected() + self.expected_variance() + + def log_total_variance(self): + return self.total_variance().log() + + def estimated_total_entropy(self): + return self.expected_entropy() + self.expected_pairwise_kl() + + def log_prob(self, value): + mean = self.expected_mean() + var = self.total_variance().unsqueeze(-1) + return Normal(mean, var.pow(0.5)).log_prob(value) + + +if __name__ == "__main__": + ex_means = [torch.ones(32, 1) for _ in range(5)] + ex_vars = [2 * torch.ones(32, 1) for _ in range(5)] + mixture_dis = GaussianDiagonalMixture(ex_means, ex_vars) + print(mixture_dis.expected_mean().shape) + print(mixture_dis.log_prob(torch.zeros(32, 1)).shape) + + print(mixture_dis.expected_entropy().shape) + print(mixture_dis.expected_pairwise_kl().shape) + print(mixture_dis.variance_of_expected().shape) + print(mixture_dis.expected_variance().shape) + print(mixture_dis.total_variance().shape) diff --git a/src/distributions/prior_distribution.py b/src/distributions/prior_distribution.py new file mode 100644 index 0000000..1344aee --- /dev/null +++ b/src/distributions/prior_distribution.py @@ -0,0 +1,122 @@ +import math, torch + +from torch.distributions import StudentT +from .distributions import NormalDiagonalWishart +from src.utils.func_utils import mvdigamma + + +class NormalWishartPrior(NormalDiagonalWishart): + + def forward(self): + self.precision_coeff = (self.belief + 1) / ( + self.belief * (self.df - self.dimensionality + 1) + ) + return StudentT( + (self.df - self.dimensionality + 1).unsqueeze(-1), + loc=self.loc, + scale=(self.precision_coeff.unsqueeze(-1) / self.precision_diag).pow(0.5), + ) + + def predictive_posterior_log_prob(self, value): + return self.forward().log_prob(value) + + def predictive_posterior_variance(self): + variance_res = self.forward().variance + if variance_res.size(-1) != 1: + raise ValueError("Predictive posterior returned entropy with incorrect shapes") + return variance_res[..., 0] + + def log_predictive_posterior_variance(self): + return self.predictive_posterior_variance().log() + + def predictive_posterior_entropy(self): + entropy_res = self.forward().entropy() + if entropy_res.size(-1) != 1: + raise ValueError("Predictive posterior returned entropy with incorrect shapes") + return entropy_res[..., 0] + + def entropy_ub(self): + return self.expected_pairwise_kl() + self.expected_entropy() + + def expected_entropy(self): + mvdigamma_term = mvdigamma(0.5 * self.df, self.dimensionality) + return 0.5 * ( + self.dimensionality * (1 + math.log(2 * math.pi)) + - (2 * self.precision_diag).log().sum(dim=-1) + - mvdigamma_term.squeeze() + ) + + def expected_log_prob(self, value): + neg_mse_term = -torch.sum( + (self.loc - value).pow(2) * self.precision_diag * self.df.unsqueeze(-1), + dim = -1 + ) + mvdigamma_term = mvdigamma(0.5 * self.df, self.dimensionality) + + reg_terms = (2 * self.precision_diag).log().sum(dim=-1) + mvdigamma_term + conf_term = -self.dimensionality * self.belief.pow(-1) + return 0.5 * (neg_mse_term + reg_terms + conf_term) + + def mutual_information(self): + predictive_posterior_entropy = self.predictive_posterior_entropy() + expected_entropy = self.expected_entropy() + return predictive_posterior_entropy - expected_entropy + + def expected_pairwise_kl(self): + term1 = 0.5 * ( + self.df * self.dimensionality / (self.df - self.dimensionality - 1) -\ + self.dimensionality + ) + term2 = 0.5 * ( + self.df * self.dimensionality / (self.df - self.dimensionality - 1) +\ + self.dimensionality + ) / self.belief + return term1 + term2 + + def variance_of_expected(self): + return self.expected_variance() / self.belief + + def log_variance_of_expected(self): + return self.variance_of_expected().log() + + def expected_variance(self): + result = 1 / (self.precision_diag * (self.df.unsqueeze(-1) - self.dimensionality - 1)) + if result.size(-1) != 1: + raise ValueError("Expected variance currently supports only one-dimensional targets") + + return result[..., 0] + + def log_expected_variance(self): + return self.expected_variance().log() + + def total_variance(self): + tv = self.variance_of_expected() + self.expected_variance() + ppv = self.predictive_posterior_variance() + + rel_diff = (tv - ppv).abs() / tv.abs().pow(0.5) / ppv.abs().pow(0.5) + assert (rel_diff < 1e-6).all() + return tv + + def log_total_variance(self): + return self.total_variance().log() + + +if __name__ == '__main__': + ex_mean = torch.zeros(32, 200, 400, 1) + ex_var = torch.ones(32, 200, 400, 1) + ex_belief = torch.ones(32, 200, 400) + ex_df = 10 * torch.ones(32, 200, 400) + + ex_dist = NormalWishartPrior(ex_mean, ex_var, ex_belief, ex_df) + print(ex_dist.predictive_posterior_log_prob(2 * torch.ones(32, 200, 400, 1)).shape) + print(ex_dist.log_prob(2 * torch.ones(32, 200, 400, 1), 2 * torch.ones(32, 200, 400, 1)).shape) + + print(ex_dist.predictive_posterior_entropy().shape) #Total + print(ex_dist.expected_entropy().shape) #Data + print(ex_dist.mutual_information().shape) #Knowledge + print(ex_dist.expected_pairwise_kl().shape) #Knowledge + print(ex_dist.variance_of_expected().shape) #Knowledge + print(ex_dist.expected_variance().shape) #Data + print(ex_dist.total_variance().shape) #Total + print(ex_dist.predictive_posterior_variance().shape) #Total + diff --git a/src/models/__pycache__/simple_model.cpython-37.pyc b/src/models/__pycache__/simple_model.cpython-37.pyc new file mode 100644 index 0000000..aed213d Binary files /dev/null and b/src/models/__pycache__/simple_model.cpython-37.pyc differ diff --git a/src/models/simple_model.py b/src/models/simple_model.py new file mode 100644 index 0000000..6ccb508 --- /dev/null +++ b/src/models/simple_model.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn + + +class GaussianNoise(nn.Module): + def __init__(self, mean=0.0, sigma=0.05): + super(GaussianNoise, self).__init__() + self.mean = mean + self.sigma = sigma + + def forward(self, input): + if not self.training: + return input + noise = input.clone().normal_(self.mean, self.sigma) + return input + noise + + +class SimpleModel(nn.Module): + def __init__( + self, input_dim, output_dim, num_units, + num_hidden=1, activation=nn.ReLU, isPrior=False, + drop_rate=0.0, use_bn=False, noise_level=0.05 + ): + super(SimpleModel, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_hidden = num_hidden + self.isPrior = isPrior + self.use_bn = use_bn + + # dense network with %num_units hidden layers + self.features, curr_dim = [], input_dim + self.features.append(GaussianNoise(sigma=noise_level)) + for _ in range(num_hidden): + self.features.append(nn.Linear(curr_dim, num_units)) + if self.use_bn: + self.features.append(nn.BatchNorm1d(num_units)) + self.features.append(activation()) + if drop_rate > 0.0: + self.features.append(nn.Dropout(drop_rate)) + curr_dim = num_units + self.features = nn.Sequential(*self.features) + + # generate stats of output distribution + self.layer_mean = nn.Linear(num_units, output_dim) + self.layer_std = nn.Sequential( + nn.Linear(num_units, output_dim), + nn.Softplus() + ) + if isPrior: + self.layer_beta = nn.Sequential( + nn.Linear(num_units, 1), + nn.Softplus() + ) + + self._initialize_weights() + + def forward(self, x): + x = x.view(-1, self.input_dim) + x = self.features(x) + + mean = self.layer_mean(x) + std = self.layer_std(x) + 1e-6 + if self.isPrior: + beta = self.layer_beta(x) + 1e-6 + kappa = beta + nu = beta + self.output_dim + 1 + return mean, std, kappa[..., 0], nu[..., 0] + else: + return mean, std + + def _initialize_weights(self): + """Initialize weights as in + `Probabilistic Backpropagation for Scalable Learning of Bayesian Neural Networks` + (https://arxiv.org/pdf/1502.05336.pdf), section 3.5 + """ + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 1.0 / (m.weight.size(1) + 1)) + nn.init.constant_(m.bias, 0) diff --git a/src/training/__pycache__/ood_trainers.cpython-37.pyc b/src/training/__pycache__/ood_trainers.cpython-37.pyc new file mode 100644 index 0000000..f589172 Binary files /dev/null and b/src/training/__pycache__/ood_trainers.cpython-37.pyc differ diff --git a/src/training/__pycache__/trainers.cpython-37.pyc b/src/training/__pycache__/trainers.cpython-37.pyc new file mode 100644 index 0000000..1bca611 Binary files /dev/null and b/src/training/__pycache__/trainers.cpython-37.pyc differ diff --git a/src/training/ood_trainers.py b/src/training/ood_trainers.py new file mode 100644 index 0000000..7c9300d --- /dev/null +++ b/src/training/ood_trainers.py @@ -0,0 +1,292 @@ +from typing import Tuple +from itertools import cycle + +import torch +from torch.distributions import Distribution +import torch.nn as nn +from torch.optim import Adam + +from src.training.trainers import DistributionRKLTrainer +from src.utils.func_utils import reduce_tensor, params_rmse + + +class DistributionRKLTrainerWithOOD(DistributionRKLTrainer): + @property + def uncertainty_methods(self): + return [ + 'predictive_posterior_entropy', 'expected_entropy', + 'mutual_information', 'expected_pairwise_kl', + 'variance_of_expected', 'expected_variance', + 'total_variance' + ] + + def train_step(self, x, y, x_ood): + self.optimizer.zero_grad() + + predicted_params = self.model(x) + prior_params = self.prior_converter(x) + + ordinary_loss = self.loss_function( + predicted_params, + prior_params, + y + ) + + self.switch_bn_updates("eval") + predicted_ood_params = self.model(x_ood) + prior_params = self.prior_converter(x_ood) + + ood_loss = self.loss_function( + predicted_ood_params, + prior_params + ) + self.switch_bn_updates("train") + loss = ordinary_loss + self.loss_params["ood_coeff"] * ood_loss + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + return ordinary_loss.item(), self.loss_params["ood_coeff"] * ood_loss.item() + + def loss_function(self, predicted_params, prior_params, targets=None): + if targets is None: + return self.loss_params["inv_real_beta"] * self.rkl_loss( + predicted_params, prior_params, reduction='mean' + ) + else: + predicted_dist = self.distribution(*predicted_params) + inv_beta = self.loss_params["inv_real_beta"] + return -predicted_dist.expected_log_prob(targets).mean() + inv_beta * self.rkl_loss( + predicted_params, prior_params, reduction='mean' + ) + + def eval_step(self, dataloader: torch.utils.data.DataLoader) -> Tuple[float, list]: + self.model.eval() + acc_eval_loss = 0.0 + acc_metrics = [0.0 for m in self.test_metrics] + with torch.no_grad(): + for i, (x, y) in enumerate(dataloader): + x, y = x.to(self.device), y.to(self.device) + predicted_params = self.model(x) + prior_params = self.prior_converter(x) + acc_eval_loss += self.loss_function( + predicted_params, + prior_params, + y + ).item() / len(dataloader) + for i, metric in enumerate(self.test_metrics): + acc_metrics[i] += metric( + predicted_params, + y + ) / len(dataloader) + return acc_eval_loss, acc_metrics + + def train(self, dataloader, oodloader, num_epochs, eval_dataloader=None, log_per=0, verbose=True): + with torch.no_grad(): + self.estimate_avg_mean_var(dataloader) + + trainloss_hist, oodloss_hist, valloss_hist, metrics_hist = [], [], [], [] + + for e in range(num_epochs): + self.model.train() + acc_train_loss = 0.0 + acc_ood_loss = 0.0 + """With only zip() the iterator will be exhausted when the length + is equal to that of the smallest dataset. + But with the use of cycle(), we will repeat the smallest dataset again unless + our iterator looks at all the samples from the largest dataset.""" + for (x, y), (x_ood,) in zip(dataloader, cycle(oodloader)): + x, y, x_ood = x.to(self.device), y.to(self.device), x_ood.to(self.device) + c_losses = self.train_step(x, y, x_ood) + acc_train_loss += c_losses[0] / len(dataloader) + acc_ood_loss += c_losses[1] / len(dataloader) + + trainloss_hist += [acc_train_loss] + oodloss_hist += [acc_ood_loss] + + if eval_dataloader and log_per > 0 and self.test_metrics: + if e % log_per == 0 or e == (num_epochs - 1): + acc_eval_loss, acc_metrics = self.eval_step(eval_dataloader) + + if verbose: + print("Epoch %d train loss %.3f ood loss %.3f eval loss %.3f" % ( + e, acc_train_loss, acc_ood_loss, acc_eval_loss + ), 'eval ' + ','.join(m.__name__ + " %.3f" % acc_metrics[i] for i, m in enumerate(self.test_metrics)), + flush=True + ) + valloss_hist += [acc_eval_loss] + metrics_hist += [acc_metrics] + + if self.scheduler: + self.scheduler.step() + + return trainloss_hist, oodloss_hist, valloss_hist, metrics_hist + + def estimate_avg_mean_var(self, dataloader): + self.avg_mean = None + for _, y in dataloader: + if self.avg_mean is None: + self.avg_mean = y.mean(dim=0) / len(dataloader) + else: + self.avg_mean += y.mean(dim=0) / len(dataloader) + sum_var = torch.zeros_like(self.avg_mean) + num_samples = 0 + for _, y in dataloader: + avg_mean = torch.repeat_interleave(self.avg_mean.unsqueeze(0), repeats=y.size(0), dim=0) + sum_var += (y - avg_mean).pow(2).sum(dim=0) + num_samples += y.size(0) + self.avg_scatter = sum_var / num_samples + + def prior_converter(self, inputs): + avg_mean_r = torch.repeat_interleave( + self.avg_mean.unsqueeze(0), repeats=inputs.size(0), dim=0 + ).to(inputs.device) + prior_kappa, prior_nu = self.loss_params['prior_beta'],\ + self.loss_params['prior_beta'] + self.model.output_dim + 1 + avg_precision_r = torch.repeat_interleave( + (1 / (prior_nu * self.avg_scatter.unsqueeze(0))), repeats=inputs.size(0), dim=0 + ).to(inputs.device) + + all_params = [avg_mean_r, avg_precision_r] + return all_params + [prior_kappa, prior_nu] + + def nll_loss(self, predicted_params, targets, reduction='mean'): + assert reduction in ['mean', 'sum', 'none'] + predicted_dist = self.distribution(*predicted_params) + batched_loss = -predicted_dist.predictive_posterior_log_prob(targets) + assert batched_loss.dim() < 2 or batched_loss.size(-1) == 1 + return reduce_tensor(batched_loss, reduction) + + def switch_bn_updates(self, mode): + if self.model.use_bn: + for m in self.model.modules(): + if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d): + if mode == 'train': + m.train() + elif mode == 'eval': + m.eval() + + def check_loss_params(self, loss_params): + for req_key in ['inv_real_beta', 'ood_coeff', 'prior_beta']: + if req_key not in loss_params.keys(): + raise Exception("Rkl loss params dict should contain key", req_key) + + +class DistributionEnsembleToPriorDistiller(DistributionRKLTrainer): + def __init__( + self, teacher_models: list, *args, **kwargs + ): + super(DistributionEnsembleToPriorDistiller, self).__init__(*args, **kwargs) + self.teacher_models = teacher_models + self.num_steps = 1 + for model in self.teacher_models: + model.eval() + self.loss_params['temperature'] = self.loss_params['max_temperature'] + + @property + def uncertainty_methods(self): + return [ + 'predictive_posterior_entropy', 'expected_entropy', + 'mutual_information', 'expected_pairwise_kl', + 'variance_of_expected', 'expected_variance', + 'total_variance' + ] + + def train_step(self, x, y): + x += torch.empty(x.shape).normal_( + mean=0, std=self.loss_params["noise_level"] + ).to(x.device) + + if "max_steps" in self.loss_params.keys(): + T_0 = self.loss_params["max_temperature"] + first_part = float(0.2 * self.loss_params["max_steps"]) + third_part = float(0.6 * self.loss_params["max_steps"]) + if self.num_steps < first_part: + self.loss_params["temperature"] = T_0 + elif self.num_steps < third_part: + self.loss_params["temperature"] = T_0 - (T_0 - 1) * min( + float(self.num_steps - first_part) / float(0.4 * self.loss_params["max_steps"]), + 1.0 + ) + else: + self.loss_params["temperature"] = 1.0 + + self.optimizer.zero_grad() + + predicted_params = self.model(x) + loss = self.loss_function( + predicted_params, + x + ) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + + self.optimizer.step() + self.num_steps += 1 + return loss.item() + + def eval_step(self, dataloader: torch.utils.data.DataLoader) -> Tuple[float, list]: + self.model.eval() + acc_eval_loss = 0.0 + acc_metrics = [0.0 for m in self.test_metrics] + with torch.no_grad(): + for x, y in dataloader: + x, y = x.to(self.device), y.to(self.device) + predicted_params = self.model(x) + acc_eval_loss += self.loss_function( + predicted_params, + x + ).item() / len(dataloader) + for i, metric in enumerate(self.test_metrics): + acc_metrics[i] += metric( + predicted_params, + y + ) / len(dataloader) + return acc_eval_loss, acc_metrics + + def loss_function(self, predicted_params, x): + T = self.loss_params["temperature"] + with torch.no_grad(): + all_teachers_means, all_teachers_vars = [], [] + aggr_teachers_mean = torch.zeros_like(predicted_params[0]).to(self.device) + aggr_teachers_var = torch.zeros_like(predicted_params[1]).to(self.device) + for i, teacher in enumerate(self.teacher_models): + teacher_params = teacher(x) + aggr_teachers_mean += teacher_params[0] / len(self.teacher_models) + aggr_teachers_var += teacher_params[1].pow(2) / len(self.teacher_models) + all_teachers_means.append(teacher_params[0]) + all_teachers_vars.append(teacher_params[1].pow(2)) + for i, _ in enumerate(self.teacher_models): + all_teachers_means[i] = (T - 1) * aggr_teachers_mean / (T + 1) +\ + 2 * all_teachers_means[i] / (T + 1) + all_teachers_vars[i] = (T - 1) * aggr_teachers_var / (T + 1) +\ + 2 * all_teachers_vars[i] / (T + 1) + + new_nu = (predicted_params[-1] - self.model.output_dim - 1) * T +\ + self.model.output_dim + 1 + new_kappa = predicted_params[-2] * T + #new_nu = (predicted_params[-1] - self.model.output_dim - 2) * (1 / T) +\ + # self.model.output_dim + 2 + #new_kappa = predicted_params[-2] * (1 / T) + + predicted_dist = self.distribution(*[ + predicted_params[0], predicted_params[1], new_kappa, new_nu + ]) + all_losses = [] + for i in range(len(self.teacher_models)): + all_losses.append(-predicted_dist.log_prob( + all_teachers_means[i], 1 / all_teachers_vars[i] + ).sum()) + + return sum(all_losses) / len(all_losses) / T + + def nll_loss(self, predicted_params, targets, reduction='mean'): + assert reduction in ['mean', 'sum', 'none'] + predicted_dist = self.distribution(*predicted_params) + batched_loss = -predicted_dist.predictive_posterior_log_prob(targets) + assert batched_loss.dim() < 2 or batched_loss.size(-1) == 1 + return reduce_tensor(batched_loss, reduction) + + def check_loss_params(self, loss_params): + for req_key in ['max_temperature', 'noise_level']: + if req_key not in loss_params.keys(): + raise Exception("NLL loss params dict should contain key", req_key) diff --git a/src/training/trainers.py b/src/training/trainers.py new file mode 100644 index 0000000..bc7db84 --- /dev/null +++ b/src/training/trainers.py @@ -0,0 +1,297 @@ +from typing import Tuple + +import torch +from torch.distributions import Distribution +from torch.distributions.kl import kl_divergence +from torch.utils.data import SequentialSampler + +from sklearn.metrics import mean_squared_error + + +def reduce_tensor(vec: torch.Tensor, reduction: str = 'mean'): + """Global reduction of tensor based on str + + Args: + vec: torch.FloatTensor + reduction: str, one of ['sum', 'mean', 'none'], default 'mean' + """ + assert reduction in ['sum', 'mean', 'none'] + if reduction == 'mean': + return vec.mean() + elif reduction == 'sum': + return vec.sum() + elif reduction == 'none': + return vec + + +def params_rmse(predicted_params, targets): + assert not torch.isnan(predicted_params[0]).any() + return mean_squared_error(predicted_params[0].cpu(), targets.cpu()) ** 0.5 + + + + +class DistributionMLETrainer: + """This class implements MLE training for a model that outputs parameters + of some distribution. Note that both trained model and optimizer instances + are created inside it. + """ + def __init__( + self, model_params: dict, model: torch.nn.Module, + optim_params: dict, distribution=Distribution, + optimizer=torch.optim.Adam, scheduler=None, scheduler_params=None, + test_metrics=[params_rmse], device='cuda:0' + ): + self.model = model(**model_params).to(device) + self.distribution = distribution + self.optimizer = optimizer(self.model.parameters(), **optim_params) + self.scheduler = None + if scheduler is not None: + self.scheduler = scheduler(self.optimizer, **scheduler_params) + self.test_metrics = test_metrics + self.device = device + + @property + def uncertainty_methods(self): + return ['entropy'] + + def train_step(self, x: torch.FloatTensor, y: torch.FloatTensor) -> float: + self.optimizer.zero_grad() + predicted_params = self.model(x) + loss = self.loss_function( + predicted_params, + y + ) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + return loss.item() + + def eval_step(self, dataloader: torch.utils.data.DataLoader) -> Tuple[float, list]: + self.model.eval() + acc_eval_loss = 0.0 + acc_metrics = [0.0 for m in self.test_metrics] + with torch.no_grad(): + for x, y in dataloader: + x, y = x.to(self.device), y.to(self.device) + predicted_params = self.model(x) + acc_eval_loss += self.loss_function( + predicted_params, + y + ).item() / len(dataloader) + for i, metric in enumerate(self.test_metrics): + acc_metrics[i] += metric( + predicted_params, + y + ) / len(dataloader) + return acc_eval_loss, acc_metrics + + def train(self, dataloader: torch.utils.data.DataLoader, + num_epochs: int, eval_dataloader: torch.utils.data.DataLoader = None, + log_per: int = 0, verbose: str =True + ) -> Tuple[list, list, list]: + trainloss_hist, valloss_hist, metrics_hist = [], [], [] + + for e in range(num_epochs): + self.model.train() + acc_train_loss = 0.0 + for x, y in dataloader: + x, y = x.to(self.device), y.to(self.device) + acc_train_loss += self.train_step(x, y) / len(dataloader) + + trainloss_hist += [acc_train_loss] + + if eval_dataloader and log_per > 0 and self.test_metrics: + if e % log_per == 0 or e == (num_epochs - 1): + acc_eval_loss, acc_metrics = self.eval_step(eval_dataloader) + if verbose: + print("Epoch %d train loss %.3f eval loss %.3f" % ( + e, acc_train_loss, acc_eval_loss + ), 'eval ' + ','.join( + m.__name__ + " %.3f" % acc_metrics[i] + for i, m in enumerate(self.test_metrics) + ), flush=True + ) + valloss_hist += [acc_eval_loss] + metrics_hist += [acc_metrics] + + if self.scheduler: + self.scheduler.step() + + return trainloss_hist, valloss_hist, metrics_hist + + def compute_unsupervised_metric_through_data( + self, dataloader: torch.utils.data.DataLoader, metric + ) -> torch.FloatTensor: + metric_scores = [] + self.model.eval() + with torch.no_grad(): + for items in dataloader: + if isinstance(items, tuple) or isinstance(items, list): + predicted_params = self.model(items[0].to(self.device)) + else: + predicted_params = self.model(items.to(self.device)) + metric_scores += metric( + predicted_params + ).cpu().tolist() + return torch.FloatTensor(metric_scores) + + def loss_function(self, predicted_params, targets): + return self.nll_loss(predicted_params, targets, reduction='mean') + + def nll_loss(self, predicted_params, targets, reduction='mean'): + assert reduction in ['mean', 'sum', 'none'] + predicted_dist = self.distribution(*predicted_params) + batched_loss = -predicted_dist.log_prob(targets) + assert batched_loss.dim() < 2 or batched_loss.size(-1) == 1 + return reduce_tensor(batched_loss, reduction) + + def get_predicted_params(self, dataloader: torch.utils.data.DataLoader) -> list: + all_predicted_params = [] + self.model.eval() + with torch.no_grad(): + for items in dataloader: + if isinstance(items, tuple) or isinstance(items, list): + predicted_params = self.model(items[0].to(self.device)) + else: + predicted_params = self.model(items.to(self.device)) + if len(all_predicted_params) == 0: + all_predicted_params = [param.cpu().tolist() for param in predicted_params] + else: + for i in range(len(all_predicted_params)): + all_predicted_params[i] += predicted_params[i].cpu().tolist() + return [torch.FloatTensor(param) for param in all_predicted_params] + + def eval_uncertainty(self, dataloader, method: str = 'entropy'): + some_metric = lambda params: getattr(self.distribution(*params), method)() + return self.compute_unsupervised_metric_through_data(dataloader, some_metric) + + def save_model(self, dir: str): + torch.save(self.model.state_dict(), dir + '.ckpt') + + def load_model(self, dir: str): + self.model.load_state_dict(torch.load(dir + '.ckpt')) + + +class DistributionRKLTrainer(DistributionMLETrainer): + """This class replaces standard MLE loss with Reverse-KL. + """ + def __init__(self, loss_params: dict, *args, **kwargs): + super(DistributionRKLTrainer, self).__init__(*args, **kwargs) + self.check_loss_params(loss_params) + self.loss_params = loss_params + + def loss_function(self, predicted_params, targets): + target_params = self.converter(targets) + return self.rkl_loss(predicted_params, target_params, reduction='mean') + + def rkl_loss(self, predicted_params, target_params, reduction='mean'): + assert reduction in ['mean', 'sum', 'none'] + predicted_dist = self.distribution(*predicted_params) + true_dist = self.distribution(*target_params) + batched_loss = kl_divergence(predicted_dist, true_dist) + assert batched_loss.dim() < 2 or batched_loss.size(-1) == 1 + return reduce_tensor(batched_loss, reduction) + + def converter(self, targets): + """Extend targets with manually specified params to parametrize target distribution""" + return [targets] + self.loss_params["real_params"] + + def check_loss_params(self, loss_params): + for req_key in ['target_lambda', 'target_kappa', 'target_nu']: + if req_key not in loss_params.keys(): + raise Exception("Rkl loss params dict should contain key", req_key) + + +class DistributionEnsembleMLETrainer: + """This class sequentially trains multiple models and combines their outputs in an + ensemble distribution. Besides more accurate predictions, this allows + us to decompose uncertainty measures. + """ + def __init__( + self, n_models: int, mixture_distribution=Distribution, + *args, **kwargs + ): + self.trainers = [ + DistributionMLETrainer(*args, **kwargs) for _ in range(n_models) + ] + self.mixture_distribution = mixture_distribution + + @property + def uncertainty_methods(self): + return [ + 'expected_entropy', 'expected_pairwise_kl', + 'variance_of_expected', 'expected_variance', + 'total_variance' + ] + + def train(self, dataloader: torch.utils.data.DataLoader, + num_epochs: int, eval_dataloader: torch.utils.data.DataLoader = None, + log_per: int = 0, verbose: str =True + ) -> Tuple[list, list, list]: + train_hists, val_hists, metrics_hists = [], [], [] + for i, trainer in enumerate(self.trainers): + if verbose: + print('-'*20, flush=True) + print("Model %d" % i, flush=True) + res = trainer.train(dataloader, num_epochs, eval_dataloader, log_per, verbose) + train_hists.append(res[0]) + val_hists.append(res[1]) + metrics_hists.append(res[2]) + return train_hists, val_hists, metrics_hists + + def nll_loss(self, predicted_params, targets, reduction='mean'): + assert reduction in ['mean', 'sum', 'none'] + predicted_dist = self.mixture_distribution(*predicted_params) + batched_loss = -predicted_dist.log_prob(targets) + assert batched_loss.dim() < 2 or batched_loss.size(-1) == 1 + return reduce_tensor(batched_loss, reduction) + + def get_predicted_params(self, dataloader: torch.utils.data.DataLoader) -> list: + if not isinstance(dataloader.sampler, SequentialSampler): + print(dataloader.batch_sampler) + raise ValueError("To merge predicted params correctly dataloader shouldn't shuffle") + all_means, all_stds = [], [] + for trainer in self.trainers: + cmean, cstd = trainer.get_predicted_params(dataloader) + + all_means.append(cmean) + all_stds.append(cstd) + return all_means, all_stds + + def compute_unsupervised_metric_through_data( + self, dataloader: torch.utils.data.DataLoader, metric + ) -> torch.FloatTensor: + metric_scores = [] + for trainer in self.trainers: + trainer.model.eval() + with torch.no_grad(): + for items in dataloader: + all_means, all_stds = [], [] + if isinstance(items, tuple) or isinstance(items, list): + for trainer in self.trainers: + cmean, cstd = trainer.model(items[0].to(trainer.device)) + all_means.append(cmean) + all_stds.append(cstd) + else: + for trainer in self.trainers: + cmean, cstd = trainer.model(items.to(trainer.device)) + all_means.append(cmean) + all_stds.append(cstd) + if len(all_means[0]) > 1: + metric_scores += metric( + [all_means, all_stds] + ).cpu().tolist() + return torch.FloatTensor(metric_scores) + + def save_model(self, dir: str): + for i, trainer in enumerate(self.trainers): + trainer.save_model(dir + '_' + str(i)) + + def load_model(self, dir: str): + for i, trainer in enumerate(self.trainers): + trainer.load_model(dir + '_' + str(i)) + + def eval_uncertainty(self, dataloader, method: str = 'expected_pairwise_kl'): + some_metric = lambda params: getattr(self.mixture_distribution(*params), method)() + return self.compute_unsupervised_metric_through_data(dataloader, some_metric) diff --git a/src/utils/__pycache__/func_utils.cpython-37.pyc b/src/utils/__pycache__/func_utils.cpython-37.pyc new file mode 100644 index 0000000..7a02b2f Binary files /dev/null and b/src/utils/__pycache__/func_utils.cpython-37.pyc differ diff --git a/src/utils/func_utils.py b/src/utils/func_utils.py new file mode 100644 index 0000000..9a7db93 --- /dev/null +++ b/src/utils/func_utils.py @@ -0,0 +1,100 @@ +from typing import Union +import torch + +from sklearn.metrics import mean_squared_error + + +def percentile(t: torch.tensor, q: float) -> Union[int, float]: + """ + Return the ``q``-th percentile of the flattened input tensor's data. + + CAUTION: + * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. + * Values are not interpolated, which corresponds to + ``numpy.percentile(..., interpolation="nearest")``. + + :param t: Input tensor. + :param q: Percentile to compute, which must be between 0 and 100 inclusive. + :return: Resulting value (scalar). + """ + # Note that ``kthvalue()`` works one-based, i.e. the first sorted value + # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, + # so that ``round()`` returns an integer, even if q is a np.float32. + k = 1 + round(.01 * float(q) * (t.numel() - 1)) + result = t.view(-1).kthvalue(k).values.item() + return result + + +def mvdigamma(vec: torch.FloatTensor, p: int, reduction: str = 'sum'): + """Implements batched Multivariate digamma function over a given float vector + + Args: + vec: torch.FloatTensor of shapes (bs, ...), inp to apply function on + p: int, dimensionality + reduction: str, one of ['sum', 'mean', 'none'], default 'sum' + Returns: + Tensor with same shapes as vec, where the mvdigamma function is + computed for each position independently + """ + assert reduction in ['sum', 'mean'] + + increasing_numbers = torch.arange( + 1, p + 1, dtype=torch.float, requires_grad=False + ) + output = torch.digamma( + vec.unsqueeze(-1) + 0.5 * (1 - increasing_numbers.to(vec.device)) + ) + + if reduction == 'sum': + return output.sum(axis=-1) + elif reduction == 'mean': + return output.mean(axis=-1) + + +def reduce_tensor(vec: torch.Tensor, reduction: str = 'mean'): + """Global reduction of tensor based on str + + Args: + vec: torch.FloatTensor + reduction: str, one of ['sum', 'mean', 'none'], default 'mean' + """ + assert reduction in ['sum', 'mean', 'none'] + if reduction == 'mean': + return vec.mean() + elif reduction == 'sum': + return vec.sum() + elif reduction == 'none': + return vec + + +def params_rmse(predicted_params, targets): + assert not torch.isnan(predicted_params[0]).any() + return mean_squared_error(predicted_params[0].cpu(), targets.cpu()) ** 0.5 + + +def rel_error(value1, value2): + value1_norm = value1.norm(p=2, dim=-1) + value2_norm = value2.norm(p=2, dim=-1) + diff_norm = (value1 - value2).norm(p=2, dim=-1) + return diff_norm / (value1_norm.pow(0.5) * value2_norm.pow(0.5)) + + +if __name__ == '__main__': + from scipy.special import digamma + + for k in range(1000): + rand_nu = torch.randn(512).exp() + 1e-5 + digammas_scipy = torch.zeros(512) + digammas_torch = mvdigamma(rand_nu, 1) + + for k in range(512): + digammas_scipy[k] = digamma(rand_nu[k].item()) + + if (rel_error(digammas_scipy, digammas_torch) > 1e-6).any(): + raise Exception("Digamma functions of torch and scipy doesn't match") + + if ((rand_nu.log() - 1 / (2 * rand_nu) - digammas_torch) < 0.0).any(): + raise Exception("Upper inequality doesn't satisfied") + + if ((rand_nu.log() - 1 / (rand_nu) - digammas_torch) > 0.0).any(): + raise Exception("Lower inequality doesn't satisfied") diff --git a/uncertainty_example.ipynb b/uncertainty_example.ipynb new file mode 100644 index 0000000..1dd7a0c --- /dev/null +++ b/uncertainty_example.ipynb @@ -0,0 +1,1618 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Structure of this notebook:\n", + "#### 1) Uncertainty and deep ensembles for classification\n", + "#### 2) Uncertainty and deep ensembles for regression\n", + "#### 3) Ensemble distribution distillation and Prior networks for regression" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deep ensembles for classification" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Why ensembles?\n", + "\n", + "#### Assume that we have an ensemble of different probabilistic models that predict $y$ having features $x^*$: \n", + "#### $$ \\{ P(y|x^*, \\theta_m) \\}_{m=1..M}, \\theta_m \\sim P(\\theta|D),$$\n", + "#### where $P(\\theta|D)$ is the probability to train such model, having dataset $D$, some optimization procedure and model architecture.\n", + "#### In deep ensembles with train each model in the ensemble with different random inits.\n", + "#### Each model $P(y|x^*, \\theta_m)$ captures different estimate of data uncertainty, it has it's own view, own local optima in the task.\n", + "\n", + "#### H - entropy, for some discrete distribution with classes probabilities $p_i$:\n", + "$$H(p) = -\\sum_{i=1..N} p_i \\cdot \\log p_i$$\n", + "\n", + "#### Through ensemble we can calculate *Total uncertainty*:\n", + "#### $$H[\\mathbb{E}_{\\theta \\sim P(\\theta|D)} [P(y|x^*, \\theta)]] \\simeq H[\\frac{1}{M} \\sum_{m=1}^M [P(y|x^*, \\theta_m)]]$$\n", + "#### and *Expected data (aleatoric) uncertainty*:\n", + "#### $$\\mathbb{E}_{\\theta \\sim P(\\theta|D)} [H[P(y|x^*, \\theta)]] \\simeq \\frac{1}{M} \\sum_{m=1}^M [H[P(y|x^*, \\theta_m)]]$$\n", + "#### *Knowledge (epistemic) uncertainty* is the difference between *Total uncertainty* and *Expected Data uncertainty*:\n", + "#### $$\\mathcal{I}(y, \\theta| x^*, D) = H[\\mathbb{E}_{\\theta \\sim P(\\theta|D)} [P(y|x^*, \\theta)]] - \\mathbb{E}_{\\theta \\sim P(\\theta|D)} [H[P(y|x^*, \\theta)]]$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**To visualize the behaviour of ensemble in different cases let's consider 3 classes classification, and show it on triangular simplex, each corner in the triangular is some class, each point is the prediction of some model in the ensemble. The closer a point is to a corner, the greater its prediction of the probability of this class**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**There are different ways to generate ensembles, more about them you can find in video about \"Ensemble Generation\".**\n", + "\n", + "**Here we will use the simplest, yet the most effective way to generate ensembles for uncertainty estimation: we will train different models with different random seeds.**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**For simplicity, we will show all concepts on easy tasks, where you can train your model fast and play with it. You can fidn high scale experiments in scientific papers introduced in our track.**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.6.0 10.2\n" + ] + } + ], + "source": [ + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "import matplotlib.pyplot as plt\n", + "import torch.nn as nn\n", + "import seaborn as sns\n", + "import sklearn\n", + "\n", + "print(torch.__version__, torch.version.cuda)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**MNIST dataset**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "n_epochs = 3\n", + "batch_size_train = 64\n", + "batch_size_test = 200\n", + "learning_rate = 0.01\n", + "momentum = 0.5\n", + "num_networks = 5 # Ensemble size\n", + "\n", + "random_seed = 1\n", + "torch.manual_seed(random_seed)\n", + "device='cuda:0'" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "train_loader = torch.utils.data.DataLoader(\n", + " torchvision.datasets.MNIST('./data/', train=True, download=True,\n", + " transform=torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize(\n", + " (0.1307,), (0.3081,))\n", + " ])),\n", + " batch_size=batch_size_train, shuffle=True)\n", + "\n", + "test_loader = torch.utils.data.DataLoader(\n", + " torchvision.datasets.MNIST('./data/', train=False, download=True,\n", + " transform=torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize(\n", + " (0.1307,), (0.3081,))\n", + " ])),\n", + " batch_size=batch_size_test, shuffle=True)\n", + "\n", + "classes = tuple(str(i) for i in range(10))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 9 3 0 1\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "# functions to show an image\n", + "def imshow(img):\n", + " npimg = img.numpy()\n", + " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", + " plt.show()\n", + "\n", + "\n", + "# get some random training images\n", + "dataiter = iter(train_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "# show images\n", + "imshow(torchvision.utils.make_grid(images[:4]))\n", + "# print labels\n", + "print(' '.join('%5s' % classes[labels[j]] for j in range(4)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use simple network with two convolutional layers and two linear layers" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", + " self.conv2_drop = nn.Dropout2d()\n", + " self.fc1 = nn.Linear(320, 50)\n", + " self.fc2 = nn.Linear(50, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " x = x.view(-1, 320)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "\n", + "criterion = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "ensemble = []\n", + "for i in range(num_networks):\n", + " new_net = Net()\n", + " new_net.to(device)\n", + " ensemble.append(new_net)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network: 0 epoch: 1 step: 299 loss: 0.23045 \n", + "Network: 0 epoch: 1 step: 599 loss: 0.10454 \n", + "Network: 0 epoch: 1 step: 899 loss: 0.07892 \n", + "Network: 0 epoch: 2 step: 299 loss: 0.06506 \n", + "Network: 0 epoch: 2 step: 599 loss: 0.0577 \n", + "Network: 0 epoch: 2 step: 899 loss: 0.05441 \n", + "Network: 0 epoch: 3 step: 299 loss: 0.04942 \n", + "Network: 0 epoch: 3 step: 599 loss: 0.04697 \n", + "Network: 0 epoch: 3 step: 899 loss: 0.04368 \n", + "Network: 1 epoch: 1 step: 299 loss: 0.22487 \n", + "Network: 1 epoch: 1 step: 599 loss: 0.10296 \n", + "Network: 1 epoch: 1 step: 899 loss: 0.08061 \n", + "Network: 1 epoch: 2 step: 299 loss: 0.06297 \n", + "Network: 1 epoch: 2 step: 599 loss: 0.05738 \n", + "Network: 1 epoch: 2 step: 899 loss: 0.05285 \n", + "Network: 1 epoch: 3 step: 299 loss: 0.04823 \n", + "Network: 1 epoch: 3 step: 599 loss: 0.0444 \n", + "Network: 1 epoch: 3 step: 899 loss: 0.04269 \n", + "Network: 2 epoch: 1 step: 299 loss: 0.24886 \n", + "Network: 2 epoch: 1 step: 599 loss: 0.10192 \n", + "Network: 2 epoch: 1 step: 899 loss: 0.07642 \n", + "Network: 2 epoch: 2 step: 299 loss: 0.06149 \n", + "Network: 2 epoch: 2 step: 599 loss: 0.05567 \n", + "Network: 2 epoch: 2 step: 899 loss: 0.05212 \n", + "Network: 2 epoch: 3 step: 299 loss: 0.04732 \n", + "Network: 2 epoch: 3 step: 599 loss: 0.04288 \n", + "Network: 2 epoch: 3 step: 899 loss: 0.0423 \n", + "Network: 3 epoch: 1 step: 299 loss: 0.24683 \n", + "Network: 3 epoch: 1 step: 599 loss: 0.09415 \n", + "Network: 3 epoch: 1 step: 899 loss: 0.07193 \n", + "Network: 3 epoch: 2 step: 299 loss: 0.0601 \n", + "Network: 3 epoch: 2 step: 599 loss: 0.05246 \n", + "Network: 3 epoch: 2 step: 899 loss: 0.04885 \n", + "Network: 3 epoch: 3 step: 299 loss: 0.04606 \n", + "Network: 3 epoch: 3 step: 599 loss: 0.04303 \n", + "Network: 3 epoch: 3 step: 899 loss: 0.03979 \n", + "Network: 4 epoch: 1 step: 299 loss: 0.24504 \n", + "Network: 4 epoch: 1 step: 599 loss: 0.10257 \n", + "Network: 4 epoch: 1 step: 899 loss: 0.07364 \n", + "Network: 4 epoch: 2 step: 299 loss: 0.0606 \n", + "Network: 4 epoch: 2 step: 599 loss: 0.0552 \n", + "Network: 4 epoch: 2 step: 899 loss: 0.04713 \n", + "Network: 4 epoch: 3 step: 299 loss: 0.0453 \n", + "Network: 4 epoch: 3 step: 599 loss: 0.04275 \n", + "Network: 4 epoch: 3 step: 899 loss: 0.04027 \n", + "Finished Training\n", + "CPU times: user 1min 44s, sys: 265 ms, total: 1min 44s\n", + "Wall time: 1min 43s\n" + ] + } + ], + "source": [ + "%%time\n", + "for net_id, net in enumerate(ensemble):\n", + " optimizer = optim.SGD(net.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " for epoch in range(n_epochs): # loop over the dataset multiple times\n", + " running_loss = 0.0\n", + " for i, data in enumerate(train_loader, 0):\n", + " # get the inputs; data is a list of [inputs, labels]\n", + " inputs, labels = data\n", + " inputs, labels = inputs.cuda(), labels.cuda()\n", + " # zero the parameter gradients\n", + " optimizer.zero_grad()\n", + " # forward + backward + optimize\n", + " outputs = net(inputs)\n", + "\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # print statistics\n", + " running_loss += loss.item()\n", + " if i % 300 == 299: \n", + " print(f'Network: {net_id} epoch: {epoch + 1} step: {i} loss: {round(running_loss/2000, 5)} ')\n", + " running_loss = 0.0\n", + "\n", + "print('Finished Training')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "for net in ensemble:\n", + " net.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "dataiter = iter(test_loader)\n", + "images, labels = dataiter.next()\n", + "images, labels = images.to(device), labels.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "predictions = net(images)\n", + "probabilities = predictions.exp()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def predict_uncertainties(ensemble, images):\n", + " all_probabilities = []\n", + " for net in ensemble:\n", + " net.eval()\n", + " probs = net(images).exp()\n", + " all_probabilities.append(probs)\n", + " # Now probs is a list with num_networks tensors of size (bs x num_classes)\n", + " expected_probability = torch.mean(torch.stack(all_probabilities), dim=0)\n", + " # Total uncertainty\n", + " entropy_of_expected = torch.sum(-expected_probability * torch.log(expected_probability), dim=1) \n", + " # Data uncertainty\n", + " expected_entropy = torch.mean(\n", + " torch.stack(\n", + " [torch.sum(-prob * torch.log(prob), dim=1) for prob in all_probabilities]),\n", + " dim=0)\n", + " return expected_probability, entropy_of_expected, expected_entropy" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "expected_probability, entropy_of_expected, expected_entropy = predict_uncertainties(ensemble, images)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "knowledge_uncertainty = entropy_of_expected - expected_entropy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Predictions of one network" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy of the network on the 10000 test images: 96 %\n" + ] + } + ], + "source": [ + "net = ensemble[0]\n", + "correct = 0\n", + "total = 0\n", + "with torch.no_grad():\n", + " for data in test_loader:\n", + " images, labels = data[0].to('cuda'), data[1].to('cuda')\n", + " outputs = net(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + "print('Accuracy of the network on the 10000 test images: %d %%' % (\n", + " 100 * correct / total))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ensemble mean prediction" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy of the network on the 10000 test images: 97 %\n" + ] + } + ], + "source": [ + "correct = 0\n", + "total = 0\n", + "with torch.no_grad():\n", + " for data in test_loader:\n", + " images, labels = data[0].to('cuda'), data[1].to('cuda')\n", + " outputs = torch.log(torch.mean(torch.stack([net(images).exp() for net in ensemble]), dim=0))\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + "print('Accuracy of the network on the 10000 test images: %d %%' % (\n", + " 100 * correct / total))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Uncertainties" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "all_expected_probability = []\n", + "all_entropy_of_expected = []\n", + "all_expected_entropy = []\n", + "all_knowledge_uncertainty = []\n", + "all_labels = []\n", + "with torch.no_grad():\n", + " for data in test_loader:\n", + " images, labels = data[0].to('cuda'), data[1].to('cuda')\n", + " expected_probability, entropy_of_expected, expected_entropy = predict_uncertainties(ensemble, images)\n", + " knowledge_uncertainty = entropy_of_expected - expected_entropy\n", + "\n", + " all_expected_probability.append(expected_probability.cpu().numpy())\n", + " all_entropy_of_expected.append(entropy_of_expected.cpu().numpy())\n", + " all_expected_entropy.append(expected_entropy.cpu().numpy())\n", + " all_knowledge_uncertainty.append(knowledge_uncertainty.cpu().numpy())\n", + " all_labels.append(labels.cpu().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "all_expected_probability = np.vstack(all_expected_probability)\n", + "all_entropy_of_expected = np.hstack(all_entropy_of_expected)\n", + "all_expected_entropy = np.hstack(all_expected_entropy)\n", + "all_knowledge_uncertainty = np.hstack(all_knowledge_uncertainty)\n", + "all_labels = np.hstack(all_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "all_predictions = np.argmax(all_expected_probability, axis=1)\n", + "\n", + "errors = (all_predictions != all_labels).astype(int)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Total uncertainty" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**We can use total uncertainty to understand if we are certain in our predictions. If we are not certain, we can use some other model, or human evaluation. More about use cases you can find in our video about uncertainty applications.**" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/iv-provilkov/miniconda3/envs/unc1/lib/python3.7/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n", + " warnings.warn(msg, FutureWarning)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAs0AAAHmCAYAAACIzLPpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAz00lEQVR4nO3deXxcZ33v8e9vJMuL5E2WIy+JLWdzYmd1nDgQmsgsIRCahfa2SUqaQCC5KW25hdLyAm4JUAqlvfS2t6ElIZAFiCkQQjZKEkAJIavtLLbjOHHiBe+brMW2JEvzu3+cM/ZIHukZ2TNzxtLn/fK8NPOc7TejR/L3HD3zjLm7AAAAAPQvlXQBAAAAQLkjNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAaGMDO73szczBqTPGYSdSR53CNhZnVmdreZbYprb0q6psGIa74z6TqOVmbWZGZrk64DwKEIzcBRwMwa4zCSufWYWbOZLTezu8zsEjOzAh/zFjO7opD7LIb4tbnFzCYkXUuB/B9JfyzpPyVdK+kr/a1oZg3xcz/rSA5YqP0MZ+Xy82JmZ8W1NCRdCzDUEJqBo8u9ioLU9ZI+J+mXkhol/VzSozmC4z2SRkt68jCO9QVJVxzGdkdyzMPRqKjWCWVQSyG8R9Iv3P1L7v49d39sgHUbFD33s47wmIXaz3B2uD8vfV0safYRbH9WXEtDAWoBkKUy6QIADMpSd/9edoOZfVLS1yV9UlGofl9mmbv3SOopRWFmNtbd20p5zJByqmUQpkjalXQRCDOzEZIq3L2jUPt0965C7QtAYXGlGTjKuXuPu39K0lOSLjGzd2SW9TO+eFT859tVZrbXzHab2TIz+6d4eYOZebz6ddnDQrL24WZ2p5m9y8yeMrN2SQ/2d8wslfGx15lZp5m9YmZX9V2pv3Gxffcdr/OFePGarFpvGaiWeNzwrWb2OzPrir/eamaT+jneO83sr83szbju183suhzPLyczqzazr2ZtvyUetzwza51b4tfY1Pt1v76ffV4v6dfxw+9mrd80yOMOuB8zS5nZ58zsyXj7LjNbb2b/0ff1GoyB+onlGNdrZmvj9lPM7GEzazOzFjP7sZlNybGPcWb2FTNbaWYdZrYz7qtX9Vlvavxc1sfPbZOZ3WZmx/RZ75a43rlm9g0z2yCpQ9JVefy8/LGZPRAfo9PMdpjZ/WZ2Rp7PvSl+/tPM7F6LhmbtNbNfmNnJ2TVK+m788NdZtdxpZlfG9z926HdDMrMVZrbarLDDvIChhCvNwNBxh6R3SLpUUYDuz62SPiLpbknfUPR74CRJ74yXb1c0BOQeSb+RdFs/+5kv6Q8k3S7prjxr/EdJ1ZK+GT/+sKR7zWyUu9+Z5z6yfUvSOElXSvorSTvi9lf628DMxkt6WtKJkr4jaamksyXdLOmdZnaeu7f12ewfFA3z+JakznjdO81stbv/dqACLboa+QtJF0j6saIxyyfF+7jYzOa7+wZJ90larUNf96f72fWTcV2fjdf9Tdy+dZDHHXA/kqokfVrSTyT9TNIeSedKukHSO8zsnBJeHZ0uqUnST+OazpR0k6I+cHFmJYuGKT0laa6i5/4fkioUfZ8/IGlRvN4MSc8oeo53SHpTUb+4WdLC+DVq6VPD9yXtU/R6uqQXFf55+XNJO+NlWySdIOlGSb81s3nu/kYez71a0ffqWUXfq1mSPiHpZ2Z2WvxXlfskTY33/Q+SVsbbvinphfjYH1H0M3uAmZ0vaY6kz7m7C0Bu7s6NG7cyvykat+uS/nqAdebF6/wkq+36uK0xq22XpEfyOKZLunOAZS7p3TmW5Tpmpm2dpPFZ7ePjtl2SRoeO3c++b4nbGvJc/ytx25/1WffjcfuXc2z/oqSqrPbpisLzvXm8jh+L9/H1Pu2Xxu335Pu6D9Avrj+S4wb2Y9nfm6z2G+Jt/uhw6s/1vcla1iRpbZ+2tf0c79a4fXZW2zfjthtz7DuVdf9nkrZJOrbPOvMldUu6JUc/a5JUOcifl+ocbafGfeibeTz3pnj/f9On/dNx+3vzfF3/IV42p0/77fHznZZPv+PGbbjeGJ4BDB2t8ddxgfVaJM01s9OO8Hgvu/vjg9zmPzzryl18/z8lTVQU3ErhSkVX0/teEfxW3H5ljm2+6VlXU919o6TXFV25zed4aUlfzW5094clvSTpcjMrxu/ighzXI/skycwqzGyCmdVJ+lW8yoKCVj2wTe7+X33aMnWcJEXDSSRdJWmlux9y1dfd0/F64xVddX5AUodFQ3bq4ue2VtFV/4v7bi/p/7p792CKdvc98TEtHjZSp6ivrVL+r19a0r/1aev13PNwu6LQfEOmwcyqFc3W8nN335TnfoBhidAMDB2ZsNw64FrS/1IUUpfFY12/bWaHE9xeH2yBOvjn4myvxl+PP4z9HY5Zklb1DT7x49f7qeOtHG07JeUzpneWorDXnGPZCkljJdXlsZ/BKthxzeyPzOw5RcMSmhUFvsxrMrEAtearv++DdPB7UaeoppcC+5qt6P/AGxQ9n7632ZLqc2w36H5vZmeb2UOS2hSdtGaOcbryf/02+aFvOOz73Afk7mskPS7p2nj4jiT9kaK+8O086wCGLcY0A0NH5k1FqwZayd1/ZtEcru+XdJGkdysKDr8xs3d7/uNT9x5uoUcgqd9Z/c3AMeTfNGVmH5T0Q0nPKxpD+ztFb4CrkPTfOvyLLwONne3v+zzQTCiD/V5k1v+e+h+Tvy9H26D6fTxu+klFJ7NfVvTzuUfR8/+/kmry3FWhnvttkn4k6TJF49RvUDTW+eFB7AMYlgjNwNCR+ZNr8D8/d9+lKCx8L363/Nck/Y2kyxX9h1ospyoaR5ptTvw1+yriLkm1ObbPdRV4sG9cekvSbDOrzL7abGaVkk5W7quZR+ItRbOaTHD33X2WzVEUpnYcslV+BnrugznuQPu5VlFIXujuBwKjmZ0y+HJ7yUyrl+v7PEvS/sPc7w5FV8PPDKy3WtHzrjqMYUaDcaWiYHyZu/86e0E8+0hngY8X+nnIjOO+wcyWK3qj6D8OdsgJMBwxPAM4ysXjTP9Z0cwZj/gAszlkxqRmt7l75o1uUu8A067cgeZI3ByPJc3UM17S/5S0W9ITWeu9LultZjYma92Jimbb6Ks9/ppvrfdLmizpo33aPxa3/zTP/eTrfkW/az+T3Whm71M0m8MDmXG2h2Gg5z6Y4w60nx5FQezA/xfxidbnD7PmjMwwh3f3qe9qSdMOd6fxc7pX0hwzu6Hv8syUau6+U9Ijkj4Yzx5xyHpmNnkQh+7v5yVzhbjX1eB46rdDpsorgAF/Htx9v6Q7Jb1XB6drvKMIdQBDDleagaPLPDP7UHx/rKJxl1dIminpUUnXBLYfK2mzmT2gKChvU3RV72ZFV+cezFr3WUnvNrO/lbReUb5edIT175D0nJll5pL9sKQZkj6afRVT0r8ruhL+KzO7R9Gn/X1M0UwbfYPGs/HXfzSz7yu6Krrc3Zf3U8PXJf0PSbea2TxFr8PZiq7Ur4qXF9Kdkq6T9LfxsJgnFU1r9meKpnX77BHs+1VF42T/zMz2Kjr52ObuvxrkcQfaz48VTS34KzO7W9IIRX1ujI6Au68ys8cl3RQH2ZcUfZrdlYquAo/of+ugzyuaQvHbZnaxounnTNH3uVLR1XMp6vdPSXoyfm4vKjo5OF7RX13uVjRrRj76+3n5uaIhHfeY2b8r+jm7QNHwqDdV+P+HX1D0psHPxSeaeyStcffnsta5XdHMG1dLesLzm/IOQNLTd3Djxi1808EpwTK3HkVvKFqhaDzmJf1sd72ypp9SNB/tVxWNT92p6E/DaxXNV3xSn21PUhTEWzPHzVo20PRavY7Zp+3dkr6oKFR0Slom6Zp+9vNpRSG5U9EbCD+Sa9/xun+jaDjC/nj5Lf3VErdPVjQt2YZ4mw2Kpi6rCz2XrGVN6jM12ADfv+r4dX9LUpeik5V7JM3MsW7eU87F679f0VzTHfG2TYd53IH28zFFwbpD0mZF42Jrc9U6mPoVnQD9KO5j7YoC5qm5Xtu4nzbl2EejckyXp+hE6+uKAniXov7+Gx06ZV2dpH9SdOW7Q9EJwzJJ/6qsqdk0wNSGefy8XKgonLfF+39Y0mn9PM+82uL2BmX196z26+LvV1d/3w9Jv4yXXZtvX+PGbbjfzH2wwwEBAMDRzMwekfQ2RXMz53rDI4A+GNMMAMAwYmYnKhrT/D0CM5A/rjQDADAMmNkCRcNf/jL+eqq7r020KOAowpVmAACGh5sVvX9hnKQ/ITADg8OVZgAAACCAK80AAABAwFExT3NdXZ03NDQkXUZi9uzZo+rq6qTLQBmgLyCDvoAM+gIy6AuFsWTJkh3ufsiHGx0VobmhoUGLFy9OuozENDU1qbGxMekyUAboC8igLyCDvoAM+kJhmNm6XO0MzwAAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAioTLoAhO3a06UfPLe+V9s1C2YkVA0AAMDww5VmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQEDRQrOZHWdmvzazV81shZl9Im6/xcw2mtlL8e39xaoBAAAAKITKIu67W9Kn3H2pmY2VtMTMHouX/Yu7/3MRjw0AAAAUTNFCs7tvlrQ5vt9mZislTS/W8QAAAIBiMXcv/kHMGiQ9Kek0SZ+UdL2kVkmLFV2Nbs6xzY2SbpSk+vr6cxYtWlT0OsvV7pZWpStH9Wqrra5KqBokqb29XTU1NUmXgTJAX0AGfQEZ9IXCWLhw4RJ3n9+3veih2cxqJD0h6Svufp+Z1UvaIcklfVnSVHf/yED7mD9/vi9evLiodZaz+x5+VB11p/Rqu2bBjISqQZKamprU2NiYdBkoA/QFZNAXkEFfKAwzyxmaizp7hpmNkPQTSd939/skyd23unuPu6cl3S7pvGLWAAAAABypYs6eYZLukLTS3b+R1T41a7UrJS0vVg0AAABAIRRz9owLJF0raZmZvRS3fVbS1WZ2lqLhGWsl3VTEGgAAAIAjVszZM56SZDkWPVKsYwIAAADFwCcCAgAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAgoGih2cyOM7Nfm9mrZrbCzD4Rt9ea2WNm9kb8dWKxagAAAAAKoZhXmrslfcrd50g6X9LHzWyOpM9I+qW7nyTpl/FjAAAAoGwVLTS7+2Z3Xxrfb5O0UtJ0SZdLuite7S5JVxSrBgAAAKAQSjKm2cwaJJ0t6TlJ9e6+OV60RVJ9KWoAAAAADpe5e3EPYFYj6QlJX3H3+8xst7tPyFre7O6HjGs2sxsl3ShJ9fX15yxatKiodZaz3S2tSleO6tVWW12VUDVIUnt7u2pqapIuA2WAvoAM+gIy6AuFsXDhwiXuPr9ve2UxD2pmIyT9RNL33f2+uHmrmU11981mNlXStlzbuvttkm6TpPnz53tjY2MxSy1r9z38qDrqTunV1rhgRkLVIElNTU0azj8LOIi+gAz6AjLoC8VVzNkzTNIdkla6+zeyFj0g6br4/nWSflasGgAAAIBCKOaV5gskXStpmZm9FLd9VtLXJP2Xmd0gaZ2kPypiDQAAAMARK1podvenJFk/i99VrOMCAAAAhcYnAgIAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAXmFZjO7z8wuNTNCNgAAAIadfEPwNyVdI+kNM/uamc0uYk0AAABAWckrNLv74+7+J5LmSVor6XEze9rMPmxmI3JtY2bfMbNtZrY8q+0WM9toZi/Ft/cX4kkAAAAAxZT3cAszmyTpekkflfSipH9VFKIf62eTOyVdkqP9X9z9rPj2yKCqBQAAABJQmc9KZvZTSbMl3SPp9919c7zoh2a2ONc27v6kmTUUpEoAAAAgQfleab7d3ee4+1czgdnMRkqSu88f5DH/3MxeiYdvTBzktgAAAEDJmbuHVzJb6u7zQm05tmuQ9JC7nxY/rpe0Q5JL+rKkqe7+kX62vVHSjZJUX19/zqJFi8LPZoja3dKqdOWoXm211VUJVYMktbe3q6amJukyUAboC8igLyCDvlAYCxcuXJLrovCAwzPMbIqk6ZJGm9nZkixeNE7SmMEW4e5bs/Z9u6SHBlj3Nkm3SdL8+fO9sbFxsIcbMu57+FF11J3Sq61xwYyEqkGSmpqaNJx/FnAQfQEZ9AVk0BeKKzSm+b2K3vx3rKRvZLW3SfrsYA9mZlOzxkNfKWn5QOsDAAAA5WDA0Ozud0m6y8z+wN1/Mpgdm9m9khol1ZnZBklfkNRoZmcpGp6xVtJNh1EzAAAAUFKh4RkfcvfvSWows0/2Xe7u38ixWWbZ1Tma7xh8iQAAAECyQsMzquOvjCoHAADAsBUanvGt+OsXS1MOAAAAUH7ymqfZzL5uZuPMbISZ/dLMtpvZh4pdHAAAAFAO8v1wk4vdvVXSBxS9ge9ESZ8uVlEAAABAOck3NGeGcVwq6Ufu3lKkegAAAICyE3ojYMZDZvaapH2SbjazyZI6ilcWAAAAUD7yutLs7p+R9HZJ8919v6Q9ki4vZmEAAABAucj3SrMknaJovubsbe4ucD0AAABA2ckrNJvZPZJOkPSSpJ642UVoBgAAwDCQ75Xm+ZLmuLsXsxgAAACgHOU7e8ZySVOKWQgAAABQrvK90lwn6VUze15SZ6bR3S8rSlUAAABAGck3NN9SzCIAAACAcpZXaHb3J8xspqST3P1xMxsjqaK4pQEAAADlIa8xzWb2MUk/lvStuGm6pPuLVBMAAABQVvJ9I+DHJV0gqVWS3P0NSccUqygAAACgnOQbmjvdvSvzIP6AE6afAwAAwLCQb2h+wsw+K2m0mb1H0o8kPVi8sgAAAIDykW9o/oyk7ZKWSbpJ0iOSPl+sogAAAIByku/sGWkzu1/S/e6+vbglAQAAAOVlwCvNFrnFzHZIWiVplZltN7O/K015AAAAQPJCwzP+StGsGee6e62710paIOkCM/urolcHAAAAlIFQaL5W0tXuvibT4O5vSfqQpD8tZmEAAABAuQiF5hHuvqNvYzyueURxSgIAAADKSyg0dx3mMgAAAGDICM2ecaaZteZoN0mjilAPAAAAUHYGDM3uXlGqQgAAAIByle+HmwAAAADDFqEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAooWms3sO2a2zcyWZ7XVmtljZvZG/HVisY4PAAAAFEoxrzTfKemSPm2fkfRLdz9J0i/jxwAAAEBZK1podvcnJe3q03y5pLvi+3dJuqJYxwcAAAAKpdRjmuvdfXN8f4uk+hIfHwAAABg0c/fi7dysQdJD7n5a/Hi3u0/IWt7s7jnHNZvZjZJulKT6+vpzFi1aVLQ6y93ullalK0f1aqutrkqoGiSpvb1dNTU1SZeBMkBfQAZ9ARn0hcJYuHDhEnef37e9ssR1bDWzqe6+2cymStrW34rufpuk2yRp/vz53tjYWKISy899Dz+qjrpTerU1LpiRUDVIUlNTk4bzzwIOoi8gg76ADPpCcZV6eMYDkq6L718n6WclPj4AAAAwaMWccu5eSc9Imm1mG8zsBklfk/QeM3tD0rvjxwAAAEBZK9rwDHe/up9F7yrWMQEAAIBi4BMBAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhGYAAAAggNAMAAAABBCaAQAAgABCMwAAABBAaAYAAAACCM0AAABAAKEZAAAACCA0AwAAAAGEZgAAACCA0AwAAAAEEJoBAACAAEIzAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAIAAQjMAAAAQQGgGAAAAAgjNAAAAQAChGQAAAAggNAMAAAABhOajyP6etPZ19SRdBgAAwLBTmXQByE9bx37d8dQatXd262O/d3zS5QAAAAwrXGk+CuzudN3+m7e0e+9+pcz0nafW6K3t7UmXBQAAMGwQmsvchua9uvWVbrV1dOvDFzToo++YpbS7rrn9Oa3fuTfp8gAAAIYFQnOZ+/SPXtHebukjF8zSzEnVOmbcKH3kHbPU0d2jj9z1gtJpT7pEAACAIY/QXMY2t+zTs2t26sJpKR1XO+ZA+9Txo/V3H5ij1dvatXhdc4IVAgAADA+E5jL28Cub5S7NO+bQb9Mlp03RmKoK/fTFDQlUBgAAMLwQmsvYAy9v0unTx2vyaDtk2ZiqSr3vtKl66JXN6tjPNHQAAADFRGguU2t27NErG1r0+2dO7XedP5g3XW0d3Xp85dYSVgYAADD8EJrL1EMvb5IkfeCMaf2uc/7xkzRt/Cjdt3RjqcoCAAAYlgjNZcjd9cDLm3Ruw0RNmzC63/VSKdMVZ0/XE69v1/a2zhJWCAAAMLwQmsvQa1va9Ma2dl12Zv9XmTM+OG+6etJRyAYAAEBxEJrL0IMvb1JFyvS+0/sfz5xx4jFjdeax43XfUmbRAAAAKBZCcxl6fOVWnX98repqRua1/qVnTNWKTa3a0tJR5MoAAACGJ0JzmWnZu1+vb23X+bMm5b3NBSfWSZKefnNHscoCAAAY1gjNZWbJ+l2SpPkNtXlvc+qUcaqtrtJvV+8sVlkAAADDGqG5zCxe26zKlOms4ybkvU0qZXrb8ZP09Js75O7FKw4AAGCYIjSXmcXrmjV32jiNrqoY1HZvO2GSNrd0aM2OPUWqDAAAYPgiNJeRru60Xv7dbp0zM/+hGRmZcc2/fZMhGgAAAIVGaC4jKza1qLM7rfkNEwe9bcOkMZo2fpSe4c2AAAAABUdoLiNL1jVLkubPHHxoNjO9/cQ6PfPmTqXTjGsGAAAopERCs5mtNbNlZvaSmS1OooZy9MLaXZpRO0bHjBt1WNtfcOIkNe/dr1c3txa4MgAAgOEtySvNC939LHefn2ANZcPdtWRd82FdZc54+wnM1wwAAFAMDM8oE+t27tWO9i6dcxjjmTPqx43SCZOr9TRvBgQAACiopEKzS3rUzJaY2Y0J1VBWFh8Yzzz4mTOyXXBinZ5fs0v7e9KFKAsAAACSLIkPwzCz6e6+0cyOkfSYpL9w9yf7rHOjpBslqb6+/pxFixaVvM5S+u7yTr2wpVv//q4xSpn1Wra7pVXpyt7jnGurq3Lu54Ut3br1pU797/NH6YQJg5vrGeWvvb1dNTU1SZeBMkBfQAZ9ARn0hcJYuHDhklzDhyuTKMbdN8Zft5nZTyWdJ+nJPuvcJuk2SZo/f743NjaWusyS+ocXn9B5J4zWOxeed8iy+x5+VB11p/Rqa1wwI+d+5rZ16taXHldPbYMaLzyhKLUiOU1NTRrqPwvID30BGfQFZNAXiqvkwzPMrNrMxmbuS7pY0vJS11FO9nZ1a/W2dp157IQj3tfksSM1q65az69pPvLCAAAAICmZK831kn5q0RCESkk/cPf/TqCOsrFyc6vSLp02fXxB9nduw0Q9+upWpdOuVMrCGwAAAGBAJQ/N7v6WpDNLfdxytmxDiyTp9AKF5vNmTdJ/Ld6gN7a1a/aUsQXZJwAAwHDGlHNlYPmmVtXVjFT9uJEF2d95DdEMHM+v3VWQ/QEAAAx3hOYysHxji06fPk5mhRlKcVztaNWPG6kX1hCaAQAACiGR2TNwUMf+Hr2xrV3vmVM/qO1+8Nz6nO3XLJghM9O5DbV6fs0uuXvBwjgAAMBwxZXmhK3c3KqetBfsTYAZC2bVaktrhzY07yvofgEAAIYjQnPClm8s7JsAM86dFY9rZogGAADAESM0J2zZxhbVVldp6vhR4ZUH4eRjxmr86BF6gTcDAgAAHDFCc8KWb2zVadPHF3zccSplmj9zIjNoAAAAFAChOUEd+3v0+tY2nT59XFH2f+6sWr21fY92tHcWZf8AAADDBaE5Qau2tKk77TptWmHHM2ecF49rZuo5AACAI0NoTtCy+E2AhZ45I+O0aeM1akSKIRoAAABHiNCcoBWbWjRhzAgdO3F0UfZfVZnS2cdN5M2AAAAAR4jQnKBlG1t02rTCvwkw27mzavXqpla1dewv2jEAAACGOkJzQjq7e7RqS1vRhmZkLJhVq7RLS9fvLupxAAAAhjJCc0Le2Nqu/T1e8A816evsGRNUmTI9v2ZnUY8DAAAwlBGaE3LwTYDFmW4uY0xVpeZOH68X1jQX9TgAAABDGaE5Ics2tmjcqErNqB1T9GOd1zBRL23Yrc7unqIfCwAAYCgiNCdkxcaWonwSYC7nNtSqqzutVza0FP1YAAAAQxGhOQH7e9JaWYI3AWac2xB9yMnzfMgJAADAYSE0J+D1rW3q6k6XLDRPrK7SyfU1hGYAAIDDRGhOwIqNrZJU9Jkzsp3bUKul65rVk/aSHRMAAGCoIDQnYNnGFtWMrNTMErwJMOO8WbVq6+zW8o2MawYAABgsQnMClm1s0dxp45RKFf9NgBnvOLFOZtKTr28v2TEBAACGCkJziXX3pLVyc2tJh2ZI0qSakTp9+ng9QWgGAAAYNEJzia3e3q7O7rROP7a0oVmSLjp5spaub1bL3v0lPzYAAMDRjNBcYsviuZLnTksmNKdd+u2bO0p+bAAAgKMZobnEVmxqVXVVhY6vqy75sc86boLGjqpkXDMAAMAgEZpLLHoT4PiSvgkwo7IipXecWKcnXt8ud6aeAwAAyBehuYS6e9J6dVOr5k4fl1gNF508WZtbOvTGtvbEagAAADjaEJpL6LUtbdq3v0fzZkxMrIYLT54sSXpiFUM0AAAA8kVoLqEl65olSfNmJheap00YrZOOqWHqOQAAgEEgNJfQknXNmjJulKaNH5VoHRedPFnPr9mlvV3didYBAABwtKhMuoDhZOn6Zp0zc6LMivcmwB88t/6QtmsWzOj1eOEpx+jbT63RE6u2632nTy1aLQAAAEMFV5pLZGtrhzY079PZMyYkXYoWzKpVXU2VHnh5U9KlAAAAHBUIzSWyNB7PfE6C45kzKitSuvT0qfrla9vU1sGnAwIAAIQQmktk6fpmVVWmEvkkwFwuO2uaurrTenTF1qRLAQAAKHuE5hJZsq5ZZ0wfr6rK8njJ582YqGMnjmaIBgAAQB7KI8ENcZ3dPVq+sTXRqeb6MjP9/pnT9NTqHdrZ3pl0OQAAAGWN0FwCyze2qqsnneiHmuRy2ZnT1JN2PbJsc9KlAAAAlDVCcwksPfChJhOSLaSPU6aM1UnH1DBEAwAAIIDQXAJL1jXruNrROmZssh9q0peZ6fKzpumFtc3a0Lw36XIAAADKFqG5yNw9+lCTMhuakXHF2dOVMunuZ9YlXQoAAEDZIjQX2Zvb27WtrVPnzqpNupScjp04RpeeMU0/eG69WpmzGQAAICdCc5E1rdouSbro5MkJV9K/my48Xu2d3Tk/ghsAAACE5qJrWrVdJx5To2Mnjkm6lH6dNn28Ljhxkr7z1Bp1dvckXQ4AAEDZITQX0Z7Obj2/Zpcay/gqc8ZNF56gbW2d+tlLzKQBAADQV2XSBQxlz7y5U109aTXOPibROnINu7hmwYxej3/vpDqdOnWcbnvyLf3hvGOVSlmpygMAACh7XGkuoqbXt2lMVYXOnVWeM2dkMzPddOHxWr2tXY8s58NOAAAAshGai8Td1bRqu95+wiSNrKxIupy8fOCMqZo7bZy+9OCrzKQBAACQhdBcJG/t2KMNzft0UcJDMwajsiKlr37wdO1o79Q//2JV0uUAAACUDUJzkWSmmjsa3gSY7YxjJ+hP39age55dp6Xrm5MuBwAAoCzwRsAiaVq1TSdMrtZxteU51Vx/czJfs2CG/vq9s/WLFVv02fuW6cG/eIdGVHBuBQAAhjfSUBHs6ezWc2t26aKTj56hGdlqRlbqi5fN1Wtb2vS1n7+WdDkAAACJIzQXwYMvb1JXd1qXnjEl6VIO28Vzp+j6tzfojqfW6M7frkm6HAAAgEQxPKMI7n1+vU6ur9G8GeU/1dxA/vcH5mjT7n364kOvatqE0bp47tF7EgAAAHAkuNJcYCs2tejlDS266twZMju6PyCkImX616vO1hnHTtBfLnpRz761M+mSAAAAEkFoLrBFz/9OVZUpfXDe9KRLKYjRVRW647r5mjZhtK694zn98IXcbyAEAAAYygjNBbSvq0f3v7hR7z9tiiaMqUq6nIKpqxmpn958gc4/fpL+9ifL9OWHXlV3TzrpsgAAAEqGMc0F9NArm9TW2a2rz5uRdCmHLddUdNcsmKHxY0bou9efq79/eKXueGqNFq/dpb+/4nSdfuz4BKoEAAAoLa40F9CiF36n4ydX67xZtUmXUhSVFSndctlc/dvVZ2vj7g5ddutT+rufLdfuvV1JlwYAAFBUXGkukMVrd2nJumZ97v2nHvVvAOwr19Xnmy86Qet37dE9z67Tj5ds0FXnztANvzdL0yeMTqBCAACA4iI0F0B3T1qfv3+5pk8YrT85/+gdmjEYo6sq9MXLT9M1C2bqW0+8qbufWau7n1mr98yp1xVnT1fj7MkaWVmRdJkAAAAFQWgugLueWafXtrTpPz90jsZUDa+XdPaUsfrGH5+lT713tr771Br99MWN+vnyLRo/eoTeM6dejbMn6/dOnKzxY0YkXSoAAMBhG14Jrwi2tnboXx57XY2zJ+u9c+uTLicx0yeM1uc/MEczJ1Vr9bZ2vbxhtx5+ZbN+vGSDKlKm06eP13mzajV/5kSdPWOiJo8dmXTJAAAAeSM0H6G/f3ilunrS+uJlc4fcWOaQXGOdK1Km2VPGavaUsepJuzY079WIipSeW7NTd/52rW578i1JUl1NlU6dOk6nTh2nU6aM1alTx+mEyTWqquS9qQAAoPwQmo/APc+u04Mvb9In3nWSZk6qTrqcslORsgOvy5VnH6sPnDFNG5v3aePufdrS0qHmvV268+m16uqO5nxOmTRtwmg1TKrWjEljNLN2jGZOqtaM2jGaOn6UJowZMexOTAAAQHkgNB+mu55eqy88sELvPvUYfXzhiUmXc1QYUZFSQ121GuoOnmD0pF072ju1paVD29o6tHNPl9bu3KOl65u1t6un1/YjK1OaMn6U6seN0tT466TqKk2srtKk6irVxreJ1VWqqapUKkXABgAAhUFoPgzfeWqNvvTQq3rPnHrdes08hhQcgYqUqX5cFID72tfVo117urRrb5da9+1X6779aunYr22tHVq9rV2t+/arO+397ntMVYWqR1aqOvN1ZKVq4q/VVRUaNaJCVZUpVVWkoq997o+Mb1F7xSHrjKzs/bWqIqXKCvoCAABDUSKh2cwukfSvkiokfdvdv5ZEHYO1elu7vv7fr+nRV7fqvXPr9f+uJjAX0+iqCk2vGq3pE3PP/ezu6upJa09nj/Z0dmtvV7f2dPZo9pSxau/s1p7Obu2J217f2qbtbZ3a2LxPnd096uxOq7vH1Z1OqyftGiB7D4qZNCKVUkXKVFlhGlER3R+RMlVWpFRZYapMmSpTKY2osHi9zP1UvF60PPM1s96IipQ2b+rUM/tWakQqa18VqXifB/eVvX32eiNSB49ZmTKlzJRKRScvKcvcsh6nTBVx28H70TZmJpOUMlNm1EzmfnY7Q2oAAENByUOzmVVIulXSeyRtkPSCmT3g7q+WupZ8dPek9dLvdusnSzfqvxb/TqNHVOivLz5ZN110gkZwVTFRZqaRlRUaWVmh2uqqXsuqR/bu2ucfP2nAffWkXT3pKER3p109Pa7u9MFQvb8na/mB+30fp+P9SGl39bgrnfbofqbtwGPXvv1ppdNST1Z7Op3ZTn22d1WkTF37u/XrDWvV3ZMuWNAvhUyQtjiAm6KGzH2zOGRLcfuh4VuKt423SWUF8lRKqrD4hCDrpKWyz+OKVFZbvDzTVpF9IhM/zl6eOdHIta/sx+rzfCxT84HHB9szzy2qMXpckcq6ZT1O2cFjpFKm9i5Xy779B2rKLE9xogIARZHElebzJK1297ckycwWSbpcUlmF5sVrd+n237ylp1fvVFtnt0ZUmK49f6b+4p0nalIN06UNNZlgUlXmnyw/asdr6qg7RVIUwtOZgJ0J1weCdnwi0Cd4ZwJ5Zlt39f6q6Ap+2iX36L67lFam7eC6UrSO4m0yGT4K85l96cB6Llf8L27vu078OFokl/fZPt7GD10nu+7Mc+3YH51YZJ5/9rJ0n9ch83plv6aygydTZetXj+ZsPjRwZ/p4FM4r4r8iZEJ7rpOUVOaEpO86B/7C0PskKHNy0Hv9Q9fNPvk5dJ99Tzgyjwdet3ed2dvZIX/1yDw+cAKWtW7v52I5jysdPHE7eP/QdvVqt0zTgb/K9N2HstbN5zjKal+xtVtdK7Ycuu2BUg5unHldD6lHlnVfWevnOn7vdQ/W37d94OOon/Z8n3evevp5zQ45Th6vr/o9fu5t1eu1PoLn3d+2nASXjSRC83RJv8t6vEHSggTqGFBrx34t29CiS8+YqgtPnqy3nzBJE8ZUhTcESiQznKLMc/5RL/uk4sAJSdbJSCaY92SdiEgHg7yU++QiV9g/eMLi6vHcyzLHrGjboq7q+qg97UpLBwJ/9n7SvfZxcPsDJyl+8OTF48I8R33d7spUn70s54mPZz33A+v3PmnyAU6A+t1Xf8cdYF8W15D9vRmSXlySdAUogVDA9rQr9djPB9hBHscYRB0D7ye80kD7OW7iGP3iry7Mo5rSKds3AprZjZJujB+2m9mqJOp4RtI/JnHg3uok7Ui6CJQF+gIy6AvIoC8gY8j0hZWS7JOJHX5mrsYkQvNGScdlPT42buvF3W+TdFupiipnZrbY3ecnXQeSR19ABn0BGfQFZNAXiiuJP+y+IOkkM5tlZlWSrpL0QAJ1AAAAAHkp+ZVmd+82sz+X9AtFU859x91XlLoOAAAAIF+JjGl290ckPZLEsY9SDFNBBn0BGfQFZNAXkEFfKCLzIf12YgAAAODIMVkVAAAAEEBoLhNmdomZrTKz1Wb2mRzLR5rZD+Plz5lZQwJlogTy6AvXm9l2M3spvn00iTpRfGb2HTPbZmbL+1luZvZvcV95xczmlbpGlEYefaHRzFqyfi/8XalrRGmY2XFm9msze9XMVpjZJ3Ksw++GIiA0l4GsjxZ/n6Q5kq42szl9VrtBUrO7nyjpX1QW00ej0PLsC5L0Q3c/K759u6RFopTulHTJAMvfJ+mk+HajpP8oQU1Ixp0auC9I0m+yfi98qQQ1IRndkj7l7nMknS/p4zn+n+B3QxEQmsvDgY8Wd/cuSZmPFs92uaS74vs/lvQu47M1h6J8+gKGCXd/UtKuAVa5XNLdHnlW0gQzm1qa6lBKefQFDBPuvtndl8b32xR9Dsj0Pqvxu6EICM3lIddHi/f9ATiwjrt3S2qRNKkk1aGU8ukLkvQH8Z/cfmxmx+VYjuEh3/6C4eFtZvaymf3czOYmXQyKLx6qebak5/os4ndDERCagaPPg5Ia3P0MSY/p4F8gAAxfSyXNdPczJf0/SfcnWw6KzcxqJP1E0v9y99ak6xkOCM3lIZ+PFj+wjplVShovaWdJqkMpBfuCu+9098744bclnVOi2lB+8vndgWHA3VvdvT2+/4ikEWZWl3BZKBIzG6EoMH/f3e/LsQq/G4qA0Fwe8vlo8QckXRff/0NJv3Im2R6Kgn2hz7i0yxSNZ8Pw9ICkP43fKX++pBZ335x0USg9M5uSeZ+LmZ2n6P93LqwMQfH3+Q5JK939G/2sxu+GIkjkEwHRW38fLW5mX5K02N0fUPQDco+ZrVb0ZpCrkqsYxZJnX/hLM7tM0Tuod0m6PrGCUVRmdq+kRkl1ZrZB0hckjZAkd/9PRZ+s+n5JqyXtlfThZCpFseXRF/5Q0s1m1i1pn6SruLAyZF0g6VpJy8zspbjts5JmSPxuKCY+ERAAAAAIYHgGAAAAEEBoBgAAAAIIzQAAAEAAoRkAAAAIIDQDAAAAAYRmAAAAIIDQDAAAAAQQmgEAAICA/w+WjWpqqFYqgwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(12,8))\n", + "sns.distplot(all_entropy_of_expected, bins=100)\n", + "plt.title('Distribution of total uncertainty', fontsize=18)\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "total_uncertainty_order_error_cumsum = np.cumsum(errors[all_entropy_of_expected.argsort()])" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(12,8))\n", + "plt.plot(total_uncertainty_order_error_cumsum)\n", + "plt.xlabel('Example_id', fontsize=18)\n", + "plt.ylabel('Errors sum', fontsize=18)\n", + "plt.title(\"Change of total error if we go from low to high total uncertainty\", fontsize=18)\n", + "plt.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Knowledge uncertainty" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "OOD detection" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([200, 1, 28, 28])" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "images.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "ood_loader = torch.utils.data.DataLoader(\n", + " torchvision.datasets.FashionMNIST('./data/', train=False, \n", + " transform=torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " ]), download=True), \n", + " batch_size=batch_size_test, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 9 2 1 1\n" + ] + } + ], + "source": [ + "dataiter = iter(ood_loader)\n", + "images, labels = dataiter.next()\n", + "\n", + "# show images\n", + "imshow(torchvision.utils.make_grid(images[:4]))\n", + "# print labels\n", + "print(' '.join('%5s' % classes[labels[j]] for j in range(4)))" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "f_all_expected_probability = []\n", + "f_all_entropy_of_expected = []\n", + "f_all_expected_entropy = []\n", + "f_all_knowledge_uncertainty = []\n", + "f_all_labels = []\n", + "with torch.no_grad():\n", + " for data in ood_loader:\n", + " images, labels = data[0].to('cuda'), data[1].to('cuda')\n", + " expected_probability, entropy_of_expected, expected_entropy = predict_uncertainties(ensemble, images)\n", + " knowledge_uncertainty = entropy_of_expected - expected_entropy\n", + "\n", + " f_all_expected_probability.append(expected_probability.cpu().numpy())\n", + " f_all_entropy_of_expected.append(entropy_of_expected.cpu().numpy())\n", + " f_all_expected_entropy.append(expected_entropy.cpu().numpy())\n", + " f_all_knowledge_uncertainty.append(knowledge_uncertainty.cpu().numpy())\n", + " f_all_labels.append(labels.cpu().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "f_all_expected_probability = np.vstack(f_all_expected_probability)\n", + "f_all_entropy_of_expected = np.hstack(f_all_entropy_of_expected)\n", + "f_all_expected_entropy = np.hstack(f_all_expected_entropy)\n", + "f_all_knowledge_uncertainty = np.hstack(f_all_knowledge_uncertainty)\n", + "f_all_labels = np.hstack(f_all_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "f_all_predictions = np.argmax(f_all_expected_probability, axis=1)\n", + "\n", + "errors = (f_all_predictions != f_all_labels).astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: 0.0517\n" + ] + } + ], + "source": [ + "print(f'Accuracy: {(1 - errors).sum() / errors.shape[0]}')" + ] + }, + { + "cell_type": "code", + "execution_count": 186, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/iv-provilkov/miniconda3/envs/unc1/lib/python3.7/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n", + " warnings.warn(msg, FutureWarning)\n", + "/home/iv-provilkov/miniconda3/envs/unc1/lib/python3.7/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n", + " warnings.warn(msg, FutureWarning)\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 186, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(12,8))\n", + "sns.distplot(all_expected_probability.flatten(), label='MNIST', bins=40)\n", + "sns.distplot(f_all_expected_probability.flatten(), label='Fashion MNIST', bins=40)\n", + "plt.title(\"Distributions of mean likelihood in ensemble for all classes in two datasets\", fontsize=18)\n", + "plt.legend(fontsize=18)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/iv-provilkov/miniconda3/envs/unc1/lib/python3.7/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n", + " warnings.warn(msg, FutureWarning)\n", + "/home/iv-provilkov/miniconda3/envs/unc1/lib/python3.7/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n", + " warnings.warn(msg, FutureWarning)\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(12,8))\n", + "sns.distplot(all_knowledge_uncertainty, label='MNIST', bins=40)\n", + "sns.distplot(f_all_knowledge_uncertainty, label='Fashion MNIST', bins=40)\n", + "plt.title(\"Distributions of knowledge uncertainty for two datasets\", fontsize=18)\n", + "plt.legend(fontsize=18)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### OOD Detection\n", + "\n", + "**We will use fashion-MNIST as OOD data**" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import roc_auc_score" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "y_true = np.zeros(20000)\n", + "y_true[10000:] = 1\n", + "knowledge_scores = np.hstack([all_knowledge_uncertainty, f_all_knowledge_uncertainty])\n", + "data_scores = np.hstack([all_expected_entropy, f_all_expected_entropy])\n", + "total_scores = np.hstack([all_entropy_of_expected, f_all_entropy_of_expected])" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ROC-AUC score with knowledge uncertainty: 0.864588025\n", + "ROC-AUC score with data uncertainty: 0.99998885\n", + "ROC-AUC score with total uncertainty: 0.99997749\n" + ] + } + ], + "source": [ + "print(f'ROC-AUC score with knowledge uncertainty: {roc_auc_score(y_true, knowledge_scores)}')\n", + "print(f'ROC-AUC score with data uncertainty: {roc_auc_score(y_true, data_scores)}')\n", + "print(f'ROC-AUC score with total uncertainty: {roc_auc_score(y_true, total_scores)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calibration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**It is usually useful to check the match of the real distribution of classes with the distribution of model probabilities.**" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/iv-provilkov/miniconda3/envs/unc1/lib/python3.7/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n", + " warnings.warn(msg, FutureWarning)\n", + "No handles with labels found to put in legend.\n" + ] + }, + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Calibration of our model')" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(12,8))\n", + "sns.distplot(all_labels, bins=10)\n", + "plt.plot(np.arange(0, 10), all_expected_probability.mean(axis=0), color='red')\n", + "plt.legend()\n", + "plt.title(\"Calibration of our model\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Here our calibration is good--model distribution over classes is very similar to the real distribution of classes.**\n", + "\n", + "**In case of bad calibration -- if your model distribution over classes is very different from real distribution, you should try to fix this. For classification tasks it could be fixed with temperature techniques**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deep ensembles for Regression" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**For regression tasks you should consider other distributions, but in general techniques are the same**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**We still have the same formulas for uncertainties:**\n", + "\n", + "#### Through ensemble we can calculate *Total uncertainty*:\n", + "#### $$H[\\mathbb{E}_{\\theta \\sim P(\\theta|D)} [P(y|x^*, \\theta)]] $$\n", + "#### and *Expected data uncertainty*:\n", + "#### $$\\mathbb{E}_{\\theta \\sim P(\\theta|D)} [H[P(y|x^*, \\theta)]] $$\n", + "#### *Knowledge uncertainty* is the difference between *Total uncertainty* and *Expected Data uncertainty*:\n", + "#### $$\\mathcal{I}(y, \\theta| x^*, D) = H[\\mathbb{E}_{\\theta \\sim P(\\theta|D)} [P(y|x^*, \\theta)]] - \\mathbb{E}_{\\theta \\sim P(\\theta|D)} [H[P(y|x^*, \\theta)]]$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Gaussian ensemble for different uncertainty cases:\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Here we will use simple toy dataset to show concepts.** \n", + "**For ease of understanding, we will import models and functions from source files, that you can find in the repository. You can check them more precisely later. These classes are quite flexible and thoughtful, you can use them with other tasks.**" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "from src.data.toy_loading import get_toy_dataset, get_arrays_from_loader\n", + "from torch.utils import data\n", + "from torch.distributions import Normal\n", + "from torch.optim import SGD, Adam\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "import random" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "targets_fn = lambda x: np.sin(x) + x / 10.0\n", + "noise_fn = lambda x: 1.0 / (1.0 + np.abs(x)) + 0.1\n" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train size: 2048 16\n", + "OOD size: 512 4\n", + "Test size: 512 1\n" + ] + } + ], + "source": [ + "train_data, test_data, ood_full_data, y_noise = get_toy_dataset(\n", + " targets_fn, noise_fn, train_limits=(-10, 10), ood_abs_limits=(20, 25),\n", + " test_limits=(-15, 15),\n", + " train_samples=2048, test_samples=512, ood_samples=512,\n", + " random_state=12\n", + ")\n", + "\n", + "trainloader = data.DataLoader(train_data, batch_size=128, shuffle=True)\n", + "testloader = data.DataLoader(test_data, batch_size=512)\n", + "oodloader = data.DataLoader(\n", + " data.TensorDataset(ood_full_data.tensors[0]),\n", + " batch_size=128, shuffle=True\n", + ")\n", + "\n", + "print(\"Train size:\", len(trainloader) * trainloader.batch_size, len(trainloader))\n", + "print(\"OOD size:\", len(oodloader) * oodloader.batch_size, len(oodloader))\n", + "print(\"Test size:\", len(testloader) * testloader.batch_size, len(testloader))" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "X_train, y_train = get_arrays_from_loader(trainloader)\n", + "X_ood, y_ood = ood_full_data.tensors[0], ood_full_data.tensors[1]\n", + "X_test, y_test = get_arrays_from_loader(testloader)\n", + "test_ord = X_test.argsort(0)\n", + "X_test, y_test = X_test[test_ord].squeeze(), y_test[test_ord].squeeze()\n", + "\n", + "plt.figure(figsize=(10,10))\n", + "plt.title(\"Data\")\n", + "plt.rc('font', size=30)\n", + "plt.plot(\n", + " X_test, y_test,\n", + " color=sns.color_palette()[0],\n", + " linewidth=3, label=\"Test data\"\n", + ")\n", + "plt.fill_between(\n", + " X_test,\n", + " y_test - noise_fn(X_test), \n", + " y_test + noise_fn(X_test), \n", + " color=sns.color_palette()[0], alpha=0.1\n", + ")\n", + "plt.fill_between(\n", + " X_test,\n", + " y_test - 2 * noise_fn(X_test), \n", + " y_test + 2 * noise_fn(X_test), \n", + " color=sns.color_palette()[0], alpha=0.1\n", + ")\n", + "plt.scatter(X_ood, y_ood, color='g',label='ood points')\n", + "plt.grid()\n", + "plt.legend(loc='lower left', fontsize=15)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "from src.models.simple_model import SimpleModel\n", + "from src.training.trainers import DistributionMLETrainer, DistributionEnsembleMLETrainer\n", + "from src.distributions.mixture_distribution import GaussianDiagonalMixture" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 train loss 1.538 eval loss 1.366 eval params_rmse 1.008\n", + "Epoch 100 train loss 0.806 eval loss 1.282 eval params_rmse 1.057\n", + "Epoch 149 train loss 0.723 eval loss 1.212 eval params_rmse 1.401\n", + "Trained Gaussian, final rmse 1.401\n" + ] + } + ], + "source": [ + "model_params = {\n", + " \"input_dim\": 1, \"output_dim\": 1,\n", + " \"num_units\": 30, \"num_hidden\": 2\n", + "}\n", + "optim_params = {\"lr\": 1e-3, \"weight_decay\": 1e-4}\n", + "NUM_EPOCHS = 150\n", + "LOG_PER = 100\n", + "\n", + "singletrainer = DistributionMLETrainer(\n", + " model_params,\n", + " SimpleModel, optim_params, distribution=Normal,\n", + " optimizer=Adam\n", + ")\n", + "\n", + "strain_hist, stest_hist, smetrics_hist = singletrainer.train(\n", + " trainloader, NUM_EPOCHS, testloader, \n", + " log_per=LOG_PER, verbose=True\n", + ")\n", + "print(\"Trained Gaussian, final rmse %.3f\" % smetrics_hist[-1][0])" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "model_preds = singletrainer.get_predicted_params(testloader)\n", + "mean_preds, std_preds = model_preds[0][test_ord].squeeze(), model_preds[1][test_ord].squeeze()\n", + "\n", + "fig=plt.figure(figsize=(12,8))\n", + "plt.plot(X_test, y_test, label=\"Real y\")\n", + "plt.plot(X_test, mean_preds, label=\"Mean prediction\")\n", + "plt.fill_between(X_test, mean_preds - std_preds, mean_preds + std_preds, alpha=0.1)\n", + "plt.legend()\n", + "plt.grid()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances = singletrainer.get_predicted_params(testloader)[1]\n", + "\n", + "fig=plt.figure(figsize=(12,8))\n", + "plt.title(\"Predicted Variance for Normal\")\n", + "plt.plot(X_test, variances, color='teal', label='variances')" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "ensembletrainer = DistributionEnsembleMLETrainer(\n", + " 10, GaussianDiagonalMixture,\n", + " model_params,\n", + " SimpleModel, optim_params, distribution=Normal,\n", + " optimizer=Adam\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------\n", + "Model 0\n", + "Epoch 0 train loss 1.588 eval loss 1.442 eval params_rmse 1.038\n", + "Epoch 100 train loss 0.898 eval loss 1.286 eval params_rmse 0.778\n", + "Epoch 149 train loss 0.773 eval loss 1.308 eval params_rmse 1.221\n", + "--------------------\n", + "Model 1\n", + "Epoch 0 train loss 1.584 eval loss 1.404 eval params_rmse 1.042\n", + "Epoch 100 train loss 0.811 eval loss 1.149 eval params_rmse 1.209\n", + "Epoch 149 train loss 0.770 eval loss 1.267 eval params_rmse 1.409\n", + "--------------------\n", + "Model 2\n", + "Epoch 0 train loss 1.479 eval loss 1.245 eval params_rmse 0.866\n", + "Epoch 100 train loss 0.771 eval loss 1.066 eval params_rmse 0.754\n", + "Epoch 149 train loss 0.677 eval loss 1.039 eval params_rmse 0.957\n", + "--------------------\n", + "Model 3\n", + "Epoch 0 train loss 1.511 eval loss 1.260 eval params_rmse 0.849\n", + "Epoch 100 train loss 0.887 eval loss 1.194 eval params_rmse 1.015\n", + "Epoch 149 train loss 0.809 eval loss 1.181 eval params_rmse 1.307\n", + "--------------------\n", + "Model 4\n", + "Epoch 0 train loss 1.550 eval loss 1.413 eval params_rmse 1.045\n", + "Epoch 100 train loss 0.752 eval loss 1.257 eval params_rmse 0.976\n", + "Epoch 149 train loss 0.684 eval loss 1.064 eval params_rmse 1.175\n", + "--------------------\n", + "Model 5\n", + "Epoch 0 train loss 1.479 eval loss 1.256 eval params_rmse 0.884\n", + "Epoch 100 train loss 0.769 eval loss 1.211 eval params_rmse 1.011\n", + "Epoch 149 train loss 0.705 eval loss 1.134 eval params_rmse 1.329\n", + "--------------------\n", + "Model 6\n", + "Epoch 0 train loss 1.491 eval loss 1.273 eval params_rmse 0.882\n", + "Epoch 100 train loss 0.879 eval loss 1.249 eval params_rmse 0.821\n", + "Epoch 149 train loss 0.734 eval loss 1.213 eval params_rmse 1.021\n", + "--------------------\n", + "Model 7\n", + "Epoch 0 train loss 1.539 eval loss 1.331 eval params_rmse 0.952\n", + "Epoch 100 train loss 0.787 eval loss 1.273 eval params_rmse 1.258\n", + "Epoch 149 train loss 0.774 eval loss 1.254 eval params_rmse 1.334\n", + "--------------------\n", + "Model 8\n", + "Epoch 0 train loss 1.563 eval loss 1.382 eval params_rmse 1.001\n", + "Epoch 100 train loss 0.761 eval loss 1.239 eval params_rmse 1.104\n", + "Epoch 149 train loss 0.692 eval loss 1.095 eval params_rmse 1.324\n", + "--------------------\n", + "Model 9\n", + "Epoch 0 train loss 1.584 eval loss 1.484 eval params_rmse 0.999\n", + "Epoch 100 train loss 0.720 eval loss 1.098 eval params_rmse 0.843\n", + "Epoch 149 train loss 0.638 eval loss 0.970 eval params_rmse 0.959\n", + "Trained Ensemble, final rmse 0.959\n" + ] + } + ], + "source": [ + "entrain_hist, enval_hist, enmetrics_hist = ensembletrainer.train(\n", + " trainloader, NUM_EPOCHS, testloader, \n", + " log_per=LOG_PER, verbose=True\n", + ")\n", + "print(\"Trained Ensemble, final rmse %.3f\" % enmetrics_hist[-1][-1][0])" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "ev_scores = ensembletrainer.eval_uncertainty(\n", + " testloader, \"expected_variance\"\n", + ")[test_ord].squeeze()\n", + "\n", + "fig = plt.figure(figsize=(12,8))\n", + "plt.title(\"Expected variance plot\")\n", + "plt.plot(X_test, ev_scores, color='b', label='expected variance (data uncertainty)')\n", + "#%%\n", + "# Show estimated knowledge uncertainty\n", + "voe_scores = ensembletrainer.eval_uncertainty(\n", + " testloader, \"variance_of_expected\"\n", + ")[test_ord].squeeze()\n", + "\n", + "ood_voe_scores = ensembletrainer.eval_uncertainty(\n", + " oodloader, \"variance_of_expected\")\n", + "plt.plot(X_test, voe_scores, label='variance of expected (knowledge uncertainty)', color='r')\n", + "plt.scatter(X_ood, ood_voe_scores, label='ood variance of expected (knowledge uncertainty)', color='g')\n", + "plt.grid()\n", + "plt.legend(fontsize=15)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ensemble distribution distillation into Prior Network\n", + "\n", + "**We can distill an ensemble into one model, to reduce computational time during inference, while saving predictions quality.**\n", + "\n", + "**If we use usual knowledge distillation: ensemble of normal models into another normal model, then we lose possibility to calculate uncertainty measures. Ensemble distribution distillation is to solve this problem. It allows us to distill an ensemble into Prior Network, saving possibility to calculate uncertainty measures.**\n", + "\n", + "**Prior Network is the distribution over distributions. For example we can use Normal-Wishart distribution in Prior Network for an ensemble of Normal distributions, as you can sample parameters of Normal distribution from Normal-Wishart distribution.**" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.ood_trainers import DistributionEnsembleToPriorDistiller\n", + "from src.distributions.prior_distribution import NormalWishartPrior" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 train loss 608.501 eval loss 2305.013 eval params_rmse 1.077\n", + "Epoch 100 train loss 156.554 eval loss 788.144 eval params_rmse 1.065\n", + "Epoch 149 train loss 135.326 eval loss 676.144 eval params_rmse 1.055\n", + "Distilled NWPrior, final rmse 1.055\n" + ] + } + ], + "source": [ + "loss_params = {\n", + " \"max_temperature\": 1.0,\n", + " \"noise_level\": 3.0\n", + "}\n", + "model_params[\"isPrior\"] = True\n", + "prior_distiller = DistributionEnsembleToPriorDistiller(\n", + " [ensembletrainer.trainers[i].model for i in range(10)],\n", + " loss_params, model_params, SimpleModel,\n", + " optim_params, distribution=NormalWishartPrior, optimizer=Adam\n", + ")\n", + "ptrain_hist, pval_hist, pmetrics_hist = prior_distiller.train(\n", + " trainloader, NUM_EPOCHS, testloader, \n", + " log_per=LOG_PER, verbose=True\n", + ")\n", + "print(\"Distilled NWPrior, final rmse %.3f\" % pmetrics_hist[-1][0])" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "model_preds = prior_distiller.get_predicted_params(testloader)\n", + "mean_preds, std_preds = model_preds[0][test_ord].squeeze(), model_preds[1][test_ord].squeeze()\n", + "\n", + "fig = plt.figure(figsize=(12,8))\n", + "plt.plot(X_test, y_test, label=\"real answers\")\n", + "plt.plot(X_test, mean_preds, label='predictions')\n", + "plt.fill_between(X_test, mean_preds - std_preds, mean_preds + std_preds, alpha=0.1)\n", + "plt.grid()\n", + "plt.legend(fontsize=15)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Show data uncertainty\n", + "ev_scores = prior_distiller.eval_uncertainty(\n", + " testloader, \"expected_variance\"\n", + ")[test_ord].squeeze()\n", + "\n", + "fig = plt.figure(figsize=(12,8))\n", + "plt.title(\"Expected variance plot\")\n", + "plt.plot(X_test, ev_scores, color='teal', label='expected variance (data uncertainty)')\n", + "plt.plot(X_test, noise_fn(X_test).pow(2), color='navy', label='noise')\n", + "plt.xlim(-15, 15)\n", + "plt.ylim(0, 1.2)\n", + "plt.grid()\n", + "plt.legend(fontsize=15)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#%%\n", + "# Show knowledge uncertainty\n", + "\n", + "fig = plt.figure(figsize=(12,8))\n", + "voe_scores = prior_distiller.eval_uncertainty(\n", + " testloader, \"variance_of_expected\"\n", + ")[test_ord].squeeze()\n", + "ood_voe_scores = prior_distiller.eval_uncertainty(\n", + " oodloader, \"variance_of_expected\"\n", + ")[test_ord].squeeze()\n", + "plt.plot(X_test, voe_scores, label=\"knowledge uncertainty test\")\n", + "plt.scatter(X_ood, ood_voe_scores, label=\"knowledge uncertainty ood\")\n", + "plt.scatter(X_ood, y_ood, color='g',label='ood points')\n", + "#plt.ylim(0, 2.0)\n", + "plt.grid()\n", + "plt.legend(fontsize=15)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**This notebook:https://github.com/VProv/uncertainty_example**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Find more interesting videos:https://ods.ai/tracks/uncertainty-estimation-in-ml-df2020**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Regression Prior Networks: https://github.com/JanRocketMan/regression-prior-networks**" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}