Skip to content

Commit

Permalink
Merge commit '2efac3ed83a29f57f914e9044fdddd2ce7ecd6b7'
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Jul 21, 2017
2 parents 71ce344 + 2efac3e commit ec2def8
Showing 1 changed file with 14 additions and 27 deletions.
41 changes: 14 additions & 27 deletions torch/lib/THC/generic/THCTensorMathMagma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ THC_API void THCTensor_(gesvd2)(THCState *state, THCTensor *ru_, THCTensor *rs_,

THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a)
{
#ifdef USE_MAGMA
THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
THArgCheck(a->size[0] == a->size[1], 2, "A should be square");

#ifdef USE_MAGMA
int info;
int n = a->size[0];
int lwork = n * magma_get_sgetri_nb(n);
Expand Down Expand Up @@ -391,37 +391,23 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a)
magma_free_pinned(ipiv);
THCTensor_(freeCopyTo)(state, input, ra_);
#else
THArgCheck(a->nDimension == 2, 2, "A should be 2 dimensional");
THArgCheck(a->size[0] == a->size[1], 2, "A should be square");

int n = a->size[0];

// input
THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a);
// output
THCTensor *output = THCTensor_(newColumnMajor)(state, ra_, a);
THCTensor *input = THCTensor_(newColumnMajor)(state, a, a);
THCTensor_(resizeNd)(state, ra_, 2, input->size, input->stride);

size_t matrices_size = sizeof(real*);

real **matrices1 = (real **)THAlloc(matrices_size);
const real **matrices1_const = (const real **)THAlloc(matrices_size);
real **matrices2 = (real **)THAlloc(matrices_size);
matrices1[0] = THCTensor_(data)(state, input);
matrices1_const[0] = THCTensor_(data)(state, input);
matrices2[0] = THCTensor_(data)(state, output);
real *matrices1[1] = { THCTensor_(data)(state, input) };
real *matrices2[1] = { THCTensor_(data)(state, ra_) };

// Copy pointers to device.
real **d_matrices1, **d_matrices2;
const real **d_matrices1_const;
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, matrices_size));
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1_const, matrices_size));
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, matrices_size));
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices1, sizeof(real*)));
THCudaCheck(THCudaMalloc(state, (void**)&d_matrices2, sizeof(real*)));

THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size,
cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
THCudaCheck(cudaMemcpyAsync(d_matrices1_const, matrices1_const, matrices_size,
THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, sizeof(real*),
cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size,
THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, sizeof(real*),
cudaMemcpyHostToDevice, THCState_getCurrentStream(state)));
int info;
int *info_gpu;
Expand All @@ -446,11 +432,13 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a)

// Inverse
#if defined(THC_REAL_IS_FLOAT)
THCudaBlas_Sgetri(state, n, d_matrices1_const, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
THCudaBlas_Sgetri(state, n, (const real**)d_matrices1, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
#else
THCudaBlas_Dgetri(state, n, d_matrices1_const, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
THCudaBlas_Dgetri(state, n, (const real**)d_matrices1, n, ipiv_gpu, d_matrices2, n, info_gpu, 1);
#endif

THCudaCheck(cudaMemcpy(&info, info_gpu, sizeof(int), cudaMemcpyDeviceToHost));

if (info > 0)
THError("CUBLAS getri : U(%d,%d) is 0, U is singular", info, info);
else if (info < 0)
Expand All @@ -460,10 +448,9 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a)
THCudaCheck(THCudaFree(state, info_gpu));

THCudaCheck(THCudaFree(state, d_matrices1));
THCudaCheck(THCudaFree(state, d_matrices1_const));
THCudaCheck(THCudaFree(state, d_matrices2));

THCTensor_(freeCopyTo)(state, output, input);
THCTensor_(free)(state, input);
#endif
}

Expand Down

0 comments on commit ec2def8

Please sign in to comment.