Skip to content

Commit

Permalink
- Initial commits of PID based abstention
Browse files Browse the repository at this point in the history
- some code fixes and refactoring
  • Loading branch information
thulas committed Nov 14, 2019
1 parent 542f3cf commit 832e51b
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 38 deletions.
160 changes: 141 additions & 19 deletions dac_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,45 @@
from torch.nn.modules.loss import _Loss
import pdb
import math

from simple_pid import PID

#for numerical stability
epsilon = 1e-7

#this might be changed from inside dac_sandbox.py
total_epochs = 200
alpha_final = 1.0
alpha_init_factor = 64.
#TODO: remove below this might be changed from inside the main script
# total_epochs = 200
# alpha_final = 1.0
# alpha_init_factor = 64.




def get_abst_rate(p_out):
"""
function to return abstention rate given a batch of outputs.
p_out: a batch of probabilisites over classes (or pre-softmax scores)
returns: the rate of abstention; abstention class is assumed to be final class.
"""
#pdb.set_trace()
abst_class_id = p_out.shape[1] - 1
predictions = torch.argmax(p_out,dim=1)
num_abstains = torch.sum(predictions.eq(abst_class_id))
return torch.sum(num_abstains.float())/p_out.shape[0]


#loss calculation and alpha-auto-tune are rolled into one function. This is invoked
#after every iteration
class dac_loss(_Loss):
def __init__(self, model, learn_epochs, use_cuda=False, cuda_device=None):
def __init__(self, model, learn_epochs, total_epochs, use_cuda=False, cuda_device=None,
abst_rate=None, final_abst_rate=None,alpha_final=1.0,alpha_init_factor=64.):
print("using dac loss function\n")
super(dac_loss, self).__init__()
self.model = model
#self.alpha = alpha
self.learn_epochs = learn_epochs
self.total_epochs = total_epochs
self.alpha_final = alpha_final
self.alpha_init_factor = alpha_init_factor
self.use_cuda = use_cuda
self.cuda_device = cuda_device
#self.kappa = kappa #not used
Expand All @@ -42,13 +61,51 @@ def __init__(self, model, learn_epochs, use_cuda=False, cuda_device=None):
self.alpha_var = None

self.alpha_thresh_ewma = None #exponentially weighted moving average for alpha_thresh
self.alpha_thresh = None #instantaneous alpha_thresh
self.ewma_mu = 0.05 #mu parameter for EWMA;
self.curr_alpha_factor = None #for alpha initiliazation
self.alpha_inc = None #linear increase factor of alpha during abstention phase
self.alpha_set_epoch = None

self.abst_rate = None #instantaneous abstention rate
self.abst_rate_ewma = None # exponentially weighted moving averge of abstention

self.pid = None
self.final_abst_rate = None
#PD controller for pre-specified abstention rate
if abst_rate is not None:
#pdb.set_trace()
if abst_rate < 0.:
raise ValueError("Invalid abstention rate of %f. Must be non-negative" %(new_abst_rate))
#self.pid = PID(1.,0.5, 0., sample_time=None,setpoint=abst_rate)
#self.pid.output_limits = (-0.1,0.1)
self.pid = PID(-1.,-0.5, 0., sample_time=None,setpoint=abst_rate)
self.pid.output_limits = (0.,None)

if final_abst_rate is not None:
#if total_epochs is None:
# raise ValueError("total epochs must be specified if final abstention rate is specied")
#else:
self.abst_delta = (abst_rate - final_abst_rate)/(self.total_epochs-self.learn_epochs)
self.final_abst_rate = final_abst_rate

def update_abst_rate(self, abst_delta):
#pdb.set_trace()
if self.pid is not None:
new_abst_rate = max(0.,self.pid.setpoint - abst_delta)
# if new_abst_rate < 0.:
# raise ValueError("Invalid abstention rate of %f. Must be non-negaitve" (%new_abst_rate))
# else:
self.pid.setpoint = new_abst_rate
print("DAC updated abstention rate to %f" %(new_abst_rate))
else:
print("Warning: Cannot update abstention rate as PID has not been initialized")


def __call__(self, input_batch, target_batch, epoch):
global total_epochs, alpha_final
#pdb.set_trace()
#TODO: remove global
#global total_epochs, alpha_final
#pdb.set_trace()
if epoch < self.learn_epochs or not self.model.training:
loss = F.cross_entropy(input_batch, target_batch, reduce=False)
Expand All @@ -57,24 +114,40 @@ def __call__(self, input_batch, target_batch, epoch):
h_c = F.cross_entropy(input_batch[:,0:-1],target_batch,reduce=False)


