Skip to content

Commit

Permalink
Use new Symbol APIs in Profiler example (apache#11477)
Browse files Browse the repository at this point in the history
* Adding test for scala profiler

* Separated profiler tests into their own methods

* Moved profiler test out of infer folder
  • Loading branch information
andrewfayres authored and nswamy committed Jul 22, 2018
1 parent 049d048 commit 60b6ab6
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object ProfilerMatMul {

val A = Symbol.Variable("A")
val B = Symbol.Variable("B")
val C = Symbol.dot()(A, B)()
val C = Symbol.api.dot(Some(A), Some(B))

val executor = C.simpleBind(ctx, "write",
Map("A" -> Shape(4096, 4096), "B" -> Shape(4096, 4096)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object ProfilerNDArray {
val randomRet = (0 until shape.product)
.map(r => scala.util.Random.nextFloat() - 0.5f).toArray
dat.set(randomRet)
val ndArrayRet = NDArray.broadcast_to(Map("shape" -> targetShape))(dat).get
val ndArrayRet = NDArray.api.broadcast_to(dat, Some(targetShape))
require(ndArrayRet.shape == targetShape)
val err = {
// implementation of broadcast
Expand Down Expand Up @@ -122,8 +122,8 @@ object ProfilerNDArray {
}

def reldiff(a: NDArray, b: NDArray): Float = {
val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
val norm = NDArray.sum(NDArray.abs(a)).toScalar
val diff = NDArray.api.sum(NDArray.api.abs(a - b)).toScalar
val norm = NDArray.api.sum(NDArray.api.abs(a)).toScalar
diff / norm
}

Expand Down Expand Up @@ -171,15 +171,15 @@ object ProfilerNDArray {
def testClip(): Unit = {
val shape = Shape(10)
val A = Random.uniform(-10f, 10f, shape)
val B = NDArray.clip(A, -2f, 2f)
val B = NDArray.api.clip(A, -2f, 2f)
val B1 = B.toArray
require(B1.forall { x => x >= -2f && x <= 2f })
}

def testDot(): Unit = {
val a = Random.uniform(-3f, 3f, Shape(3, 4))
val b = Random.uniform(-3f, 3f, Shape(4, 5))
val c = NDArray.dot(a, b)
val c = NDArray.api.dot(a, b)
val A = a.toArray.grouped(4).toArray
val B = b.toArray.grouped(5).toArray
val C = (Array[Array[Float]]() /: A)((acc, row) => acc :+ row.zip(B).map(z =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.profiler

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
import java.io.File

import org.apache.mxnet.Profiler
import org.apache.mxnet.Context

/**
* Integration test for imageClassifier example.
* This will run as a part of "make scalatest"
*/
class ProfilerSuite extends FunSuite with BeforeAndAfterAll {
private val logger = LoggerFactory.getLogger(classOf[ProfilerSuite])

override def beforeAll(): Unit = {
logger.info("Running profiler test...")
val eray = new ProfilerNDArray
val path = System.getProperty("java.io.tmpdir")
val kwargs = Map("file_name" -> path)
logger.info(s"profile file save to $path")

Profiler.profilerSetState("run")
}

override def afterAll(): Unit = {
Profiler.profilerSetState("stop")
}

test("Profiler Broadcast test") {
ProfilerNDArray.testBroadcast()
}

test("Profiler NDArray Saveload test") {
ProfilerNDArray.testNDArraySaveload()
}

test("Profiler NDArray Copy") {
ProfilerNDArray.testNDArrayCopy()
}

test("Profiler NDArray Negate") {
ProfilerNDArray.testNDArrayNegate()
}

test("Profiler NDArray Scalar") {
ProfilerNDArray.testNDArrayScalar()
}

test("Profiler NDArray Onehot") {
ProfilerNDArray.testNDArrayOnehot()
}

test("Profiler Clip") {
ProfilerNDArray.testClip()
}

test("Profiler Dot") {
ProfilerNDArray.testDot()
}
}

0 comments on commit 60b6ab6

Please sign in to comment.