From 10725393518df14b9b6976686f72fae792c3f393 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Mon, 5 Oct 2015 15:46:54 -0700 Subject: [PATCH] NetSpec: type-check Function inputs (they must be Top instances) --- python/caffe/net_spec.py | 4 ++++ python/caffe/test/test_net_spec.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 93fc01927db..b6520627a4b 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -103,6 +103,10 @@ class Function(object): def __init__(self, type_name, inputs, params): self.type_name = type_name + for index, input in enumerate(inputs): + if not isinstance(input, Top): + raise TypeError('%s input %d is not a Top (type is %s)' % + (type_name, index, type(input))) self.inputs = inputs self.params = params self.ntop = self.params.get('ntop', 1) diff --git a/python/caffe/test/test_net_spec.py b/python/caffe/test/test_net_spec.py index fee3c0aaebe..ffe71bacb08 100644 --- a/python/caffe/test/test_net_spec.py +++ b/python/caffe/test/test_net_spec.py @@ -79,3 +79,11 @@ def test_zero_tops(self): net_proto = silent_net() net = self.load_net(net_proto) self.assertEqual(len(net.forward()), 0) + + def test_type_error(self): + """Test that a TypeError is raised when a Function input isn't a Top.""" + data = L.DummyData(ntop=2) # data is a 2-tuple of Tops + r = r"^Silence input 0 is not a Top \(type is <(type|class) 'tuple'>\)$" + with self.assertRaisesRegexp(TypeError, r): + L.Silence(data, ntop=0) # should raise: data is a tuple, not a Top + L.Silence(*data, ntop=0) # shouldn't raise: each elt of data is a Top