forked from EIDOSLAB/entangling-disentangling-bias
-
Notifications
You must be signed in to change notification settings - Fork 0
/
EnD.py
113 lines (88 loc) · 3.84 KB
/
EnD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn.functional as F
import numpy as np
class pattern_norm(torch.nn.Module):
def __init__(self, scale = 1.0):
super(pattern_norm, self).__init__()
self.scale = scale
def forward(self, input):
sizes = input.size()
if len(sizes) > 2:
input = input.view(-1, np.prod(sizes[1:]))
input = torch.nn.functional.normalize(input, p=2, dim=1, eps=1e-12)
input = input.view(sizes)
return input
class Hook():
def __init__(self, module, backward=False):
self.module = module
if backward == False:
self.hook = module.register_forward_hook(self.hook_fn)
else:
self.hook = module.register_backward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = input
self.output = output
def close(self):
self.hook.remove()
# For each discriminatory class, orthogonalize samples
def abs_orthogonal_blind(output, gram, target_labels, bias_labels):
bias_classes = torch.unique(bias_labels)
orthogonal_loss = torch.tensor(0.).to(output.device)
M_tot = 0.
for bias_class in bias_classes:
bias_mask = (bias_labels == bias_class).type(torch.float)
if len(bias_mask.size()) < 2:
bias_mask = bias_mask.unsqueeze(dim=1)
# print(bias_mask.size())
# print('dajs')
bias_mask = torch.tril(torch.mm(bias_mask, torch.transpose(bias_mask, 0, 1)), diagonal=-1)
M = bias_mask.sum()
M_tot += M
if M > 0:
orthogonal_loss += torch.abs(torch.sum(gram*bias_mask))
if M_tot > 0:
orthogonal_loss /= M_tot
return orthogonal_loss
# For each target class, parallelize samples belonging to
# different discriminatory classes
def abs_parallel(gram, target_labels, bias_labels):
target_classes = torch.unique(target_labels)
bias_classes = torch.unique(bias_labels)
parallel_loss = torch.tensor(0.).to(gram.device)
M_tot = 0.
for target_class in target_classes:
class_mask = (target_labels == target_class).type(torch.float)
if len(class_mask.size()) < 2:
class_mask = class_mask.unsqueeze(dim=1)
for idx, bias_class in enumerate(bias_classes):
bias_mask = (bias_labels == bias_class).type(torch.float)
if len(bias_mask.size()) < 2:
bias_mask = bias_mask.unsqueeze(dim=1)
for other_bias_class in bias_classes[idx:]:
if other_bias_class == bias_class:
continue
other_bias_mask = (bias_labels == other_bias_class).type(torch.float)
if len(other_bias_mask.size()) < 2:
other_bias_mask = other_bias_mask.unsqueeze(dim=1)
mask = torch.tril(torch.mm(class_mask*bias_mask, torch.transpose(class_mask*other_bias_mask, 0, 1)), diagonal=-1)
M = mask.sum()
M_tot += M
if M > 0:
parallel_loss -= torch.sum((1.0+gram)*mask*0.5)
if M_tot > 0:
parallel_loss = 1.0 + (parallel_loss / M_tot)
return parallel_loss
def abs_regu(hook, target_labels, bias_labels, alpha=1.0, beta=1.0, sum=True):
D = hook.output
if len(D.size()) > 2:
D = D.view(-1, np.prod((D.size()[1:])))
# print(D.size())
gram_matrix = torch.tril(torch.mm(D, torch.transpose(D, 0, 1)), diagonal=-1)
# not really needed, just for safety for approximate repr
gram_matrix = torch.clamp(gram_matrix, -1, 1.)
zero = torch.tensor(0.).to(target_labels.device)
R_ortho = abs_orthogonal_blind(D, gram_matrix, target_labels, bias_labels) if alpha != 0 else zero
R_parallel = abs_parallel(gram_matrix, target_labels, bias_labels) if beta != 0 else zero
if sum:
return alpha*R_ortho + beta*R_parallel
return alpha*R_ortho, beta*R_parallel