p_out = F.softmax(F.log_softmax(input_batch,dim=1),dim=1)
#p_out = F.softmax(F.log_softmax(input_batch,dim=1),dim=1)
p_out = torch.exp(F.log_softmax(input_batch,dim=1))
p_out_abstain = p_out[:,-1]
#pdb.set_trace()

#abstention rate update
self.abst_rate = get_abst_rate(p_out)
if self.abst_rate_ewma is None:
self.abst_rate_ewma = self.abst_rate
else:
self.abst_rate_ewma = self.ewma_mu*self.abst_rate + (1-self.ewma_mu)*self.abst_rate_ewma

#update instantaneous alpha_thresh
self.alpha_thresh = Variable(((1. - p_out_abstain)*h_c).mean().data)
#update alpha_thresh_ewma
if self.alpha_thresh_ewma is None:
self.alpha_thresh_ewma = Variable(((1. - p_out_abstain)*h_c).mean().data)
self.alpha_thresh_ewma = self.alpha_thresh #Variable(((1. - p_out_abstain)*h_c).mean().data)
else:
self.alpha_thresh_ewma = Variable(self.ewma_mu*((1. - p_out_abstain)*h_c).mean().data + \
# self.alpha_thresh_ewma = Variable(self.ewma_mu*((1. - p_out_abstain)*h_c).mean().data + \
# (1. - self.ewma_mu)*self.alpha_thresh_ewma.data)
self.alpha_thresh_ewma = Variable(self.ewma_mu*self.alpha_thresh.data + \
(1. - self.ewma_mu)*self.alpha_thresh_ewma.data)

# print("\nloss details (pre abstention): %d,%f,%f,%f,%f\n" %(epoch,p_out_abstain.mean(),loss.mean(),h_c.mean(),
# self.alpha_thresh_ewma))

return loss.mean()

else:
#pdb.set_trace()
#calculate cross entropy only over true classes
h_c = F.cross_entropy(input_batch[:,0:-1],target_batch,reduce=False)
p_out = F.softmax(F.log_softmax(input_batch,dim=1),dim=1)
#p_out = F.softmax(F.log_softmax(input_batch,dim=1),dim=1)
p_out = torch.exp(F.log_softmax(input_batch,dim=1))
#probabilities of abstention class
p_out_abstain = p_out[:,-1]

Expand All @@ -92,12 +165,30 @@ def __call__(self, input_batch, target_batch, epoch):
#loss = (1. - p_out_abstain)*h_c +
# torch.log(1.+self.alpha_var)*p_out_abstain

try:
#abstention rate update
self.abst_rate = get_abst_rate(p_out)
if self.abst_rate_ewma is None:
self.abst_rate_ewma = self.abst_rate
else:
self.abst_rate_ewma = self.ewma_mu*self.abst_rate + (1-self.ewma_mu)*self.abst_rate_ewma




#update instantaneous alpha_thresh
self.alpha_thresh = Variable(((1. - p_out_abstain)*h_c).mean().data)

#if (epoch == 5):
# pdb.set_trace()

try:
#update alpha_thresh_ewma
if self.alpha_thresh_ewma is None:
self.alpha_thresh_ewma = Variable(((1. - p_out_abstain)*h_c).mean().data)
self.alpha_thresh_ewma = self.alpha_thresh #Variable(((1. - p_out_abstain)*h_c).mean().data)
else:
self.alpha_thresh_ewma = Variable(self.ewma_mu*((1. - p_out_abstain)*h_c).mean().data + \
# self.alpha_thresh_ewma = Variable(self.ewma_mu*((1. - p_out_abstain)*h_c).mean().data + \
# (1. - self.ewma_mu)*self.alpha_thresh_ewma.data)
self.alpha_thresh_ewma = Variable(self.ewma_mu*self.alpha_thresh.data + \
(1. - self.ewma_mu)*self.alpha_thresh_ewma.data)


Expand All @@ -109,15 +200,46 @@ def __call__(self, input_batch, target_batch, epoch):


#pdb.set_trace()
self.alpha_var = Variable(self.alpha_thresh_ewma.data /alpha_init_factor)
self.alpha_inc = (alpha_final - self.alpha_var.data)/(total_epochs - epoch)
#aggressive initialization of alpha to jump start abstention
self.alpha_var = Variable(self.alpha_thresh_ewma.data /self.alpha_init_factor)
self.alpha_inc = (self.alpha_final - self.alpha_var.data)/(self.total_epochs - epoch)
self.alpha_set_epoch = epoch

else:
# we only update alpha every epoch
#pass
#self.alpha_var = Variable(self.alpha_thresh_ewma.data)

if self.pid is not None:
#delta = self.pid(self.abst_rate_ewma)
control = self.pid(self.abst_rate_ewma)
#print("control %f abst_rate %f abst_rate_ewma %f" %(control, self.abst_rate, self.abst_rate_ewma) )
#pdb.set_trace()

#self.alpha_var = Variable(self.alpha_thresh.data - .05)
#self.alpha_var = Variable(torch.max(self.alpha_thresh_ewma.data - delta,torch.tensor(0.001).cuda()))
try:
self.alpha_var = Variable(torch.tensor(control).clone().detach())
except TypeError:
pdb.set_trace()

else:
control = 0.

if epoch > self.alpha_set_epoch:
self.alpha_var = Variable(self.alpha_var.data + self.alpha_inc)
if self.pid is None:
self.alpha_var = Variable(self.alpha_var.data + self.alpha_inc)
# #self.alpha_var = Variable(self.alpha_var.data/2.)
# #self.alpha_var = Variable(self.alpha_thresh_ewma.data/0.8)
self.alpha_set_epoch = epoch
if self.final_abst_rate is not None:
self.update_abst_rate(self.abst_delta)
# print("delta %f, abst_rate %f abst_rate_ewma %f alpha_thresh %f alpha_thresh_ewma %f alpha_var %f"
# %(delta, self.abst_rate, self.abst_rate_ewma,
# self.alpha_thresh.data, self.alpha_thresh_ewma.data,self.alpha_var))
print("\ncontrol %f, abst_rate %f abst_rate_ewma %f alpha_thresh %f alpha_thresh_ewma %f alpha_var %f"
%(control, self.abst_rate, self.abst_rate_ewma,
self.alpha_thresh.data, self.alpha_thresh_ewma.data,self.alpha_var))

loss = (1. - p_out_abstain)*h_c - \
self.alpha_var*torch.log(1. - p_out_abstain)
Expand All @@ -133,9 +255,9 @@ def __call__(self, input_batch, target_batch, epoch):

return loss.mean()

except RuntimeError as e:
except RuntimeError as e:
#pdb.set_trace()
print(e)
print(e)



Expand Down
27 changes: 18 additions & 9 deletions train_dac.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@

parser.add_argument('--label_noise_info', default=None, type=str, help='pickle file containing indices and labels to use for simulating label noise')

parser.add_argument('--abst_rate', default=None, type=float, help='Pre-specified abstention rate; will attempt to dynamically tune abstention hyperparameter to stabilize abstention at this rate')

#for wide residual networks
parser.add_argument('--widen_factor', default=10, type=int, help='width of model')
Expand Down Expand Up @@ -117,9 +118,10 @@
args.save_epoch_model = int(args.save_epoch_model*args.epdl)


dac_loss.total_epochs = args.epochs
dac_loss.alpha_final = args.alpha_final
dac_loss.alpha_init_factor = args.alpha_init_factor
#TODO: these have to be supplied to the dac_loss function; get rid of globals.
# dac_loss.total_epochs = args.epochs
# dac_loss.alpha_final = args.alpha_final
# dac_loss.alpha_init_factor = args.alpha_init_factor

if not args.log_file is None:
sys.stdout = open(args.log_file,'w')
Expand All @@ -138,7 +140,7 @@
#abstain class id is the last class
abstain_class_id = num_classes
#simulate label noise if needed
trainset = label_noise.label_noise(args, trainset)
trainset = label_noise.label_noise(args, trainset, num_classes)
#set data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=2)
Expand Down Expand Up @@ -335,7 +337,8 @@ def getNetwork(args):
print('Using regular (non-abstaining) loss function during training')
else:
criterion = dac_loss.loss_fn_dict[args.loss_fn](model=net, learn_epochs=args.learn_epochs,
use_cuda=True, cuda_device=cuda_device).cuda(cuda_device)
total_epochs=args.epochs, use_cuda=True, cuda_device=cuda_device, abst_rate=args.abst_rate,
alpha_final=args.alpha_final,alpha_init_factor=args.alpha_init_factor).cuda(cuda_device)

else:
#criterion = nn.CrossEntropyLoss()
Expand All @@ -345,7 +348,9 @@ def getNetwork(args):
print('Using regular (non-abstaining) loss function during training')

else:
criterion = dac_loss.loss_fn_dict[args.loss_fn](model=net, learn_epochs=args.learn_epochs)
criterion = dac_loss.loss_fn_dict[args.loss_fn](model=net,
learn_epochs=args.learn_epochs, total_epochs=args.epochs, abst_rate=abst_rate,
alpha_final=args.alpha_final,alpha_init_factor=args.alpha_init_factor)


def get_hms(seconds):
Expand Down Expand Up @@ -377,7 +382,7 @@ def train(epoch):
momentum=0.9, weight_decay=5e-4,nesterov=args.nesterov)
print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.learning_rate(args.lr, int(epoch/args.epdl))))

