Skip to content

Commit 0dbf871

Browse files
lantigasoumith
authored andcommitted
Have median reduce over all dims and return just the value when dim is not provided
1 parent c691fc6 commit 0dbf871

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

generic/THCTensorMathReduce.cu

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,71 @@ THCTensor_(maxall)(THCState *state, THCTensor *self) {
330330
return val;
331331
}
332332

333+
THC_API real
334+
THCTensor_(medianall)(THCState *state, THCTensor *self) {
335+
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self));
336+
337+
real val;
338+
ptrdiff_t nelem, k;
339+
340+
nelem = THCTensor_(nElement)(state, self);
341+
k = (nelem-1) >> 1;
342+
343+
THLongStorage *size = THLongStorage_newWithSize1(nelem);
344+
THCTensor *view = THCTensor_(newView)(state, self, size);
345+
346+
THLongStorage_free(size);
347+
348+
THCTensor *sorted = THCTensor_(new)(state);
349+
THCudaLongTensor *indices = THCudaLongTensor_new(state);
350+
351+
THCTensor_(sort)(state, sorted, indices, view, 0, 0);
352+
353+
val = THCTensor_(get1d)(state, sorted, k);
354+
355+
THCTensor_(free)(state, view);
356+
THCTensor_(free)(state, sorted);
357+
THCudaLongTensor_free(state, indices);
358+
359+
THCudaCheck(cudaGetLastError());
360+
361+
return val;
362+
}
363+
364+
THC_API void
365+
THCTensor_(median)(THCState *state,
366+
THCTensor *values,
367+
THCudaLongTensor *indices,
368+
THCTensor *self,
369+
long dimension,
370+
int keepdim) {
371+
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self));
372+
373+
long t_size_dim, k;
374+
375+
t_size_dim = THCTensor_(size)(state, self, dimension);
376+
377+
k = (t_size_dim-1) >> 1;
378+
379+
THCTensor *sorted = THCTensor_(new)(state);
380+
THCudaLongTensor *sorted_indices = THCudaLongTensor_new(state);
381+
382+
THCTensor_(sort)(state, sorted, sorted_indices, self, dimension, 0);
383+
384+
THCTensor_(narrow)(state, values, sorted, dimension, k, 1);
385+
THCudaLongTensor_narrow(state, indices, sorted_indices, dimension, k, 1);
386+
387+
THCTensor_(free)(state, sorted);
388+
THCudaLongTensor_free(state, sorted_indices);
389+
390+
if (!keepdim) {
391+
THCTensor_(squeeze1d)(state, values, values, dimension);
392+
THCudaLongTensor_squeeze1d(state, indices, indices, dimension);
393+
}
394+
395+
THCudaCheck(cudaGetLastError());
396+
}
397+
333398
THC_API void
334399
THCTensor_(max)(THCState *state,
335400
THCTensor *values,

generic/THCTensorMathReduce.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ THC_API void THCTensor_(max)(THCState *state,
3434

3535
THC_API real THCTensor_(minall)(THCState *state, THCTensor *self);
3636
THC_API real THCTensor_(maxall)(THCState *state, THCTensor *self);
37+
THC_API real THCTensor_(medianall)(THCState *state, THCTensor *self);
3738

3839
THC_API accreal THCTensor_(dist)(THCState *state, THCTensor *self, THCTensor *src,
3940
real value);

0 commit comments

Comments
 (0)