-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-399] Elemwise_mul between dense and csr on CPU & GPU #10894
Conversation
98c3bd3
to
178173d
Compare
98b0d93
to
5e91b94
Compare
CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo"; | ||
CHECK(req != kNullOp); | ||
const bool supported_op = std::is_same<OP, mshadow_op::mul>::value || | ||
std::is_same<OP, mshadow_op::div>::value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove div
CHECK(req != kNullOp); | ||
const bool supported_op = std::is_same<OP, mshadow_op::mul>::value || | ||
std::is_same<OP, mshadow_op::div>::value; | ||
CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul/div"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove div
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
CHECK_EQ(dns.storage_type(), kDefaultStorage); | ||
CHECK_EQ(csr.storage_type(), kCSRStorage); | ||
CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo"; | ||
CHECK(req != kNullOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (req == kNullOp) return
9a6e9cd
to
72ed037
Compare
CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo"; | ||
if (req == kNullOp) return; | ||
const bool supported_op = std::is_same<OP, mshadow_op::mul>::value; | ||
CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul/div"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove div?
const std::vector<NDArray> &outputs) { | ||
const bool supported_ops = std::is_same<mshadow_op::right, LOP>::value && | ||
std::is_same<mshadow_op::left, ROP>::value; | ||
CHECK(supported_ops); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
err msg is missing
72ed037
to
406210d
Compare
…0894) * support elemwise_mul between dns and csr * address reviews and support for backward when ograd is dns
…0894) * support elemwise_mul between dns and csr * address reviews and support for backward when ograd is dns
Description
As title
Checklist
Essentials
Changes
Comments