Skip to content

Commit

Permalink
[WIP] Implement very simple filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Apr 25, 2019
1 parent 70dc412 commit 6171371
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
6 changes: 6 additions & 0 deletions common_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@

gflags.DEFINE_integer('nb_windows', 25, 'Number of regions to segmentate the'
' gate location on the image')

# Testing / Visualizing
gflags.DEFINE_integer('successive_frames', 5, 'number of successive frames to'
' use for the prediction filter (backward and forward)')
gflags.DEFINE_integer('max_outliers', 3, 'number of successive frames to'
' use for the prediction filter')
28 changes: 28 additions & 0 deletions prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
from common_flags import FLAGS
from constants import TEST_PHASE


'''
Looks at previous and future predictions and see if the given prediction is an
outlier, in which case it is recomputed as an interpolation of its direct
neighbours.
'''
def filter_prediction(prediction, previous_predictions, next_predictions):
if all(p == previous_predictions[0] for p in previous_predictions) and
all(p == next_predictions[0] for p in next_predictions):
return previous_predictions[0]


def save_visual_output(input_img, prediction, index):
if FLAGS.img_mode == "rgb":
img_mode = "RGB"
Expand Down Expand Up @@ -101,14 +113,30 @@ def _main():
nb_batches = int(np.ceil(n_samples / FLAGS.batch_size))
localization_accuracy = 0

previous_predictions = []
next_predictions = []
n = 0
step = 10
for i in range(0, nb_batches, step):
inputs, predictions = utils.compute_predictions(
model, test_generator, step, verbose = 1)

for j in range(len(inputs)):
if n > (2*FLAGS.successive_frames + FLAGS.max_outliers):
predictions[j] = filter_prediction(predictions[j],
previous_predictions,
next_predictions)
if len(previous_predictions) > FLAGS.successive_frames:
del previous_predictions[0]
if len(next_predictions) > FLAGS.successive_frames:
del next_predictions[0]

save_visual_output(inputs[j], predictions[j], n)

if j > 1:
previous_predictions.append(prediction[j-1])
if j < len(inputs):
next_predictions.append(prediction[j+1])
n += 1

print("[*] Generating {} prediction images...".format(n))
Expand Down

0 comments on commit 6171371

Please sign in to comment.