Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions lib/THC/generic/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
THC_API void
THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *mat, THCTensor *vec)
{
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
THAssert(THCTensor_(checkGPU)(state, 4, r_, t, mat, vec));
if( (mat->nDimension != 2) || (vec->nDimension != 1) )
THError("matrix and vector expected");
Expand All @@ -57,6 +57,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
if(t->size[0] != mat->size[0])
THError("size mismatch");

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
if(r_ != t)
{
THCTensor_(resizeAs)(state, r_, t);
Expand Down Expand Up @@ -110,6 +111,21 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
THCTensor_(free)(state, cmat);
}

#elif defined(THC_REAL_IS_HALF)
// Currently no Hgemv/SgemvEx in Cublas
THCTensor *vecAsMatrix = THCTensor_(newWithTensor)(state, vec);
THCTensor_(resize2d)(state, vecAsMatrix, vecAsMatrix->size[0], 1);

THCTensor *tAsMatrix = THCTensor_(newWithTensor)(state, t);
THCTensor_(resize2d)(state, tAsMatrix, tAsMatrix->size[0], 1);

THCTensor_(addmm)(state, r_, beta, tAsMatrix, alpha, mat, vecAsMatrix);

// r_ will have answer as matrix, need to return a vecotr
THCTensor_(resize1d)(state, r_, r_->size[0]);
THCTensor_(free)(state, vecAsMatrix);
THCTensor_(free)(state, tAsMatrix);
#endif
#else
THError("unimplemented data type");
#endif
Expand All @@ -118,7 +134,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
THC_API void
THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real alpha, THCTensor *vec1, THCTensor *vec2)
{
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
THAssert(THCTensor_(checkGPU)(state, 4, r_, t, vec1, vec2));
if ( (vec1->nDimension != 1) || (vec2->nDimension != 1) ) {
THError("vector and vector expected");
Expand All @@ -132,12 +148,13 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a
THError("size mismatch");
}

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
if (r_ != t) {
THCTensor_(resizeAs)(state, r_, t);
THCTensor_(copy)(state, r_, t);
}

if(beta != 1) {
if(THCNumerics<real>::ne(beta, ScalarConvert<int, real>::to(1))) {
THCTensor_(mul)(state, r_, r_, beta);
}

Expand Down Expand Up @@ -187,6 +204,19 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a

THCTensor_(freeCopyTo)(state, cr, r_);
}
#elif defined(THC_REAL_IS_HALF)
// currently no Hger/SgerEx in Cublas.
THCTensor *vec2T = THCTensor_(newWithTensor)(state, vec2);
THCTensor_(resize2d)(state, vec2T, vec2T->size[0], 1);
THCTensor_(transpose)(state, vec2T, NULL, 0, 1);

THCTensor *vec1M = THCTensor_(newWithTensor)(state, vec1);
THCTensor_(resize2d)(state, vec1M, vec1M->size[0], 1);

THCTensor_(addmm)(state, r_, beta, t, alpha, vec1M, vec2T);
THCTensor_(free)(state, vec2T);
THCTensor_(free)(state, vec1M);
#endif
#else
THError("unimplemented data type");
#endif
Expand Down