Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix Batch input issue with Scala Benchmark (#12848)
Browse files Browse the repository at this point in the history
* add initial change

* add fix

* improved usage of Shape as well as warning message on performance

* change into parallel

* drop dropBack

* apply Andrew's comments

* remove add dim inside img 2 pixel

* addressed Naveen's comment

* update comments
  • Loading branch information
lanking520 authored and nswamy committed Oct 21, 2018
1 parent d8c7375 commit 58f4117
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,19 @@ class ShapeSuite extends FunSuite with BeforeAndAfterAll {
assert(Shape(1, 2, 3) === Shape(1, 2, 3))
assert(Shape(1, 2) != Shape(1, 2, 3))
}

test("drop") {
val s = Shape(1, 2, 3)
val s2 = s.drop(1)
assert(s == Shape(1, 2, 3))
assert(s2 == Shape(2, 3))
val s3 = s.drop(2)
assert(s3 == Shape(3))
}

test("slice") {
val s = Shape(1, 2, 3)
val s2 = s.slice(0, 1)
assert(s2 == Shape(1))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.mxnet._

trait InferBase {

def loadModel(context: Array[Context]): Any
def loadModel(context: Array[Context], batchInference : Boolean): Any
def loadSingleData(): Any
def loadBatchFileList(batchSize: Int): List[Any]
def loadInputBatch(source: Any): Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ object ScalaInferenceBenchmark {

private val logger = LoggerFactory.getLogger(classOf[CLIParserBase])

def loadModel(objectToRun: InferBase, context: Array[Context]):
def loadModel(objectToRun: InferBase, context: Array[Context], batchInference : Boolean):
Any = {
objectToRun.loadModel(context)
objectToRun.loadModel(context, batchInference)
}

def loadDataSet(objectToRun: InferBase):
Expand Down Expand Up @@ -134,7 +134,7 @@ object ScalaInferenceBenchmark {
logger.info("Running single inference call")
// Benchmarking single inference call
NDArrayCollector.auto().withScope {
val loadedModel = loadModel(exampleToBenchmark, context)
val loadedModel = loadModel(exampleToBenchmark, context, false)
val dataSet = loadDataSet(exampleToBenchmark)
val inferenceTimes = runInference(exampleToBenchmark, loadedModel, dataSet, baseCLI.count)
printStatistics(inferenceTimes, "single_inference")
Expand All @@ -144,7 +144,7 @@ object ScalaInferenceBenchmark {
logger.info("Running for batch inference call")
// Benchmarking batch inference call
NDArrayCollector.auto().withScope {
val loadedModel = loadModel(exampleToBenchmark, context)
val loadedModel = loadModel(exampleToBenchmark, context, true)
val batchDataSet = loadBatchDataSet(exampleToBenchmark, baseCLI.batchSize)
val inferenceTimes = runBatchInference(exampleToBenchmark, loadedModel, batchDataSet)
printStatistics(inferenceTimes, "batch_inference")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,11 @@ class CLIParser extends CLIParserBase{

class ImageClassifierExample(CLIParser: CLIParser) extends InferBase{

override def loadModel(context: Array[Context]): Classifier = {
override def loadModel(context: Array[Context],
batchInference : Boolean = false): Classifier = {
val dType = DType.Float32
val inputShape = Shape(1, 3, 224, 224)
val batchSize = if (batchInference) CLIParser.batchSize else 1
val inputShape = Shape(batchSize, 3, 224, 224)

val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ class CLIParser extends CLIParserBase {

class SSDClassifierExample(CLIParser: CLIParser)
extends InferBase {
override def loadModel(context: Array[Context]): Any = {
override def loadModel(context: Array[Context], batchInference: Boolean = false): Any = {
val dType = DType.Float32
val inputShape = Shape(1, 3, 512, 512)
val batchSize = if (batchInference) CLIParser.batchSize else 1
val inputShape = Shape(batchSize, 3, 512, 512)
val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, "NCHW"))
new ObjectDetector(CLIParser.modelPathPrefix, inputDescriptors, context)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class TestCharRnn(CLIParser: CLIParser) extends InferBase {

private var vocab : Map[String, Int] = null

override def loadModel(context: Array[Context]): Any = {
override def loadModel(context: Array[Context], batchInference : Boolean = false): Any = {
val batchSize = 32
val buckets = List(129)
val numHidden = 512
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,16 @@ class ImageClassifier(modelPathPrefix: String,
topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] = {

val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height)
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
val imageShape = inputShape.drop(1)
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
val imgWithBatchNum = NDArray.api.expand_dims(pixelsNDArray, 0)
inputImage.flush()
scaledImage.flush()
handler.execute(pixelsNDArray.dispose())

val output = super.classifyWithNDArray(IndexedSeq(pixelsNDArray), topK)
val output = super.classifyWithNDArray(IndexedSeq(imgWithBatchNum), topK)

handler.execute(pixelsNDArray.dispose())
handler.execute(imgWithBatchNum.dispose())

IndexedSeq(output(0))
}
Expand All @@ -97,14 +100,16 @@ class ImageClassifier(modelPathPrefix: String,
def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
IndexedSeq[IndexedSeq[(String, Float)]] = {

val imageBatch = ListBuffer[NDArray]()
for (image <- inputBatch) {
val scaledImage = ImageClassifier.reshapeImage(image, width, height)
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
imageBatch += pixelsNDArray
}
val inputBatchSeq = inputBatch.toIndexedSeq
val imageBatch = inputBatchSeq.indices.par.map(idx => {
val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height)
val imageShape = inputShape.drop(1)
val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
val imgWithBatch = NDArray.api.expand_dims(imgND, 0).get
handler.execute(imgND.dispose())
imgWithBatch
}).toList
val op = NDArray.concatenate(imageBatch)

val result = super.classifyWithNDArray(IndexedSeq(op), topK)
handler.execute(op.dispose())
handler.execute(imageBatch.foreach(_.dispose()))
Expand Down Expand Up @@ -147,9 +152,9 @@ object ImageClassifier {
* returned by this method after the use.
* </p>
* @param resizedImage BufferedImage to get pixels from
* @param inputImageShape Input shape; for example for resnet it is (1,3,224,224).
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
Should be same as inputDescriptor shape.
* @return NDArray pixels array
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
*/
def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
// Get height and width of the image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ class ObjectDetector(modelPathPrefix: String,
: IndexedSeq[IndexedSeq[(String, Array[Float])]] = {

val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height)
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
val output = objectDetectWithNDArray(IndexedSeq(pixelsNDArray), topK)
val imageShape = inputShape.drop(1)
val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
val pixelsNDWithBatch = NDArray.api.expand_dims(pixelsNDArray, 0)
handler.execute(pixelsNDArray.dispose())
val output = objectDetectWithNDArray(IndexedSeq(pixelsNDWithBatch), topK)
handler.execute(pixelsNDWithBatch.dispose())
output
}

Expand Down Expand Up @@ -147,13 +150,16 @@ class ObjectDetector(modelPathPrefix: String,
def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
IndexedSeq[IndexedSeq[(String, Array[Float])]] = {

val imageBatch = ListBuffer[NDArray]()
for (image <- inputBatch) {
val scaledImage = ImageClassifier.reshapeImage(image, width, height)
val pixelsNdarray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape)
imageBatch += pixelsNdarray
}
val op = NDArray.concatenate(imageBatch)
val inputBatchSeq = inputBatch.toIndexedSeq
val imageBatch = inputBatchSeq.indices.par.map(idx => {
val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height)
val imageShape = inputShape.drop(1)
val pixelsND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
val pixelsNDWithBatch = NDArray.api.expand_dims(pixelsND, 0).get
handler.execute(pixelsND.dispose())
pixelsNDWithBatch
})
val op = NDArray.concatenate(imageBatch.toList)

val result = objectDetectWithNDArray(IndexedSeq(op), topK)
handler.execute(op.dispose())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class Predictor(modelPathPrefix: String,

// rebind with the new batchSize
if (batchSize != inputBatchSize) {
logger.info(s"Latency increased due to batchSize mismatch $batchSize vs $inputBatchSize")
val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), f.dtype, f.layout) )
mxNetHandler.execute(mod.bind(desc, forceRebind = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
val image1 = new BufferedImage(100, 200, BufferedImage.TYPE_BYTE_GRAY)
val image2 = ImageClassifier.reshapeImage(image1, 2, 2)

val result = ImageClassifier.bufferedImageToPixels(image2, Shape(1, 3, 2, 2))
val result = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2))

assert(result.shape == inputDescriptor(0).shape)
assert(result.shape == inputDescriptor(0).shape.drop(1))
}

test("ImageClassifierSuite-testWithInputImage") {
Expand Down

0 comments on commit 58f4117

Please sign in to comment.