Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

CSR compression error. CSR support for more than 2D tensors #17256

Open
Wallart opened this issue Jan 9, 2020 · 0 comments
Open

CSR compression error. CSR support for more than 2D tensors #17256

Wallart opened this issue Jan 9, 2020 · 0 comments
Labels

Comments

@Wallart
Copy link

Wallart commented Jan 9, 2020

Hello everyone,
I am trying to use sparse tensors to save memory in my Transformer architecture and I'm applying F.sparse.cast_storage on an attention weights tensor.

class ScaledDotProductAttn(gluon.HybridBlock):

    def __init__(self, dim_k, *args, **kwargs):
        super(ScaledDotProductAttn, self).__init__(*args, **kwargs)
        self._dim_k = dim_k

    def hybrid_forward(self, F, *args, **kwargs):
        query, key, value, mask, sparse_pattern = args

        matmul_qk = F.linalg.gemm2(query, key, transpose_b=True)  # seq_len_q, seq_len_k
        scaled_attn_logits = matmul_qk / math.sqrt(self._dim_k)

        if mask is not None:
            scaled_attn_logits = F.broadcast_add(scaled_attn_logits, mask * -1e9)

        attn_weights = F.softmax(scaled_attn_logits)  # seq_len_q, seq_len_k
        if sparse_pattern is not None:
            attn_weights = F.sparse.cast_storage(attn_weights * sparse_pattern, 'csr')

        output = F.linalg.gemm2(attn_weights, value)  # seq_len_q, seq_len_k
        return output, attn_weights

As you can see the sparseNDArray is densified on the fly to produce output (because value is not sparse). Then, I return a dense output and a sparse attn_weights.
Output will be finally used to compute the loss, and attn_weights for plotting if necessary.

The error occurs when I'm updating the loss metric which is calling asnumpy internally.

Traceback (most recent call last):
  File "/home/wallart/workspaces/Transformer/trainer/transformer_trainer.py", line 77, in train
    self._loss_metric.update(0, [l * self._opts.batch_size for l in losses])
  File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/metric.py", line 1687, in update
    loss = ndarray.sum(pred).asscalar()
  File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2553, in asscalar
    return self.asnumpy()[0]
  File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2535, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/base.py", line 255, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [10:01:05] src/operator/tensor/././cast_storage-inl.cuh:470: Check failed: dns.shape_.ndim() == 2 (4 vs. 2)

The issue occurs both on MXNet 1.5.1 and 1.6.0.rc0.
Everything works if I disable the F.sparse.cast_storage call

EDIT : F.sparse.cast_storage(attn_weights * sparse_attn, 'csr').asnumpy() produces the same error. Seems that CSR Sparse api only support 2D data (my tensor is 4D). I will switch to row_sparse format.
Is it planned in the future ? Or at least have an argument to sparsity last two axis of a tensor

EDIT2 : Can someone change the label to feature request ? I can't do it.

@Wallart Wallart added the Bug label Jan 9, 2020
@Wallart Wallart changed the title Sparse compression causes errors CSR compression error. CSR support for more than 2D tensors Jan 9, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

1 participant