Skip to content

Commit

Permalink
config is now passed as an argument not directly imported
Browse files Browse the repository at this point in the history
  • Loading branch information
vanangamudi committed Jul 18, 2018
1 parent 2663e8a commit f1b185f
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions utilz.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def wrapper(*args, **kwargs):
from tqdm import tqdm as _tqdm

def tqdm(a):
return _tqdm(a) if CONFIG().tqdm else a
return _tqdm(a) if config.CONFIG.tqdm else a


def squeeze(lol):
Expand Down Expand Up @@ -151,7 +151,7 @@ def LongVar(array, requires_grad=False):

def Var(array, requires_grad=False):
ret = Variable(torch.Tensor(array), requires_grad=requires_grad)
if CONFIG.cuda:
if config.CONFIG.cuda:
ret = ret.cuda()

return ret
Expand All @@ -167,25 +167,31 @@ def init_hidden(batch_size, cell):
hidden = Variable(torch.zeros(layers, batch_size, cell.hidden_size))
context = Variable(torch.zeros(layers, batch_size, cell.hidden_size))

if CONFIG.cuda:
if config.CONFIG.cuda:
hidden = hidden.cuda()
context = context.cuda()
return hidden, context

if isinstance(cell, (nn.GRU, nn.GRUCell)):
hidden = Variable(torch.zeros(layers, batch_size, cell.hidden_size))
if CONFIG.cuda:
if config.CONFIG.cuda:
hidden = hidden.cuda()
return hidden

class Averager(list):
def __init__(self, filename=None, *args, **kwargs):
def __init__(self, filename=None, ylim=None, *args, **kwargs):
super(Averager, self).__init__(*args, **kwargs)
self.filename = filename
self.ylim = ylim
if filename:
open(filename, 'w').close()
try:
f = '{}.pkl'.format(filename)
if os.path.isfile(f):
log.debug('loading {}'.format(f))
self.extend(pickle.load(open(f, 'rb')))
except:
open(filename, 'w').close()

self.filename = filename

@property
def avg(self):
if len(self):
Expand Down

0 comments on commit f1b185f

Please sign in to comment.