This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
MXNetError: unknown type for MKLDNN :2 when training Mask RCNN with mxnet-cu101==1.7.0 #19631
Closed
Description
Description
- The GluonCV Mask RCNN script with and without horovod fails with
MXNetError: unknown type for MKLDNN :2
issue usingmxnet-cu101==1.7.0
Error Message
Traceback (most recent call last):
File "/shared/mx_170_mkl_env/lib/python3.8/multiprocessing/pool.py", line 125, in worker
result = (True, func(*args, **kwds))
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 429, in _worker_fn
batch = batchify_fn([_worker_dataset[i] for i in samples])
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 429, in <listcomp>
batch = batchify_fn([_worker_dataset[i] for i in samples])
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataset.py", line 219, in __getitem__
return self._fn(*item)
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/gluoncv-0.8.0-py3.8-linux-x86_64.egg/gluoncv/data/transforms/presets/rcnn.py", line 407, in __call__
cls_target, box_target, box_mask = self._target_generator(
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/block.py", line 747, in __call__
out = self.forward(*args)
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/gluoncv-0.8.0-py3.8-linux-x86_64.egg/gluoncv/model_zoo/rcnn/rpn/rpn_target.py", line 157, in forward
ious = mx.nd.contrib.box_iou(anchor, bbox, format='corner').asnumpy()
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/ndarray/ndarray.py", line 2563, in asnumpy
check_call(_LIB.MXNDArraySyncCopyToCPU(
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/base.py", line 246, in check_call
raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
File "src/ndarray/./../operator/tensor/.././../common/../operator/nn/mkldnn/mkldnn_base-inl.h", line 246
MXNetError: unknown type for MKLDNN :2
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/shared/gluoncv_master/scripts/instance/mask_rcnn/train_mask_rcnn.py", line 737, in <module>
train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args)
File "/shared/gluoncv_master/scripts/instance/mask_rcnn/train_mask_rcnn.py", line 559, in train
next_data_batch = next(train_data_iter)
File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 484, in __next__
batch = pickle.loads(ret.get(self._timeout))
File "/shared/mx_170_mkl_env/lib/python3.8/multiprocessing/pool.py", line 771, in get
raise self._value
mxnet.base.MXNetError: Traceback (most recent call last):
File "src/ndarray/./../operator/tensor/.././../common/../operator/nn/mkldnn/mkldnn_base-inl.h", line 246
MXNetError: unknown type for MKLDNN :2
GluonCV: v0.9.0
Horovod: v0.21.0
To Reproduce
Without Horovod
python gluon-cv/scripts/instance/mask_rcnn/train_mask_rcnn.py --gpus 0,1,2,3,4,5,6,7 --num-workers 4 --amp --lr-decay-epoch 8,10 --epochs 6 --log-interval 10 --val-interval 12 --batch-size 8 --use-fpn --lr 0.01 --lr-warmup-factor 0.001 --lr-warmup 1600 --static-alloc --clip-gradient 1.5 --use-ext --seed 987
Full log: https://gist.github.com/karan6181/efa4ad8f61c3e21cbee9c55fea98b2f0
With Horovod
horovodrun -np 8 -H localhost:8 python gluon-cv/scripts/instance/mask_rcnn/train_mask_rcnn.py --horovod --num-workers 4 --amp --lr-decay-epoch 8,10 --epochs 6 --log-interval 10 --val-interval 12 --batch-size 8 --use-fpn --lr 0.01 --lr-warmup-factor 0.001 --lr-warmup 1600 --static-alloc --clip-gradient 1.5 --use-ext --seed 987
Environment
We recommend using our script for collecting the diagnostic information with the following command
curl --retry 10 -s https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py | python3
----------Python Info----------
Version : 3.8.5
Compiler : GCC 7.3.0
Build : ('default', 'Sep 4 2020 07:30:14')
Arch : ('64bit', 'ELF')
------------Pip Info-----------
Version : 20.2.4
Directory : /shared/mx_170_mkl_env/lib/python3.8/site-packages/pip
----------MXNet Info-----------
Version : 1.7.0
Directory : /shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet
Commit Hash : 64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
Library : ['/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/libmxnet.so']
Build features:
✔ CUDA
✔ CUDNN
✔ NCCL
✔ CUDA_RTC
✖ TENSORRT
✔ CPU_SSE
✔ CPU_SSE2
✔ CPU_SSE3
✔ CPU_SSE4_1
✔ CPU_SSE4_2
✖ CPU_SSE4A
✔ CPU_AVX
✖ CPU_AVX2
✔ OPENMP
✖ SSE
✔ F16C
✖ JEMALLOC
✔ BLAS_OPEN
✖ BLAS_ATLAS
✖ BLAS_MKL
✖ BLAS_APPLE
✔ LAPACK
✔ MKLDNN
✔ OPENCV
✖ CAFFE
✖ PROFILER
✔ DIST_KVSTORE
✖ CXX14
✖ INT64_TENSOR_SIZE
✔ SIGNAL_HANDLER
✖ DEBUG
✖ TVM_OP
----------System Info----------
Platform : Linux-4.15.0-1060-aws-x86_64-with-glibc2.10
system : Linux
node : ip-192-168-70-159
release : 4.15.0-1060-aws
version : #62-Ubuntu SMP Tue Feb 11 21:23:22 UTC 2020
----------Hardware Info----------
machine : x86_64
processor : x86_64
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz
Stepping: 4
CPU MHz: 1200.134
BogoMIPS: 4999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 33792K
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0025 sec, LOAD: 0.4890 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0145 sec, LOAD: 0.0717 sec.
Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1123)>, DNS finished in 0.20147395133972168 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0093 sec, LOAD: 0.4630 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0033 sec, LOAD: 0.0823 sec.
Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.0021767616271972656 sec.
----------Environment----------
KMP_DUPLICATE_LIB_OK="True"