-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Allow stop of arange to be inferred from dims. #12064
Conversation
@anirudh2290 how long should I expect for the CI check to take? |
a5f0abb
to
6cdcae1
Compare
src/operator/tensor/init_op.h
Outdated
@@ -471,6 +473,11 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs, | |||
<< "Range does not support step=0, received " << param.step; | |||
CHECK(param.repeat > 0) | |||
<< "Range only supports repeat > 0, received " << param.repeat; | |||
if (param.start == param.stop.value()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
though I understand that end=None is already taken, this condition still feels a bit like a hack...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha what's the best way to proceed? it seems that everyone is very busy, so doing a discussion over PR comments might take a very long time. ideal would be some kind of realtime chat where we can discuss design alternatives and quickly come to a consensus. is there such a forum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if that is not desired, here's a simple proposal: introduce a new parameter called "infer_range" that defaults to false. if it is true, then exactly one of param.stop
, param.start
, or param.step
can be None, and will be inferred from the others and the output dimensions. I may only implement something more limited that makes param.stop
be the only inferrable parameter, and leave that to others to implement in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a flag sounds good to me
@szha implemented! |
@yzhliu @nswamy @anirudh2290 hi folks! it would be fantastic if we could get an idea of when you'll be able to review this PR further... it will help us plan other work that depend on this inside our company. |
:as opts}] | ||
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype)) | ||
(NDArray/arange (float start) ($/option (float stop)) step repeat infer-range ctx dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gigasquid could you help review this part?
:as opts}] | ||
(Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype)) | ||
(Symbol/arange (float start) ($/option (float stop)) step repeat infer-range nil dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python/mxnet/ndarray/ndarray.py
Outdated
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, | ||
dtype=dtype, ctx=str(ctx)) | ||
return _internal._arange(start=start, stop=stop, step=step, infer_range=infer_range, | ||
repeat=repeat, dtype=dtype, ctx=str(ctx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: infer_range and repeat keyword arguments swapped place (not that it matters)
val params = Map("start" -> start, "step" -> step, | ||
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString()) | ||
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, | ||
"infer_range" -> inferRange, "ctx" -> ctx.toString, "dtype" -> dType.toString()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lanking520 could you help review this part?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. @taliesinb Could you please try to move the newly introduced param to the end of the function in order to bring backward compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update the doc string for this new parameter. same for symbol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually prefer you create a new arange
method with inferRange
without default in the order you have here, the existing one one call with the default value(false).
Almost all of the methods have context as the last parameter, this one could cause confusion.
val params = Map("start" -> start, "step" -> step, | ||
"repeat" -> repeat, "dtype" -> dType.toString()) | ||
"repeat" -> repeat, "infer_range" -> inferRange, "dtype" -> dType.toString()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@taliesinb Same applies to what I have mentioned above
val params = Map("start" -> start, "step" -> step, | ||
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString()) | ||
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, | ||
"infer_range" -> inferRange, "ctx" -> ctx.toString, "dtype" -> dType.toString()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. @taliesinb Could you please try to move the newly introduced param to the end of the function in order to bring backward compatibility?
val params = Map("start" -> start, "step" -> step, | ||
"repeat" -> repeat, "dtype" -> dType.toString()) | ||
"repeat" -> repeat, "infer_range" -> inferRange, "dtype" -> dType.toString()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@taliesinb Same applies to what I have mentioned above
b82fd38
to
56bb7fb
Compare
Enabled via a flag.
56bb7fb
to
10851d2
Compare
@lanking520 @nswamy I've implemented the suggestion to have a second operator instead of passing this as an option. It wasn't clear whether this was desired just for Scala or for both Scala and Clojure, so I did it for both. If CI passes then this should be ready for merge unless there are more comments. |
(arange start stop {}))) | ||
|
||
(defn arange-with-inference | ||
"Behaves like arange operator, but infers the stop value from the output shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job adding the clojure function with nice documentation. 👍 If you are feeling up to it you could also add the corresponding test for it here https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj#L141
(arange start stop {}))) | ||
|
||
(defn arange-with-inference | ||
"Behaves like arange operator, but infers the stop value from the output shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here - Great job adding the clojure functions. If you want to add the corresponding test that would be awesome too https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj#L214. It can also be done in a follow up PR if that works better 😸
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gigasquid happy to try to do that! i haven't ever used Clojure or Scala before. But I've run into a problem even getting the first step, make scalapkg
, to work on macOS. The make initially failed, comlpaining it couldn't find the mvn executable. I assumed that was maven, brew install maven
had me first brew install java
. Then make scalapkg
seemed to be happy, and downloaded a bunch of stuff (including scala). But it failed with this:
[INFO] /Users/taliesinb/git/MXNet/scala-package/init/src/main/scala:-1: info: compiling
[INFO] Compiling 2 source files to /Users/taliesinb/git/MXNet/scala-package/init/target/classes at 1534710433224
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/modules/scala-xml_2.11/1.0.4/scala-xml_2.11-1.0.4.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/modules/scala-xml_2.11/1.0.4/scala-xml_2.11-1.0.4.jar (648 kB at 755 kB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.4/scala-library-2.11.4.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.4/scala-library-2.11.4.jar (5.5 MB at 1.3 MB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.6/scala-library-2.11.6.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.6/scala-library-2.11.6.jar (5.6 MB at 1.2 MB/s)
[INFO] compiler plugin: BasicArtifact(org.scalamacros,paradise_2.11.8,2.1.0)
Downloading from central: https://repo.maven.apache.org/maven2/org/scalamacros/paradise_2.11.8/2.1.0/paradise_2.11.8-2.1.0.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scalamacros/paradise_2.11.8/2.1.0/paradise_2.11.8-2.1.0.jar (271 kB at 397 kB/s)
[ERROR] error: scala.reflect.internal.MissingRequirementError: object java.lang.Object in compiler mirror not found.
[ERROR] at scala.reflect.internal.MissingRequirementError$.signal(MissingRequirementError.scala:17)
[ERROR] at scala.reflect.internal.MissingRequirementError$.notFound(MissingRequirementError.scala:18)
[INFO] at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:53)
This seems to be related to https://issues.scala-lang.org/browse/SI-9103, but that issue is still open. I have no idea whats going on or how to make progress. I reran the make scalapkg
to no effect, here's a gist with the full output: https://gist.github.com/taliesinb/d0f09e9f0202c3983298511383542f59. Do you have any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@taliesinb I'm impressed that you jumped in there on the Clojure and Scala 💯 - From the issue it seems like it is the JDK you are using. Using JDK 8 should solve the problems. If you have multiple versions of the JDK installed, you should just be able to switch by using an export of the right JAVA_HOME see here. Give it a try and see how it goes. If you don't want to hold up this PR, I'd be happy to assist on a follow up PR if you'd like 😸
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gigasquid great that worked! thanks for the help! I'll keep you posted as I (hopefully) make progress with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@taliesinb are you making a new PR for Clojure or do you want to make changes to this one ? This is good for Scala APIs.
Thanks for the great work 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm @gigasquid I'm finding the instructions in the README.md
file a little unclear in places. For example, under "Build from MXNET Source", I find this instruction a bit cryptic:
then replace the correct jar for your architecture in the project.clj, example
[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]
I would find this easier to understand if it was very explicit, such as "replace X with Y in section Z".
Here is what my project.clj contained out of the box:
;; Jars from Nexus
;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"]
;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.2.1"]
;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"]
;;; CI
[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]
At this point, not knowing what to replace with what, I read the section "Cloning the repo and running from source", which mentions uncommenting rather than replacing. That section is also a bit confusing:
you will need to replace the native version of the line in the project dependencies with your configuration.
Which line? What is a "native version of the line"? Perhaps it could say "you will need to find and uncomment the appropriate line in the dependencies section of the project.clj file, and comment the rest". We could also make the project.clj section clearer so its more obvious what to do:
;; default behavior, to be used by the CI bot on github; comment this line and
;; uncomment the appropriate line in one of the other sections
[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]
;; use a prebuilt JAR from Nexus
;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.2.1"]
;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"]
;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"]
;;; build a local JAR from source
;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]
;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.3.0-SNAPSHOT"]
;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]
Now, while the instructions could be a bit clearer, I did figure out the point eventually, and so I tried adding this line and commenting the rest:
[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]
After running lein clean
and lein test
I get this:
Generating symbol file
INFO MXNetJVM: Try loading mxnet-scala from native path.
INFO MXNetJVM: Try loading mxnet-scala-osx-x86_64-gpu from native path.
INFO MXNetJVM: Try loading mxnet-scala-osx-x86_64-cpu from native path.
WARN MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path].
INFO org.apache.mxnet.util.NativeLibraryLoader: Replaced .dylib with .jnilib
INFO org.apache.mxnet.util.NativeLibraryLoader: Loading libmxnet-scala.jnilib from /lib/native/ copying to mxnet-scala
[2
That WARN makes it sound like it's not using the library I built earlier using make scalainstall
, which will mean I can't actually test my new functionality! Wasn't make scalainstall
supposed to make the MXNet scala libraries available for everyone on my system?
How should I fix this?
Also, an ergonomics question: the test suite takes a while to run. With Python, it was very easy to just run the new tests I added using e.g. nosetests -v tests/python/unittest/test_operator.py
. Is there a similar incantation for Clojure?
Thanks in advance for your help!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nswamy I'm now skeptical of my Scala changes. For example, I'm not sure that the additions to arange
in NDArray.scala
are correct. My concern is that the only way that you can even use the new inference feature is via backward inference, because the inference is based on the output shape of the tensor produced by arange
, which must be inferred from a different part of the graph. So unless I'm missing something, calling the imperative arange
function with infer_range = true
will always fail as it has to produce an NDArray
immediately, but this is not possible because backward inference is only relevant for symbols.
The new functionality should work in the symbolic context, however.
EDIT: to answer your original question, I'd prefer to add a Scala test to verify this last claim! If i did the wrong thing before... I don't trust myself now. Plus, if you agree, I should delete the imperative version of the new arange
functionality from all language APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@taliesinb Thanks for the feedback on the wording. I'll update it to be more clear. It seems like you are doing everything exactly right. Once you do a scalainstall it will install locally in your maven a new 1.3.0-SNAPSHOT
. Since you updated your project.clj to use this, it will load up the updated jar. The WARN
is again misleading and can be improved, but it should be working :)
As far as running just one test, you can certainly do that with lein test :only org.apache.clojure-mxnet.ndarray-test
and lein test :only org.apache.clojure-mxnet.operator-test
.
Thanks again for the feedback and let me know if you have any other issues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gigasquid thanks for the info. I'll get back to this tomorrow hopefully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gigasquid ok we're good to go. I removed the pointless imperative function version of arange-with-inference
, so I only had to add a test to operator_test.clj
.
However, in doing this, I think I've picked up a problem with approx=
, in which it incorrectly returns true if one of the comparisands (is that a word??) is shorter than the other, and differs in the remaining elements that the other does not have.
For example, try change the test starting on line 200 to the following:
(deftest ones
(let [ones (sym/ones [2 2])
exec (sym/simple-bind ones (context/default-context))]
(is (approx= 1e-4
[1 1 1 1 9 9 9 9 9 9]
(-> exec (executor/forward) (executor/outputs) (first))))))
(I've introduced the 9 9 9 9 9 9 here). This test still passes.
I've reported the issue here: #12320, and fixed it in this PR. It doesn't produce any regressions, luckily!
If my new test looks good to you, we should be ready to merge!
@@ -407,11 +407,30 @@ object NDArray extends NDArrayBase { | |||
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. | |||
* @return NDArray of evenly spaced values in the specified range. | |||
*/ | |||
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, | |||
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@taliesinb Thanks for making this change. There is a compile error and CI is breaking.
I also want to slightly change this, I fixed the compile error and modified NDArray/Symbol and pushed a commit to your branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lanking520 FYI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nswamy oh thanks! my bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nswamy Looks good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
bf27a07
to
d28e068
Compare
Include a test of this fix as well.
042e23c
to
b22bb34
Compare
@taliesinb Looks great! Thanks for the Clojure tests and fixing the helper function too. |
Is there any more feedback? If not, I think this is good to merge. |
Thanks for the great work and patience @taliesinb |
* Allow stop of arange to be inferred from dims. Enabled via a flag. * modify NDArray/Symbol to add infer_range param * Add test for arange-with-inference. * Add a comment to readme about JDK 8. * Fix approx=. Include a test of this fix as well.
* Allow stop of arange to be inferred from dims. Enabled via a flag. * modify NDArray/Symbol to add infer_range param * Add test for arange-with-inference. * Add a comment to readme about JDK 8. * Fix approx=. Include a test of this fix as well.
Description
This PR adds the ability for an arange operator to leave the stop value unspecified, so that it will be inferred from the output shape (via backward shape inference). This is important to achieve shape polymorphism for efficient bucketing.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments