Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

customOp Exception: unknown storage type: -1 #16365

Open
jiashu-zhu opened this issue Oct 3, 2019 · 23 comments
Open

customOp Exception: unknown storage type: -1 #16365

jiashu-zhu opened this issue Oct 3, 2019 · 23 comments

Comments

@jiashu-zhu
Copy link

Description

I encounter a the exception "Exception: unknown storage type: -1" when I use my focal loss

my focal loss

the shape of out_data[0] is (batch_size, 2, anchor_num)
the shape of in_data[1] is (batch_size, anchor_num)

class FocalLossOperator(mx.operator.CustomOp):
    def __init__(self, gamma, alpha):
        super(FocalLossOperator, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, is_train, req, in_data, out_data, aux):
        #print('forward')
        #print(in_data[0].shape)
        y = mx.nd.exp(in_data[0] - mx.nd.max_axis(in_data[0], axis=1).reshape((in_data[0].shape[0], 1, -1)))
        y /= mx.nd.sum(y, axis=1).reshape((in_data[0].shape[0],1, -1))

        self.assign(out_data[0], req[0], y)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        y_numpy = out_data[0].asnumpy().transpose((0,2,1))
        label_numpy = in_data[1].asnumpy()
        y_numpy = y_numpy.reshape((-1,2))
        label_numpy = label_numpy.reshape((-1))
        #print(len(np.where(label_numpy == -1)[0]))
        indices = np.where(label_numpy == -1)[0]
        label_numpy[indices] = 0
        self.pro_truth = mx.nd.array(y_numpy[np.arange(y_numpy.shape[0]), label_numpy.astype(np.int)])

        print(len(indices))
        # i!=j
        pro_truth = (self.pro_truth + 1e-14).reshape((self.pro_truth.shape[0], 1))
        grad = self.alpha * mx.nd.power(1 - pro_truth, self.gamma - 1) * \
               (self.gamma * (-1 * pro_truth * mx.nd.array(y_numpy)) * mx.nd.log(pro_truth) + mx.nd.array(y_numpy) * (1 - pro_truth))

        # i==j
        pro_truth = self.pro_truth + 1e-14

        grad_numpy = grad.asnumpy()
        grad_numpy[np.arange(y_numpy.shape[0]), label_numpy.astype(np.int)] = (
                    self.alpha * mx.nd.power(1 - pro_truth, self.gamma) * (
                    self.gamma * pro_truth * mx.nd.log(pro_truth) + pro_truth - 1)).asnumpy()
        grad_numpy /= label_numpy.shape[0]
        grad_numpy[indices,:] = 0
        #grad_numpy = grad_numpy.reshape((out_data[0].shape[0],-1,out_data[0].shape[1])).transpose((0,2,1))
        grad = mx.nd.array(grad_numpy)
        grad = grad.reshape(out_data[0].shape[0],-1,out_data[0].shape[1]).transpose((0,2,1))

        self.assign(in_grad[0], req[0], grad)

@mx.operator.register('FocalLoss')
class FocalLossProp(mx.operator.CustomOpProp):
    def __init__(self, gamma, alpha):
        super(FocalLossProp, self).__init__(need_top_grad=False)

        self.gamma = float(gamma)
        self.alpha = float(alpha)

    def list_arguments(self):
        return ['data', 'labels']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shape):
        data_shape = in_shape[0]
        labels_shape = in_shape[1]
        out_shape = data_shape
        return [data_shape, labels_shape], [out_shape], []

    def create_operator(self, ctx, shapes, dtypes):
        return FocalLossOperator(self.gamma, self.alpha)

Error Message:

Error in CustomOp.backward: Traceback (most recent call last):
File "/home/anaconda2/lib/python2.7/site-packages/mxnet/operator.py", line 1020, in backward_entry
stype=stype))
File "/home/anaconda2/lib/python2.7/site-packages/mxnet/ndarray/sparse.py", line 1187, in _ndarray_cls
raise Exception("unknown storage type: %s"%stype)
Exception: unknown storage type: -1

terminate called after throwing an instance of 'dmlc::Error'
what(): [12:17:03] src/operator/custom/custom.cc:418: Check failed: reinterpret_cast(params.info->callbacks[kCustomOpBackward])( ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()), reinterpret_cast<const int*>(req.data()), static_cast(ctx.is_train), params.info->contexts[kCustomOpBackward])

Stack trace returned 8 entries:
[bt] (0) /home/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x40b29a) [0x7feccd0c829a]
[bt] (1) /home/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x40b8b1) [0x7feccd0c88b1]
[bt] (2) /home/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x6c6239) [0x7feccd383239]
[bt] (3) /home/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x6e1020) [0x7feccd39e020]
[bt] (4) /home/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x6c7078) [0x7feccd384078]
[bt] (5) /home/anaconda2/bin/../lib/libstdc++.so.6(+0xafc5c) [0x7fed70a50c5c]
[bt] (6) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7fed78a076ba]
[bt] (7) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7fed7802d41d]

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended label(s): Build

@jiashu-zhu jiashu-zhu changed the title Exception: unknown storage type: -1 customOp Exception: unknown storage type: -1 Oct 3, 2019
@chinakook
Copy link
Contributor

I have this problem too, It may be a compatible problem.

@chinakook
Copy link
Contributor

I think it's serious bug as most python custom operators encounter this error.

@jiashu-zhu
Copy link
Author

I think it's serious bug as most python custom operators encounter this error.

so how can I use this custom operator? I really need to use focal loss in my experiment
I can use other custom operators, they didn't have this problem @chinakook

@chinakook
Copy link
Contributor

You can define storage type as the parent class CustomOp. May be like ['default'].

@wkcn
Copy link
Member

wkcn commented Oct 6, 2019

I could not reproduce this exception

import mxnet as mx
import numpy as np


class FocalLossOperator(mx.operator.CustomOp):
    def __init__(self, gamma, alpha):
        super(FocalLossOperator, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, is_train, req, in_data, out_data, aux):
        #print('forward')
        #print(in_data[0].shape)
        y = mx.nd.exp(in_data[0] - mx.nd.max_axis(in_data[0], axis=1).reshape((in_data[0].shape[0], 1, -1)))
        y /= mx.nd.sum(y, axis=1).reshape((in_data[0].shape[0],1, -1))

        self.assign(out_data[0], req[0], y)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        y_numpy = out_data[0].asnumpy().transpose((0,2,1))
        label_numpy = in_data[1].asnumpy()
        y_numpy = y_numpy.reshape((-1,2))
        label_numpy = label_numpy.reshape((-1))
        #print(len(np.where(label_numpy == -1)[0]))
        indices = np.where(label_numpy == -1)[0]
        label_numpy[indices] = 0
        self.pro_truth = mx.nd.array(y_numpy[np.arange(y_numpy.shape[0]), label_numpy.astype(np.int)])

        # print(len(indices))
        # i!=j
        pro_truth = (self.pro_truth + 1e-14).reshape((self.pro_truth.shape[0], 1))
        grad = self.alpha * mx.nd.power(1 - pro_truth, self.gamma - 1) * \
               (self.gamma * (-1 * pro_truth * mx.nd.array(y_numpy)) * mx.nd.log(pro_truth) + mx.nd.array(y_numpy) * (1 - pro_truth))

        # i==j
        pro_truth = self.pro_truth + 1e-14

        grad_numpy = grad.asnumpy()
        grad_numpy[np.arange(y_numpy.shape[0]), label_numpy.astype(np.int)] = (
                    self.alpha * mx.nd.power(1 - pro_truth, self.gamma) * (
                    self.gamma * pro_truth * mx.nd.log(pro_truth) + pro_truth - 1)).asnumpy()
        grad_numpy /= label_numpy.shape[0]
        grad_numpy[indices,:] = 0
        #grad_numpy = grad_numpy.reshape((out_data[0].shape[0],-1,out_data[0].shape[1])).transpose((0,2,1))
        grad = mx.nd.array(grad_numpy)
        grad = grad.reshape(out_data[0].shape[0],-1,out_data[0].shape[1]).transpose((0,2,1))

        self.assign(in_grad[0], req[0], grad)

@mx.operator.register('FocalLoss')
class FocalLossProp(mx.operator.CustomOpProp):
    def __init__(self, gamma, alpha):
        super(FocalLossProp, self).__init__(need_top_grad=False)

        self.gamma = float(gamma)
        self.alpha = float(alpha)

    def list_arguments(self):
        return ['data', 'labels']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shape):
        data_shape = in_shape[0]
        labels_shape = in_shape[1]
        out_shape = data_shape
        return [data_shape, labels_shape], [out_shape], []

    def create_operator(self, ctx, shapes, dtypes):
        return FocalLossOperator(self.gamma, self.alpha)

