Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Add NumPy support for np.linalg.det and np.linalg.slogdet #16800

Merged
merged 16 commits into from
Nov 24, 2019
Prev Previous commit
Next Next commit
beautify
  • Loading branch information
Ubuntu committed Nov 16, 2019
commit d23a1e339f11d1601dfe83b749219bfd5959dc79
3 changes: 2 additions & 1 deletion src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,9 @@ struct det {
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 1, DType>& det,
const Tensor<xpu, 3, DType>& LU, const Tensor<xpu, 2, int>& pivot,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
if (A.shape_.Size() == 0U)
if (A.shape_.Size() == 0U) {
return;
}
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 1, DType> sign = ctx.requested[0]
.get_space_typed<xpu, 1, DType>(det.shape_, s);
Expand Down
5 changes: 1 addition & 4 deletions src/operator/tensor/la_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,11 @@ inline bool DetShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal";
mxnet::TShape out;
if (ndim == 2) {
if (Imperative::Get()->is_np_shape()) {
if (Imperative::Get()->is_np_shape() || in.Size() == 0U) {
out = mxnet::TShape(0, 1);
} else {
out = mxnet::TShape(1, 1);
}
if (in.Size() == 0U) {
out = mxnet::TShape(0, -1);
}
} else {
out = mxnet::TShape(in.begin(), in.end() - 2);
}
Expand Down