Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions 3rdparty/mshadow/mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,17 @@ struct BLASEngine<cpu, float> {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc) {
index_t m, index_t n, index_t k, float alpha,
const float *A, index_t lda, const float *B, index_t ldb,
float beta, float *C, index_t ldc) {
cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
index_t m, index_t n, index_t k, float alpha,
const float *A, index_t lda, const float *B, index_t ldb,
float beta, float *C, index_t ldc, index_t batch_count,
float **workspace) {
#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
// since same m/n/k is used for all single gemms, so we put all gemms into one group
Expand Down Expand Up @@ -408,17 +408,17 @@ struct BLASEngine<cpu, double> {
}
inline static void gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc) {
index_t m, index_t n, index_t k, double alpha,
const double *A, index_t lda, const double *B, index_t ldb,
double beta, double *C, index_t ldc) {
cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline static void batched_gemm(Stream<cpu> *stream,
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
index_t m, index_t n, index_t k, double alpha,
const double *A, index_t lda, const double *B, index_t ldb,
double beta, double *C, index_t ldc, index_t batch_count,
double **workspace) {
#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
// since same m/n/k is used for all single gemms, so we put all gemms into one group
Expand Down
24 changes: 12 additions & 12 deletions src/operator/numpy/np_tensordot_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ inline void ShiftAxes(Tuple<int>* axes_summed, const int ndim) {
/**
* Gets matrix dimensions of a and b after transpose and reshape.
*/
inline void GetMatrixDimensions(int* ad1,
int* ad2,
int* bd1,
int* bd2,
inline void GetMatrixDimensions(index_t* ad1,
index_t* ad2,
index_t* bd1,
index_t* bd2,
const mxnet::Tuple<int>& a_axes_remained,
const mxnet::Tuple<int>& a_axes_summed,
const mxnet::Tuple<int>& b_axes_remained,
Expand Down Expand Up @@ -157,10 +157,10 @@ void MatrixDot(const OpContext& ctx,
const TBlob& b,
const TBlob& out,
const OpReqType req,
const int ad1,
const int ad2,
const int bd1,
const int bd2,
const index_t ad1,
const index_t ad2,
const index_t bd1,
const index_t bd2,
const bool aT = false,
const bool bT = false) {
using namespace mshadow;
Expand Down Expand Up @@ -266,7 +266,7 @@ void TensordotImpl(const Tuple<int>& a_axes_summed,
GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
&b_axes, a_shape, b_shape);

int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
b_axes_remained, b_axes_summed, a_shape, b_shape);

Expand Down Expand Up @@ -435,7 +435,7 @@ void TensordotBackwardImpl(const Tuple<int>& a_axes_summed,
GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
&b_axes, a_shape, b_shape);

int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
b_axes_remained, b_axes_summed, a_shape, b_shape);

Expand Down Expand Up @@ -653,7 +653,7 @@ void TensordotIntAxesImpl(const int axes,
GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
&b_axes, a_shape, b_shape);

int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
b_axes_remained, b_axes_summed, a_shape, b_shape);
MatrixDot<xpu>(ctx, a, b, out, req, ad1, ad2, bd1, bd2);
Expand Down Expand Up @@ -746,7 +746,7 @@ void TensordotIntAxesBackwardImpl(const int axes,
GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
&b_axes, a_shape, b_shape);

int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
b_axes_remained, b_axes_summed, a_shape, b_shape);

Expand Down
13 changes: 13 additions & 0 deletions tests/nightly/test_np_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
LARGE_X = 100000000
SMALL_X = 100
SMALL_Y = 50
INT_OVERFLOW = 2**31


@use_np
Expand Down Expand Up @@ -76,3 +77,15 @@ def test_softmax():
true_output = np.full((SMALL_Y, LARGE_X), (1 / input_data.shape[axis]))
output = npx.softmax(input_data, axis=axis)
assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)

#@pytest.mark.skip(reason="CI hasn't switch to ILP64 OpenBLAS yet")
@use_np
def test_dot():
A = np.ones((1, INT_OVERFLOW), dtype='float32')
B = np.ones((INT_OVERFLOW, 1), dtype='float32')
A.attach_grad()
with mx.autograd.record():
C = np.dot(A, B)
assert_almost_equal(C.asnumpy(), [INT_OVERFLOW], rtol=1e-5, atol=1e-5)
C.backward()
assert A.grad.shape == (1, INT_OVERFLOW)