Skip to content

Commit

Permalink
lot of big changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
vanangamudi committed Jul 3, 2019
1 parent 9fc1b88 commit ba74da5
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 995 deletions.
Empty file removed __init__.py
Empty file.
23 changes: 17 additions & 6 deletions datafeed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from config import CONFIG
from pprint import pprint, pformat

import logging
Expand All @@ -9,8 +8,8 @@

import random

from .utilz import tqdm
from .debug import memory_consumed
from anikattu.utilz import tqdm
from anikattu.debug import memory_consumed

from collections import Counter

Expand Down Expand Up @@ -98,14 +97,19 @@ def next_batch(self, batch_size=None, apply_batchop=True, **kwargs):
.format(self.name, self._exhausted_count))

return self.batch(batch_size=batch_size, apply_batchop=apply_batchop)
except KeyboardInterrupt:
raise KeyboardInterrupt
except SystemExit:
exit(1)
except:
log.exception('batch failed')
return self.next_batch(apply_batchop=apply_batchop)

def nth_batch(self, n, apply_batchop=True):
b = self.data[ n * self.batch_size : (n+1) * self.batch_size ]
def nth_batch(self, n, batch_size=None, apply_batchop=True):
if not batch_size:
batch_size = self.batch_size

b = self.data[ n * batch_size : (n+1) * batch_size ]
if apply_batchop:
return self._batchop(b)

Expand Down Expand Up @@ -238,6 +242,8 @@ def next_batch(self, batch_size=None, apply_batchop=True, sampling_distribution=
.format(self.name, self._exhausted_count))

return self.batch(batch_size=batch_size, apply_batchop=apply_batchop, sampling_distribution=sampling_distribution)
except KeyboardInterrupt:
raise KeyboardInterrupt
except SystemExit:
exit(1)
except:
Expand All @@ -248,7 +254,12 @@ def next_batch(self, batch_size=None, apply_batchop=True, sampling_distribution=
def nth_batch(self, n, apply_batchop=True):
b = []
for fname, feed in self.datafeeds.items():
b.append(random.choice(feed.nth_batch(min(n, random.choice(range(feed.num_batch))), apply_batchop=False)))
b.append(
random.choice(
feed.nth_batch(
min(n, random.choice(range(feed.num_batch))),
apply_batchop=False)))

if len(b) == self.batch_size: break

if apply_batchop:
Expand Down
30 changes: 30 additions & 0 deletions debug.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
import os
import psutil
import linecache
import tracemalloc
process = psutil.Process(os.getpid())

def memory_consumed():
return process.memory_info().rss

def display_tracemalloc_top(snapshot, key_type='lineno', limit=3):
"""
from https://stackoverflow.com/a/45679009/1685729
"""
snapshot = snapshot.filter_traces((
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<unknown>"),
))
top_stats = snapshot.statistics(key_type)

print("Top %s lines" % limit)
for index, stat in enumerate(top_stats[:limit], 1):
frame = stat.traceback[0]
# replace "/path/to/module/file.py" with "module/file.py"
filename = os.sep.join(frame.filename.split(os.sep)[-2:])
print("#%s: %s:%s: %.1f KiB"
% (index, filename, frame.lineno, stat.size / 1024))
line = linecache.getline(frame.filename, frame.lineno).strip()
if line:
print(' %s' % line)

other = top_stats[limit:]
if other:
size = sum(stat.size for stat in other)
print("%s other: %.1f KiB" % (len(other), size / 1024))
total = sum(stat.size for stat in top_stats)
print("Total allocated size: %.1f KiB" % (total / 1024))
6 changes: 3 additions & 3 deletions nlp_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def experiment(config, ROOT_DIR, model, VOCAB, LABELS, datapoints=[[], [], []],

for e in range(eons):

if not trainer.train():
raise Exception

predictor.model.load_state_dict(trainer.best_model[1])

dump = open('{}/results/eon_{}.csv'.format(ROOT_DIR, e), 'w')
Expand All @@ -273,9 +276,6 @@ def experiment(config, ROOT_DIR, model, VOCAB, LABELS, datapoints=[[], [], []],



if not trainer.train():
raise Exception


except KeyboardInterrupt:
return locals()
Expand Down
1 change: 0 additions & 1 deletion trainer/__init__.py

This file was deleted.

277 changes: 0 additions & 277 deletions trainer/lm.py

This file was deleted.

Loading

0 comments on commit ba74da5

Please sign in to comment.