class FocalLossGluon(mx.gluon.nn.HybridBlock):
    def hybrid_forward(self, F, x, label):
        return F.Custom(x, label, gamma=1, alpha=1, op_type='FocalLoss')

if __name__ == '__main__':
    batch_size = 3
    num_anchor = 4
    x = mx.nd.zeros((batch_size, 2, num_anchor))
    label = mx.nd.zeros((batch_size, num_anchor))
    x.attach_grad()
    with mx.autograd.record():
        y = mx.nd.Custom(x, label, gamma=1, alpha=1, op_type='FocalLoss')
        y.backward()
    print(y)
    print(x.grad)

    block = FocalLossGluon()
    block.hybridize()
    for _ in range(2):
        with mx.autograd.record():
            y = block(x, label)
            y.backward()
        print(y)
        print(x.grad)

@chinakook
Copy link
Contributor

@jiashu-zhu paste your model here.

@jiashu-zhu
Copy link
Author

I just use this focalloss to replace softmaxoutput in RetinaFace @chinakook

@jiashu-zhu
Copy link
Author

@jiashu-zhu
Copy link
Author

this focalloss works in your code? @wkcn

@wkcn
Copy link
Member

wkcn commented Oct 6, 2019

@jiashu-zhu Yes, it works in my code.

@chinakook
Copy link
Contributor

chinakook commented Oct 7, 2019

All FPN op in this repo get this error. It may be something bug with custom op.

@jiashu-zhu
Copy link
Author

Really thanks, so do you have any idea to make it works?I think I can try them @chinakook

@chinakook
Copy link
Contributor

Use a older mxnet version.

@jiashu-zhu
Copy link
Author

Many thanks, I will try it @chinakook

@wkcn
Copy link
Member

wkcn commented Oct 7, 2019

Could you please tell me which SoftmaxOutput is replaced with FocalLoss?
A minimal reproduce example is good.

@jiashu-zhu
Copy link
Author

I replace the SoftmaxOutput in line 403 of rcnn/symbol/symbol_common, and I use resnet-152 as my backbone, which you can download in retinaface homepage, and other setting remain default @wkcn

@cccorn
Copy link

cccorn commented Oct 8, 2019

I got the same problem. I tried mxnet version 1.5.0, 1.4.1, 1.3.1, 1.2.1, 1.1.0, and only version 1.1.0 works for me.

@wkcn
Copy link
Member

wkcn commented Oct 9, 2019

I need a minimal reproduce example to check the bug, since I am busy and have a little time on it.

Does it work in the minist classification? https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/symbols/lenet.py

@anirudh2290
Copy link
Member

Thanks for reporting this @jiashu-zhu @cccorn @chinakook . Thanks a lot @wkcn for offering to help. It is possible that this issue got added with the Sparse Tensor support for custom op. Have you tried commenting out the declare_backward_dependency in CustomOpProp https://github.com/dingjiansw101/RoITransformer_DOTA/blob/master/fpn/operator_py/fpn_psroi_rotatedpooling.py#L128 to see if that fixes the issue. Sorry, I am a little pressed for time right now and won't be able to dig into the issue currently. Can you try this workaround for now ?

@wkcn
Copy link
Member

wkcn commented Oct 31, 2019

I met the same problem in Deformable ConvNets when I changed ENABLE_OHEM: false :(
https://github.com/msracver/Deformable-ConvNets/blob/master/experiments/fpn/cfgs/resnet_v1_101_coco_trainval_fpn_dcn_end2end_ohem.yaml#L82

I tried to address the problem, and I found some NDArray is not initalized.
https://github.com/apache/incubator-mxnet/blob/master/src/c_api/c_api.cc#L588

I commented out the declare_backward_dependency, but it did not work.

@wkcn
Copy link
Member

wkcn commented Oct 31, 2019

A temporary solution: comment out all need_top_grad=False and declare_backward_dependency.

@jiashu-zhu
Copy link
Author

Thanks for your kindly help! @wkcn
and using older version(like MXnet v1.1.0) is also a temporary solution

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

7 participants