Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit b011ecc

Browse files
lanking520nswamy
authored andcommitted
[MXNET-357] New Scala API Design (Symbol) (#10660)
* Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala
1 parent 8b53a3d commit b011ecc

File tree

8 files changed

+340
-89
lines changed

8 files changed

+340
-89
lines changed

scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala

+2
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,8 @@ object Symbol {
830830
private val functions: Map[String, SymbolFunction] = initSymbolModule()
831831
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
832832

833+
val api = SymbolAPI
834+
833835
def pow(sym1: Symbol, sym2: Symbol): Symbol = {
834836
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
835837
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.mxnet
18+
19+
20+
@AddSymbolAPIs(false)
21+
/**
22+
* typesafe Symbol API: Symbol.api._
23+
* Main code will be generated during compile time through Macros
24+
*/
25+
object SymbolAPI {
26+
}

scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala

+21-21
Original file line numberDiff line numberDiff line change
@@ -30,40 +30,40 @@ object TrainMnist {
3030
// multi-layer perceptron
3131
def getMlp: Symbol = {
3232
val data = Symbol.Variable("data")
33-
val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128))
34-
val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
35-
val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64))
36-
val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
37-
val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10))
38-
val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3))
33+
34+
val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
35+
val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
36+
val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = "fc2")
37+
val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
38+
val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = "fc3")
39+
val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
3940
mlp
4041
}
4142

4243
// LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
4344
// Haffner. "Gradient-based learning applied to document recognition."
4445
// Proceedings of the IEEE (1998)
46+
4547
def getLenet: Symbol = {
4648
val data = Symbol.Variable("data")
4749
// first conv
48-
val conv1 = Symbol.Convolution()()(
49-
Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20))
50-
val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh"))
51-
val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max",
52-
"kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
50+
val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20)
51+
val tanh1 = Symbol.api.tanh(data = Some(conv1))
52+
val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
53+
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
5354
// second conv
54-
val conv2 = Symbol.Convolution()()(
55-
Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50))
56-
val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh"))
57-
val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
58-
"kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
55+
val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50)
56+
val tanh2 = Symbol.api.tanh(data = Some(conv2))
57+
val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
58+
kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
5959
// first fullc
60-
val flatten = Symbol.Flatten()()(Map("data" -> pool2))
61-
val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 500))
62-
val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh"))
60+
val flatten = Symbol.api.Flatten(data = Some(pool2))
61+
val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
62+
val tanh3 = Symbol.api.tanh(data = Some(fc1))
6363
// second fullc
64-
val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 10))
64+
val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10)
6565
// loss
66-
val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))
66+
val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
6767
lenet
6868
}
6969

scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ object Base {
3737

3838
@throws(classOf[UnsatisfiedLinkError])
3939
private def tryLoadInitLibrary(): Unit = {
40-
val baseDir = System.getProperty("user.dir") + "/init-native"
40+
var baseDir = System.getProperty("user.dir") + "/init-native"
41+
// TODO(lanKing520) Update this to use relative path to the MXNet director.
42+
// TODO(lanking520) baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native"
43+
if (System.getenv().containsKey("MXNET_BASEDIR")) {
44+
baseDir = sys.env("MXNET_BASEDIR")
45+
}
4146
val os = System.getProperty("os.name")
4247
// ref: http://lopica.sourceforge.net/os.html
4348
if (os.startsWith("Linux")) {

scala-package/macros/pom.xml

+38
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,42 @@
5252
<type>${libtype}</type>
5353
</dependency>
5454
</dependencies>
55+
56+
<build>
57+
<plugins>
58+
<plugin>
59+
<groupId>org.apache.maven.plugins</groupId>
60+
<artifactId>maven-jar-plugin</artifactId>
61+
<configuration>
62+
<excludes>
63+
<exclude>META-INF/*.SF</exclude>
64+
<exclude>META-INF/*.DSA</exclude>
65+
<exclude>META-INF/*.RSA</exclude>
66+
</excludes>
67+
</configuration>
68+
</plugin>
69+
<plugin>
70+
<groupId>org.apache.maven.plugins</groupId>
71+
<artifactId>maven-compiler-plugin</artifactId>
72+
</plugin>
73+
<plugin>
74+
<groupId>org.scalatest</groupId>
75+
<artifactId>scalatest-maven-plugin</artifactId>
76+
<configuration>
77+
<environmentVariables>
78+
<MXNET_BASEDIR>${project.parent.basedir}/init-native</MXNET_BASEDIR>
79+
</environmentVariables>
80+
<argLine>
81+
-Djava.library.path=${project.parent.basedir}/native/${platform}/target \
82+
-Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
83+
</argLine>
84+
</configuration>
85+
</plugin>
86+
<plugin>
87+
<groupId>org.scalastyle</groupId>
88+
<artifactId>scalastyle-maven-plugin</artifactId>
89+
</plugin>
90+
</plugins>
91+
</build>
92+
5593
</project>

0 commit comments

Comments
 (0)