Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ The [currently available algorithms](./subpopbench/learning/algorithms.py) are:
* Label-Distribution-Aware Margin Loss (**LDAM**, [Cao et al., 2019](https://arxiv.org/abs/1906.07413))
* Balanced Softmax (**BSoftmax**, [Ren et al., 2020](https://arxiv.org/abs/2007.10740))
* Classifier Re-Training (**CRT**, [Kang et al., 2020](https://arxiv.org/abs/1910.09217))
* Uniform Risk Minimization (**URM**, [Krishnamachari et al., 2024](https://openreview.net/forum?id=PgLbS5yp8n))

Send us a PR to add your algorithm! Our implementations use the hyper-parameter grids [described here](./subpopbench/hparams_registry.py).

Expand Down
17 changes: 17 additions & 0 deletions subpopbench/hparams_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def _hparam(name, default_val, random_val_fn):
_hparam('stage1_model', 'model.pkl', lambda r: 'model.pkl')
_hparam('dfr_reg', .1, lambda r: 10**r.uniform(-2, 0.5))

elif algorithm == 'URM':
_hparam('urm_lambda', 0.1, lambda r: float(r.uniform(0,0.2)))

_hparam('urm_discriminator_hidden_layers', 1, lambda r: int(r.choice([1,2,3])))
_hparam('urm_generator_output', 'tanh', lambda r: str(r.choice(['tanh', 'relu'])))
_hparam('urm_discriminator_update_freq', 1, lambda r: int(r.choice([1])))

if dataset in IMAGE_DATASETS + TABULAR_DATASET:
_hparam('urm_discriminator_lr', 1e-3, lambda r: 10**r.uniform(-5, -3))
else:
_hparam('urm_discriminator_lr', 1e-5, lambda r: 10**r.uniform(-6, -5))

if dataset in TEXT_DATASETS:
_hparam('urm_discriminator_optimizer', 'adamw', lambda r: str(r.choice(['adamw'])))
else:
_hparam('urm_discriminator_optimizer', 'sgd', lambda r: str(r.choice(['sgd'])))

# Dataset-and-algorithm-specific hparam definitions
# Each block of code below corresponds to exactly one hparam. Avoid nested conditionals

Expand Down
184 changes: 183 additions & 1 deletion subpopbench/learning/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
'BSoftmax',
'CRT',
'ReWeightCRT',
'VanillaCRT'
'VanillaCRT',
'URM'
]


Expand Down Expand Up @@ -174,6 +175,187 @@ def return_feats(self, x):
def predict(self, x):
return self.network(x)

class URM(ERM):
def __init__(self, data_type, input_shape, num_classes, num_attributes, num_examples, hparams, grp_sizes=None):
ERM.__init__(self, data_type, input_shape, num_classes, num_attributes, num_examples, hparams, grp_sizes=grp_sizes)

self._setup_adversarial_net()

def _modify_generator_output(self):
"""
Modifies the output activation of the encoder/featurizer
"""
print('--> Modifying encoder output:', self.hparams['urm_generator_output'])

if self.hparams['urm_generator_output'] == 'tanh':
if self.data_type == 'images' and self.hparams['image_arch'] == 'resnet_sup_in1k':
self.featurizer.network.layer4[2].relu = nn.Tanh()

elif self.data_type == 'text' and self.hparams['text_arch'] == 'bert-base-uncased':
# self.featurizer.activation = nn.Tanh()
# # it's already Tanh, no change needed
assert type(self.featurizer.model.pooler.activation) is torch.nn.modules.activation.Tanh
elif self.data_type == 'tabular':
self.featurizer.activation = nn.Tanh()
else:
raise Exception('unimplemented data_type: %s' % self.data_type)

elif self.hparams['urm_generator_output'] == 'relu':
if self.data_type == 'images' and self.hparams['image_arch'] == 'resnet_sup_in1k':
pass # unchanged
elif self.data_type == 'text' and self.hparams['text_arch'] == 'bert-base-uncased':
# self.featurizer.activation = nn.ReLU()
self.featurizer.model.pooler.activation = nn.ReLU()
elif self.data_type == 'tabular':
self.featurizer.activation = nn.ReLU()
else:
raise Exception('unimplemented data_type: %s' % self.data_type)

else:
raise Exception('unrecognized output activation: %s' % self.hparams['urm_generator_output'])

# define min and max of output values
if self.hparams['urm_generator_output'] == 'tanh':
self.a, self.b = -1,1
elif self.hparams['urm_generator_output'] == 'identity':
self.a, self.b = 0,1
# a,b = self.hparams['urm_noise_range'][0], self.hparams['urm_noise_range'][1]
elif self.hparams['urm_generator_output'] == 'relu':
self.a, self.b = 0,1
# self.a,self.b = self.hparams['urm_noise_range'][0], self.hparams['urm_noise_range'][1]
elif self.hparams['urm_generator_output'] in ['sigmoid']:
self.a, self.b = 0,1
elif self.hparams['urm_generator_output'] in ['brelu']:
self.a, self.b = 0,3
else:
raise Exception('unrecognized output activation: %s' % self.hparams['urm_generator_output'])

def _setup_adversarial_net(self):
print('--> Initializing discriminator <--')
self.discriminator = self._init_discriminator()

self.discriminator_loss = torch.nn.BCEWithLogitsLoss(reduction="mean") # apply on logit

# featurizer optimized by self.optimizer only
if self.hparams["urm_discriminator_optimizer"] == 'sgd':
self.discriminator_optimizer = torch.optim.SGD(self.discriminator.parameters(), lr=self.hparams['urm_discriminator_lr'], \
weight_decay=self.hparams['weight_decay'], momentum=0.9)
elif self.hparams["urm_discriminator_optimizer"] == 'adam':
self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams['urm_discriminator_lr'])
elif self.hparams["urm_discriminator_optimizer"] == 'adamw':
self.discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.hparams['urm_discriminator_lr'], weight_decay=self.hparams['weight_decay'])
else:
raise Exception('%s unimplemented' % self.hparams["urm_discriminator_optimizer"])

self._modify_generator_output()
self.sigmoid = nn.Sigmoid() # to compute discriminator acc.

def _init_discriminator(self):
"""
3 hidden layer MLP
"""
model = nn.Sequential()

model.add_module("dense1", nn.Linear(self.featurizer.n_outputs, 100))
model.add_module("act1", nn.LeakyReLU())

for _ in range(self.hparams['urm_discriminator_hidden_layers']):
model.add_module("dense%d" % (2+_), nn.Linear(100, 100))
model.add_module("act2%d" % (2+_), nn.LeakyReLU())

model.add_module("output", nn.Linear(100, 1)) # model outputs logit, used with BCEWithLogitsLoss (numerically more stable)

return model

def _generate_noise(self, feats):
"""
If U is a random variable uniformly distributed on [0, 1), then (b-a)*U + a is uniformly distributed on [a, b).
"""
uniform_noise = torch.rand(feats.size(), dtype=feats.dtype, layout=feats.layout, device=feats.device) # U~[0,1]
n = ((self.b-self.a) * uniform_noise) + self.a # n ~ [a,b)

return n

def _generate_soft_labels(self, size, device, a ,b):
# returns size random numbers in [a,b]
uniform_noise = torch.rand(size, device=device) # U~[0,1]
return ((b-a) * uniform_noise) + a

def get_accuracy(self, y_true, y_prob):
# y_prob is binary probability
assert y_true.ndim == 1 and y_true.size() == y_prob.size()
y_prob = y_prob > 0.5
return (y_true == y_prob).sum().item() / y_true.size(0)

def _update_discriminator(self, i, x, y, a, step, feats):
feats = feats.detach() # don't backbrop through encoder in this step
noise = self._generate_noise(feats)

noise_logits = self.discriminator(noise) # (N,1)
feats_logits = self.discriminator(feats) # (N,1)

# hard targets
hard_true_y = torch.tensor([1] * noise.shape[0], device=noise.device, dtype=noise.dtype) # [1,1...1] noise is true
hard_fake_y = torch.tensor([0] * feats.shape[0], device=feats.device, dtype=feats.dtype) # [0,0...0] feats are fake (generated)

true_y = hard_true_y
fake_y = hard_fake_y

noise_loss = self.discriminator_loss(noise_logits.squeeze(1), true_y) # pass logits to BCEWithLogitsLoss
feats_loss = self.discriminator_loss(feats_logits.squeeze(1), fake_y) # pass logits to BCEWithLogitsLoss

d_loss = 1*noise_loss + self.hparams['urm_lambda']*feats_loss

# update discriminator
self.discriminator_optimizer.zero_grad()
d_loss.backward()
self.discriminator_optimizer.step()

def _compute_loss(self, i, x, y, a, step):
self.activations = {} # reset activations

feats = self.return_feats(x)

classifier_output = self.classifier(feats)

# train generator/encoder to make discriminator classify feats as noise (label 1)
true_y = torch.tensor(feats.shape[0]*[1], device=feats.device, dtype=feats.dtype)

g_logits = self.discriminator(feats)
g_loss = self.discriminator_loss(g_logits.squeeze(1), true_y) # apply BCEWithLogitsLoss to discriminator's logit output

loss = ce_loss + self.hparams['urm_lambda']*g_loss

return loss, feats

def predict(self, x):
# for inference, used in eval_helper.py
return self.network(x)

def update(self, minibatch, step):
all_i, all_x, all_y, all_a = minibatch

loss, feats = self._compute_loss(all_i, all_x, all_y, all_a, step)

self.optimizer.zero_grad()

loss.backward()
if self.clip_grad:
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.optimizer.step()

if self.lr_scheduler is not None:
self.lr_scheduler.step()

if self.data_type == "text":
self.network.zero_grad()

# update discriminator after updating encoder-classifier (alternating updates)
if (step % self.hparams['urm_discriminator_update_freq'] == 0):
self._update_discriminator(all_i, all_x, all_y, all_a, step, feats)

return {'loss': loss.item()}


class GroupDRO(ERM):
"""
Expand Down
15 changes: 13 additions & 2 deletions subpopbench/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, n_inputs, n_outputs, hparams):
self.output = nn.Linear(hparams['mlp_width'], n_outputs)
self.n_outputs = n_outputs

self.activation = nn.Identity() # added for URM, does not affect other algorithms

def forward(self, x):
x = self.input(x)
x = self.dropout(x)
Expand All @@ -37,6 +39,9 @@ def forward(self, x):
x = self.dropout(x)
x = F.relu(x)
x = self.output(x)

x = self.activation(x) # added for URM, does not affect other algorithms

return x


Expand Down Expand Up @@ -191,6 +196,8 @@ def __init__(self, model, hparams):
)
self.dropout = nn.Dropout(classifier_dropout)

self.activation = nn.Identity() # added for URM, does not affect other algorithms

def forward(self, x):
kwargs = {
'input_ids': x[:, :, 0],
Expand All @@ -199,11 +206,15 @@ def forward(self, x):
if x.shape[-1] == 3:
kwargs['token_type_ids'] = x[:, :, 2]
output = self.model(**kwargs)

if hasattr(output, 'pooler_output'):
return self.dropout(output.pooler_output)
output = self.dropout(output.pooler_output)
else:
return self.dropout(output.last_hidden_state[:, 0, :])
output = self.dropout(output.last_hidden_state[:, 0, :])

output = self.activation(output) # added for URM, does not affect other algorithms

return output

def replace_module_prefix(state_dict, prefix, replace_with=""):
state_dict = {
Expand Down