This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Incompatible input shape #16620
Open
Description
Description
I'm using scala / java 1.5.1 mxnet-full_2.11-osx-x86_64-cpu.
I believe I'm correctly specifying my input shape as (1, 0, 1, 5961). But I'm getting:
Incompatible input shape: expected [1,-1,1,1], got [1,-1,1,64]
I believe it has something to do with my number of factors which is 64.
I've tested the same model in Python and it works perfect.
Here's my code:
List<Context> context = new ArrayList<>();
List<DataDesc> inputDesc = new ArrayList<>();
inputDesc.add(new DataDesc("data",
new Shape(new int[]{1, 0, 1, 5961}),
DType.Float32(),
Layout.NCHW()));
context.add(Context.cpu());
Predictor predictor = new Predictor("/tmp/model", inputDesc, context, 0);
Here's my symbols json:
"nodes": [
{
"op": "null",
"name": "data",
"attrs": {"__storage_type__": "2"},
"inputs": []
},
{
"op": "null",
"name": "w1_weight",
"attrs": {
"__init__": "[\"normal\", {\"sigma\": 0.01}]",
"__lr_mult__": "0.001",
"__shape__": "(5961, 1)",
"__storage_type__": "1",
"__wd_mult__": "0.001"
},
"inputs": []
},
{
"op": "dot",
"name": "dot0",
"inputs": [[0, 0, 0], [1, 0, 0]]
},
{
"op": "null",
"name": "w0_weight",
"attrs": {
"__init__": "[\"normal\", {\"sigma\": 0.01}]",
"__lr_mult__": "0.01",
"__shape__": "(1,)",
"__wd_mult__": "0.01"
},
"inputs": []
},
{
"op": "broadcast_add",
"name": "broadcast_plus0",
"inputs": [[2, 0, 0], [3, 0, 0]]
},
{
"op": "null",
"name": "v",
"attrs": {
"__init__": "[\"normal\", {\"sigma\": 0.001}]",
"__lr_mult__": "0.0001",
"__shape__": "(5961, 64)",
"__storage_type__": "1",
"__wd_mult__": "1e-05"
},
"inputs": []
},
{
"op": "dot",
"name": "dot2",
"inputs": [[0, 0, 0], [5, 0, 0]]
},
{
"op": "square",
"name": "Square",
"inputs": [[6, 0, 0]]
},
{
"op": "_mul_scalar",
"name": "_mulscalar0",
"attrs": {"scalar": "0.5"},
"inputs": [[7, 0, 0]]
},
{
"op": "Concat",
"name": "concat0",
"attrs": {
"dim": "1",
"num_args": "2"
},
"inputs": [[4, 0, 0], [8, 0, 0]]
},
{
"op": "sum",
"name": "sum0",
"attrs": {
"axis": "1",
"keepdims": "True"
},
"inputs": [[9, 0, 0]]
},
{
"op": "square",
"name": "x_square",
"inputs": [[0, 0, 0]]
},
{
"op": "_square_sum",
"name": "_square_sum0",
"attrs": {
"axis": "1",
"keepdims": "True"
},
"inputs": [[5, 0, 0]]
},
{
"op": "dot",
"name": "dot1",
"inputs": [[11, 0, 0], [12, 0, 0]]
},
{
"op": "negative",
"name": "negative0",
"inputs": [[13, 0, 0]]
},
{
"op": "_mul_scalar",
"name": "_mulscalar1",
"attrs": {"scalar": "0.5"},
"inputs": [[14, 0, 0]]
},
{
"op": "elemwise_add",
"name": "Final_Summation",
"inputs": [[10, 0, 0], [15, 0, 0]]
},
{
"op": "null",
"name": "out_label",
"inputs": []
},
{
"op": "LinearRegressionOutput",
"name": "out",
"inputs": [[16, 0, 0], [17, 0, 0]]
}
],
"arg_nodes": [0, 1, 3, 5, 17],
"node_row_ptr": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19
],
"heads": [[18, 0, 0]],
"attrs": {"mxnet_version": ["int", 10100]}
}
And here's the full output:
22:36:10.674 [org.apache.mxnet.infer.MXNetThreadPoolHandler-0] INFO MXNetJVM::tryLoadLibraryOS - Try loading mxnet-scala from native path.
22:36:10.678 [org.apache.mxnet.infer.MXNetThreadPoolHandler-0] WARN MXNetJVM::<init> - MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path].
22:36:10.679 [org.apache.mxnet.infer.MXNetThreadPoolHandler-0] WARN MXNetJVM::<init> - LD_LIBRARY_PATH=null
22:36:10.680 [org.apache.mxnet.infer.MXNetThreadPoolHandler-0] WARN MXNetJVM::<init> - java.library.path=/Users/shearn/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.
22:36:10.689 [org.apache.mxnet.infer.MXNetThreadPoolHandler-0] INFO org.apache.mxnet.util.NativeLibraryLoader::loadLibrary - Replaced .dylib with .jnilib
[22:36:11] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v1.1.0. Attempting to upgrade...
[22:36:11] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded!
22:36:11.880 [org.apache.mxnet.infer.MXNetThreadPoolHandler-0] WARN org.apache.mxnet.DataDesc::getBatchAxis - Found Undefined Layout, will use default index 0 for batch axis
Exception in thread "main" org.apache.mxnet.MXNetError: Error in operator concat0: [22:36:11] src/operator/nn/concat.cc:67: Check failed: shape_assign(&(*in_shape)[i], dshape): Incompatible input shape: expected [1,-1,1,1], got [1,-1,1,64]
Stack trace:
[bt] (0) 1 libmxnet.so 0x000000012d540509 mxnet::op::MKLDNNActivationBackward(nnvm::NodeAttrs const&, mxnet::OpContext const&, mxnet::NDArray const&, mxnet::NDArray const&, mxnet::OpReqType const&, mxnet::NDArray const&) + 9113
[bt] (1) 2 libmxnet.so 0x000000012d91a129 mxnet::op::SupportMKLDNNConcat(std::__1::vector<mxnet::NDArray, std::__1::allocator<mxnet::NDArray> > const&) + 7977
[bt] (2) 3 libmxnet.so 0x000000012edaad39 std::__1::__tree<std::__1::__value_type<unsigned long, mxnet::NDArray>, std::__1::__map_value_compare<unsigned long, std::__1::__value_type<unsigned long, mxnet::NDArray>, std::__1::less<unsigned long>, true>, std::__1::allocator<std::__1::__value_type<unsigned long, mxnet::NDArray> > >::erase(std::__1::__tree_const_iterator<std::__1::__value_type<unsigned long, mxnet::NDArray>, std::__1::__tree_node<std::__1::__value_type<unsigned long, mxnet::NDArray>, void*>*, long>) + 50089
[bt] (3) 4 libmxnet.so 0x000000012eda191a std::__1::__tree<std::__1::__value_type<unsigned long, mxnet::NDArray>, std::__1::__map_value_compare<unsigned long, std::__1::__value_type<unsigned long, mxnet::NDArray>, std::__1::less<unsigned long>, true>, std::__1::allocator<std::__1::__value_type<unsigned long, mxnet::NDArray> > >::erase(std::__1::__tree_const_iterator<std::__1::__value_type<unsigned long, mxnet::NDArray>, std::__1::__tree_node<std::__1::__value_type<unsigned long, mxnet::NDArray>, void*>*, long>) + 12170
[bt] (4) 5 libmxnet.so 0x000000012ed30616 MXSymbolInferShapeEx + 2422
[bt] (5) 6 mxnet-scala 0x000000012cda054d Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile + 957
[bt] (6) 7 mxnet-scala 0x000000012cda08b3 Java_org_apache_mxnet_LibInfo_mxSymbolInferShape + 195
[bt] (7) 8 ??? 0x0000000112ac8667 0x0 + 4608263783
at org.apache.mxnet.Base$.checkCall(Base.scala:111)
at org.apache.mxnet.Symbol.inferShapeImpl(Symbol.scala:323)
at org.apache.mxnet.Symbol.inferShape(Symbol.scala:291)
at org.apache.mxnet.Symbol.inferShape(Symbol.scala:286)
at org.apache.mxnet.module.DataParallelExecutorGroup.org$apache$mxnet$module$DataParallelExecutorGroup$$bindIthExec(DataParallelExecutorGroup.scala:637)
at org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$bindExec$2.apply(DataParallelExecutorGroup.scala:384)
at org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$bindExec$2.apply(DataParallelExecutorGroup.scala:383)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.immutable.Range.foreach(Range.scala:160)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.AbstractTraversable.map(Traversable.scala:104)
at org.apache.mxnet.module.DataParallelExecutorGroup.bindExec(DataParallelExecutorGroup.scala:383)
at org.apache.mxnet.module.DataParallelExecutorGroup.<init>(DataParallelExecutorGroup.scala:323)
at org.apache.mxnet.module.DataParallelExecutorGroup$Builder.build(DataParallelExecutorGroup.scala:225)
at org.apache.mxnet.module.Module.bind(Module.scala:285)
at org.apache.mxnet.infer.Predictor$$anonfun$loadModule$1.apply$mcV$sp(Predictor.scala:258)
at org.apache.mxnet.infer.Predictor$$anonfun$loadModule$1.apply(Predictor.scala:258)
at org.apache.mxnet.infer.Predictor$$anonfun$loadModule$1.apply(Predictor.scala:258)
at org.apache.mxnet.infer.MXNetThreadPoolHandler$$anon$4.call(MXNetHandler.scala:83)
at java.util.concurrent.FutureTask.run(FutureTask.java:266)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)