diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 2c9a483567f8..ffdb706e5e3c 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -246,6 +246,10 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (target_stype == kRowSparseStorage) { dispatched = storage_type_assign(&out_stype, kRowSparseStorage, dispatch_mode, DispatchMode::kFComputeEx); + // csr.T, rsp/dns -> dns + } else if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); } } if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns &&