Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions python/caffe/pycaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
RMSPropSolver, AdaDeltaSolver, AdamSolver
import caffe.io

import six

# We directly update methods from Net here (rather than using composition or
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
# automatically have the improved interface.
Expand Down Expand Up @@ -97,7 +99,7 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
raise Exception('Input blob arguments do not match net inputs.')
# Set input according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for in_, blob in kwargs.iteritems():
for in_, blob in six.iteritems(kwargs):
if blob.shape[0] != self.blobs[in_].shape[0]:
raise Exception('Input is not batch sized')
self.blobs[in_].data[...] = blob
Expand Down Expand Up @@ -145,7 +147,7 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
raise Exception('Top diff arguments do not match net outputs.')
# Set top diffs according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for top, diff in kwargs.iteritems():
for top, diff in six.iteritems(kwargs):
if diff.shape[0] != self.blobs[top].shape[0]:
raise Exception('Diff is not batch sized')
self.blobs[top].diff[...] = diff
Expand Down Expand Up @@ -174,13 +176,13 @@ def _Net_forward_all(self, blobs=None, **kwargs):
all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
for batch in self._batch(kwargs):
outs = self.forward(blobs=blobs, **batch)
for out, out_blob in outs.iteritems():
for out, out_blob in six.iteritems(outs):
all_outs[out].extend(out_blob.copy())
# Package in ndarray.
for out in all_outs:
all_outs[out] = np.asarray(all_outs[out])
# Discard padding.
pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
if pad:
for out in all_outs:
all_outs[out] = all_outs[out][:-pad]
Expand Down Expand Up @@ -215,16 +217,16 @@ def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs):
for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}):
batch_blobs = self.forward(blobs=blobs, **fb)
batch_diffs = self.backward(diffs=diffs, **bb)
for out, out_blobs in batch_blobs.iteritems():
for out, out_blobs in six.iteritems(batch_blobs):
all_outs[out].extend(out_blobs.copy())
for diff, out_diffs in batch_diffs.iteritems():
for diff, out_diffs in six.iteritems(batch_diffs):
all_diffs[diff].extend(out_diffs.copy())
# Package in ndarray.
for out, diff in zip(all_outs, all_diffs):
all_outs[out] = np.asarray(all_outs[out])
all_diffs[diff] = np.asarray(all_diffs[diff])
# Discard padding at the end and package in ndarray.
pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
if pad:
for out, diff in zip(all_outs, all_diffs):
all_outs[out] = all_outs[out][:-pad]
Expand Down Expand Up @@ -256,10 +258,10 @@ def _Net_batch(self, blobs):
------
batch: {blob name: list of blobs} dict for a single batch.
"""
num = len(blobs.itervalues().next())
batch_size = self.blobs.itervalues().next().shape[0]
num = len(six.next(six.itervalues(blobs)))
batch_size = six.next(six.itervalues(self.blobs)).shape[0]
remainder = num % batch_size
num_batches = num / batch_size
num_batches = num // batch_size

# Yield full batches.
for b in range(num_batches):
Expand Down