Skip to content

Commit

Permalink
Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
MrBago committed Jun 19, 2017
1 parent d33a40b commit 786ee77
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions python/tests/transformers/named_image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,22 @@

class NamedImageTransformerImagenetTest(SparkDLTestCase):

def test_inceptionV3_prediction(self):
"""
Test inceptionV3 using keras, tensorflow and sparkDL
We run the sparkDL test with and without resizing beforehand
"""

@classmethod
def setUpClass(cls):
super(NamedImageTransformerImagenetTest, cls).setUpClass()

# Compute values used by multiple tests.
imgFiles, images = getSampleImageList()
imageArray = np.empty((len(images), 299, 299, 3), 'uint8')
for i, img in enumerate(images):
assert img is not None and img.mode == "RGB"
imageArray[i] = np.array(img.resize((299, 299)))

# Basic keras flow
# We predict the class probabilities for the images in our test library using keras API.
# Predict the class probabilities for the images in our test library using keras API.
prepedImaged = inception_v3.preprocess_input(imageArray.astype('float32'))
model = inception_v3.InceptionV3()
kerasPredict = model.predict(prepedImaged)
# These values are used by multiple tests so cache them on class setup.
cls.imageArray = imageArray
cls.kerasPredict = kerasPredict

Expand All @@ -75,6 +70,10 @@ def test_buildtfgraphforname(self):
np.testing.assert_array_almost_equal(kerasPredict, tfPredict)

def test_DeepImagePredictorNoReshape(self):
"""
Run sparkDL inceptionV3 transformer on resized images and compare result to cached keras
result.
"""
imageArray = self.imageArray
kerasPredict = self.kerasPredict
def rowWithImage(img):
Expand All @@ -84,8 +83,6 @@ def rowWithImage(img):
return [[getattr(row, field.name) for field in imageIO.imageSchema]]

# test: predictor vs keras on resized images
# Run sparkDL inceptionV3 transformer on resized images and compare result to above keras
# result.
rdd = self.sc.parallelize([rowWithImage(img) for img in imageArray])
dfType = StructType([StructField("image", imageIO.imageSchema)])
imageDf = rdd.toDF(dfType)
Expand Down

0 comments on commit 786ee77

Please sign in to comment.