-
Notifications
You must be signed in to change notification settings - Fork 41
/
utils.py
executable file
·78 lines (68 loc) · 2.32 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import print_function
import os
from torch.autograd import Variable
from torch.utils.data import Dataset
import torchvision.models as models
import torch
import torch.nn as nn
import pickle
import pdb
import torch.optim as optim
from PIL import Image
import numpy as np
import random
import torch.backends.cudnn as cudnn
from time import time
from scipy.io import wavfile
def net_frozen(args, model):
print('********************************************************')
model.frozen_until(args.frozen_until)
init_lr = args.lr
if args.trainer.lower() == 'adam':
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=init_lr, weight_decay=args.weight_decay)
elif args.trainer.lower() == 'sgd':
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
lr=init_lr, weight_decay=args.weight_decay)
print('********************************************************')
return model, optimizer
def parallelize_model(model):
if torch.cuda.is_available():
model = model.cuda()
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
return model
def unparallelize_model(model):
try:
while 1:
# to avoid nested dataparallel problem
model = model.module
except AttributeError:
pass
return model
def second2str(second):
h = int(second/3600.)
second -= h*3600.
m = int(second/60.)
s = int(second - m*60)
return "{:d}:{:02d}:{:02d} (s)".format(h, m, s)
def print_eta(t0, cur_iter, total_iter):
"""
print estimated remaining time
t0: beginning time
cur_iter: current iteration
total_iter: total iterations
"""
time_so_far = time() - t0
iter_done = cur_iter + 1
iter_left = total_iter - cur_iter - 1
second_left = time_so_far/float(iter_done) * iter_left
s0 = 'Epoch: '+ str(cur_iter + 1) + '/' + str(total_iter) + ', time so far: ' \
+ second2str(time_so_far) + ', estimated time left: ' + second2str(second_left)
print(s0)
def cvt_to_gpu(X):
return Variable(X.cuda()) if torch.cuda.is_available() \
else Variable(X)
def get_length_wav(fn):
frame_rate, signal = wavfile.read(fn)
return float(signal.shape[0])/frame_rate