Skip to content

Commit

Permalink
Merge pull request #275 from ska-sa/deadlock-tf2
Browse files Browse the repository at this point in the history
Fixes deadlocking queue inside montblanc
  • Loading branch information
bennahugo authored Jan 12, 2022
2 parents fefb941 + 0a957ae commit 662e180
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 22 deletions.
61 changes: 40 additions & 21 deletions montblanc/impl/rime/tensorflow/RimeSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import collections
import copy
import itertools
from os import lockf
import threading
import sys
import types
Expand Down Expand Up @@ -226,33 +227,46 @@ def pop(self, key, default=None):
#======================
# Thread pool executors
#======================

tpe = cf.ThreadPoolExecutor

self._descriptor_executor = tpe(1)
self._feed_executors = [tpe(1) for i in range(shards)]
self._compute_executors = [tpe(1) for i in range(shards)]
self._consumer_executor = tpe(1)

class InputsWaiting(object):
"""
Keep track of the number of inputs waiting
to be consumed on each shard
Must be called inside __enter__ of _feederqueuelock
"""
def __init__(self, shards):
self._lock = threading.Lock()
self._feederqueuelock = [threading.Lock()] * shards
self._inputs_waiting = np.zeros(shape=(shards,), dtype=np.int32)

def get(self):
with self._lock:
return self._inputs_waiting

def increment(self, shard):
with self._lock:
def get(self, shard, skiplock=False):
if not skiplock:
with self._feederqueuelock[shard]:
return self._inputs_waiting[shard]
else:
return self._inputs_waiting[shard]

def getlock(self, shard):
return self._feederqueuelock[shard]

def increment(self, shard, skiplock=False):
if not skiplock:
with self._feederqueuelock[shard]:
self._inputs_waiting[shard] += 1
else:
self._inputs_waiting[shard] += 1

def decrement(self, shard):
with self._lock:
def decrement(self, shard, skiplock=False):
if not skiplock:
with self._feederqueuelock[shard]:
self._inputs_waiting[shard] -= 1
else:
self._inputs_waiting[shard] -= 1

self._inputs_waiting = InputsWaiting(shards)
Expand Down Expand Up @@ -404,10 +418,8 @@ def _feed_impl(self, cube, data_sources, data_sinks, global_iter_args):

# Find indices of the emptiest staging_areas and, by implication
# the shard with the least work assigned to it
emptiest_staging_areas = np.argsort(self._inputs_waiting.get())
shard = emptiest_staging_areas[0]
shard = next(which_shard)

shard = next(which_shard) # round robin cycle on shards
# submit a job - once loaded the shard will be incremented for compute to start
feed_f = self._feed_executors[shard].submit(self._feed_actual,
data_sources.copy(), cube.copy(),
descriptor, shard,
Expand All @@ -420,12 +432,8 @@ def _feed_impl(self, cube, data_sources, data_sinks, global_iter_args):
consume_f = self._consumer_executor.submit(self._consume,
data_sinks.copy(), cube.copy(), global_iter_args)

self._inputs_waiting.increment(shard)

yield (feed_f, compute_f, consume_f)

chunks_fed += 1

montblanc.log.info("Done feeding {n} chunks.".format(n=chunks_fed))

def _feed_actual(self, *args):
Expand Down Expand Up @@ -515,13 +523,24 @@ def _feed_actual_impl(self, data_sources, cube,
for (a, ph, ds, ad) in gen }

self._tfrun(staging_area.put_op, feed_dict=feed_dict)

# input finally loaded now increment the number of inputs waiting on the shard
self._inputs_waiting.increment(shard)

def _compute(self, feed_dict, shard):
""" Call the tensorflow compute """

try:
descriptor, enq = self._tfrun(self._tf_expr[shard], feed_dict=feed_dict)
self._inputs_waiting.decrement(shard)
with self._inputs_waiting.getlock(shard):
# only execute once data is indicated to be loaded for
# this shard. Otherwise wait for the next iteration of the
# main event loop to request another compute on this shard
if self._inputs_waiting.get(shard, skiplock=True) > 0:
descriptor, enq = self._tfrun(self._tf_expr[shard], feed_dict=feed_dict)
self._inputs_waiting.decrement(shard, skiplock=True)
else: #postpone work till data becomes available
self._compute_executors[shard].submit(self._compute,
feed_dict, shard)

except Exception as e:
montblanc.log.exception("Compute Exception")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def readme():
log.info('install_requires={}'.format(install_requires))

setup(name='montblanc',
version="0.7.1",
version="0.7.2",
description='GPU-accelerated RIME implementations.',
long_description=readme(),
url='http://github.com/ska-sa/montblanc',
Expand Down

0 comments on commit 662e180

Please sign in to comment.