#print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.learning_rate(args.lr, epoch)))
#print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.learning_rate(args.lr, epoch)))

for batch_idx, (inputs, targets) in enumerate(trainloader):
#print(type(inputs))
Expand Down Expand Up @@ -421,7 +426,7 @@ def train(epoch):


def save_train_scores(epoch):
net.eval()
#net.eval()

train_softmax_scores = []

Expand All @@ -435,7 +440,11 @@ def save_train_scores(epoch):

train_scores = torch.cat(train_softmax_scores).cpu().numpy()
print('Saving train softmax scores at Epoch %d' %(epoch))
np.save(args.log_file+".train_scores.epoch_"+str(epoch), train_scores)
if args.log_file is None:
fn = 'test'
else:
fn = args.log_file
np.save(fn+".train_scores.epoch_"+str(epoch), train_scores)



Expand Down
29 changes: 19 additions & 10 deletions utils/label_noise.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import numpy as np
import pdb
import cPickle as cp



def label_noise(args, trainset):
def label_noise(args, trainset, num_classes):

if args.rand_labels is not None:

#pdb.set_trace()
rf = args.rand_labels
if rf < 0. or rf > 1.:
print("rand_labels fraction should be between 0 and 1")
sys.exit(0)
#pdb.set_trace()
if args.dataset == 'stl10-labeled' or args.dataset=='tin200':
train_labels = np.asarray(trainset.labels)
else:
train_labels = np.asarray(trainset.train_labels)
#train_labels = np.asarray(trainset.train_labels)
train_labels = np.asarray(trainset.targets)

print("randomizing %f percent of labels " %(rf*100))
n_train = len(train_labels)
Expand All @@ -34,15 +36,20 @@ def label_noise(args, trainset):

if args.del_noisy_data:
print("deleting noisy data")
trainset.train_data = np.delete(trainset.train_data,wrong_indices,axis=0)
#trainset.train_data = np.delete(trainset.train_data,wrong_indices,axis=0)
trainset.data = np.delete(trainset.data,wrong_indices,axis=0)
train_labels = np.delete(train_labels,wrong_indices)

#pdb.set_trace()
if args.dataset == 'stl10-labeled' or args.dataset=='tin200':
trainset.labels = train_labels.tolist()
#elif args.dataset == 'mnist':
# trainset.targets = train_labels.tolist()
else:
trainset.train_labels = train_labels.tolist()
#trainset.train_labels = train_labels.tolist()
trainset.targets = train_labels.tolist()

print("training on %d data samples" %(len(trainset.train_data)))
print("training on %d data samples" %(len(trainset.data)))

#save randomized indices if validation or train scores are also being saved
if args.save_val_scores or args.save_train_scores:
Expand Down Expand Up @@ -74,7 +81,8 @@ def label_noise(args, trainset):
train_labels = np.asarray(trainset.train_labels)
train_labels_good = np.copy(train_labels)
train_labels[noise_indices] = noise_labels
trainset.train_labels = train_labels.tolist()
#trainset.train_labels = train_labels.tolist()
trainset.targets = train_labels.tolist()

if args.exclude_train_indices is not None:
if args.rand_labels is not None:
Expand All @@ -94,7 +102,8 @@ def label_noise(args, trainset):
else: # cifar-10/100 and others
train_labels = np.asarray(trainset.train_labels)
train_labels = np.delete(train_labels,exclude_indices)
trainset.train_labels = train_labels.tolist()
#trainset.train_labels = train_labels.tolist()
trainset.targets = train_labels.tolist()

if args.label_noise_info is not None:
assert(train_labels_good is not None)
Expand Down

0 comments on commit 832e51b

Please sign in to comment.