Skip to content

Commit 2e37ab8

Browse files
izdebyfacebook-github-bot
authored andcommitted
Enable bool support for several index methods (pytorch#21435)
Summary: Enable bool tensors for these index methods: - index_select - index_copy - put - take - index_fill Tested via unit tests TODO: Enable index_add in a separate PR as it requires more "side" changes. Pull Request resolved: pytorch#21435 Differential Revision: D15684964 Pulled By: izdeby fbshipit-source-id: 48440e4d44873d70c4577e017dd0d8977e0fa15a
1 parent 61cc03f commit 2e37ab8

File tree

8 files changed

+358
-304
lines changed

8 files changed

+358
-304
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@
211211
]]
212212
[[
213213
name: _th_index_select
214+
cpu_bool: True
215+
cuda_bool: True
214216
cname: indexSelect
215217
variants:
216218
- function
@@ -226,6 +228,8 @@
226228
[[
227229
name: _th_index_copy_
228230
cname: indexCopy
231+
cpu_bool: True
232+
cuda_bool: True
229233
variants: function
230234
return: argument 0
231235
arguments:
@@ -237,6 +241,8 @@
237241
]]
238242
[[
239243
name: _th_take
244+
cpu_bool: True
245+
cuda_bool: True
240246
cname: take
241247
variants:
242248
- function
@@ -250,6 +256,8 @@
250256
]]
251257
[[
252258
name: _th_put_
259+
cpu_bool: True
260+
cuda_bool: True
253261
cname: put
254262
variants: function
255263
backends:
@@ -277,6 +285,8 @@
277285
]]
278286
[[
279287
name: _th_index_fill_
288+
cpu_bool: True
289+
cuda_bool: True
280290
cname: indexFill
281291
variants: function
282292
return: argument 0

aten/src/TH/generic/THTensorEvenMoreMath.cpp

Lines changed: 141 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -184,117 +184,6 @@ scalar_t THTensor_(maxall)(THTensor *tensor)
184184
return theMax;
185185
}
186186

187-
#if !defined(TH_REAL_IS_BOOL)
188-
189-
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
190-
{
191-
int64_t tensor_size = THTensor_(nElement)(tensor);
192-
int tensor_contig = THTensor_(isContiguous)(tensor);
193-
int mask_contig = THTensor_(isContiguous)(mask);
194-
if (tensor_contig && mask_contig) {
195-
TH_TENSOR_APPLY2_PARALLEL(tensor_size, tensor_contig, mask_contig,
196-
scalar_t, tensor, unsigned char, mask,
197-
if (*mask_data > 1) {
198-
THError("Mask tensor can take 0 and 1 values only");
199-
} else if (*mask_data == 1) {
200-
*tensor_data = value;
201-
},
202-
TH_OMP_OVERHEAD_THRESHOLD);
203-
} else {
204-
TH_TENSOR_APPLY2(scalar_t, tensor, unsigned char, mask,
205-
if (*mask_data > 1) {
206-
THFree(mask_counter);
207-
THFree(tensor_counter);
208-
THError("Mask tensor can take 0 and 1 values only");
209-
} else if (*mask_data == 1) {
210-
*tensor_data = value;
211-
});
212-
}
213-
}
214-
215-
void THTensor_(maskedFillBool)(THTensor *tensor, THBoolTensor *mask, scalar_t value)
216-
{
217-
int64_t tensor_size = THTensor_(nElement)(tensor);
218-
int tensor_contig = THTensor_(isContiguous)(tensor);
219-
int mask_contig = THTensor_(isContiguous)(mask);
220-
if (tensor_contig && mask_contig) {
221-
TH_TENSOR_APPLY2_PARALLEL(tensor_size, tensor_contig, mask_contig,
222-
scalar_t, tensor, bool, mask,
223-
if (*mask_data) {
224-
*tensor_data = value;
225-
},
226-
TH_OMP_OVERHEAD_THRESHOLD);
227-
} else {
228-
TH_TENSOR_APPLY2(scalar_t, tensor, bool, mask,
229-
if (*mask_data) {
230-
*tensor_data = value;
231-
});
232-
}
233-
}
234-
235-
void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src )
236-
{
237-
THTensor *srct = THTensor_(newContiguous)(src);
238-
scalar_t *src_data = srct->data<scalar_t>();
239-
ptrdiff_t cntr = 0;
240-
ptrdiff_t nelem = THTensor_(nElement)(srct);
241-
if (THTensor_(nElement)(tensor) != THByteTensor_nElement(mask))
242-
{
243-
c10::raw::intrusive_ptr::decref(srct);
244-
THError("Number of elements of destination tensor != Number of elements in mask");
245-
}
246-
TH_TENSOR_APPLY2(scalar_t, tensor, unsigned char, mask,
247-
if (*mask_data > 1)
248-
{
249-
c10::raw::intrusive_ptr::decref(srct);
250-
THFree(mask_counter);
251-
THFree(tensor_counter);
252-
THError("Mask tensor can take 0 and 1 values only");
253-
}
254-
else if (*mask_data == 1)
255-
{
256-
if (cntr == nelem)
257-
{
258-
c10::raw::intrusive_ptr::decref(srct);
259-
THFree(mask_counter);
260-
THFree(tensor_counter);
261-
THError("Number of elements of src < number of ones in mask");
262-
}
263-
*tensor_data = *src_data;
264-
src_data++;
265-
cntr++;
266-
});
267-
c10::raw::intrusive_ptr::decref(srct);
268-
}
269-
270-
void THTensor_(maskedCopyBool)(THTensor *tensor, THBoolTensor *mask, THTensor* src )
271-
{
272-
THTensor *srct = THTensor_(newContiguous)(src);
273-
scalar_t *src_data = srct->data<scalar_t>();
274-
ptrdiff_t cntr = 0;
275-
ptrdiff_t nelem = THTensor_(nElement)(srct);
276-
if (THTensor_(nElement)(tensor) != THBoolTensor_nElement(mask))
277-
{
278-
c10::raw::intrusive_ptr::decref(srct);
279-
THError("Number of elements of destination tensor != Number of elements in mask");
280-
}
281-
TH_TENSOR_APPLY2(scalar_t, tensor, bool, mask,
282-
if (*mask_data)
283-
{
284-
if (cntr == nelem)
285-
{
286-
c10::raw::intrusive_ptr::decref(srct);
287-
THFree(mask_counter);
288-
THFree(tensor_counter);
289-
THError("Number of elements of src < number of ones in mask");
290-
}
291-
*tensor_data = *src_data;
292-
src_data++;
293-
cntr++;
294-
});
295-
c10::raw::intrusive_ptr::decref(srct);
296-
}
297-
298187
void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index)
299188
{
300189
ptrdiff_t i, numel;
@@ -510,6 +399,147 @@ void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int ac
510399
THLongTensor_free(index);
511400
}
512401

402+
void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val)
403+
{
404+
ptrdiff_t i, numel;
405+
THTensor *tSlice;
406+
int64_t *index_data;
407+
408+
numel = THLongTensor_nElement(index);
409+
THArgCheck(THTensor_nDimensionLegacyNoScalars(index) == 1, 3, "Index is supposed to be a vector");
410+
THArgCheck(dim < THTensor_nDimensionLegacyNoScalars(tensor), 4,"Indexing dim %d is out of bounds of tensor", dim);
411+
412+
index = THLongTensor_newContiguous(index);
413+
index_data = THLongTensor_data(index);
414+
415+
for (i=0; i<numel; i++)
416+
{
417+
if (tensor->dim() > 1)
418+
{
419+
tSlice = THTensor_(new)();
420+
THTensor_(select)(tSlice, tensor,dim,index_data[i]);
421+
THTensor_(fill)(tSlice, val);
422+
c10::raw::intrusive_ptr::decref(tSlice);
423+
}
424+
else
425+
{
426+
THTensor_(set1d)(tensor, index_data[i], val);
427+
}
428+
}
429+
THLongTensor_free(index);
430+
}
431+
432+
#if !defined(TH_REAL_IS_BOOL)
433+
434+
void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
435+
{
436+
int64_t tensor_size = THTensor_(nElement)(tensor);
437+
int tensor_contig = THTensor_(isContiguous)(tensor);
438+
int mask_contig = THTensor_(isContiguous)(mask);
439+
if (tensor_contig && mask_contig) {
440+
TH_TENSOR_APPLY2_PARALLEL(tensor_size, tensor_contig, mask_contig,
441+
scalar_t, tensor, unsigned char, mask,
442+
if (*mask_data > 1) {
443+
THError("Mask tensor can take 0 and 1 values only");
444+
} else if (*mask_data == 1) {
445+
*tensor_data = value;
446+
},
447+
TH_OMP_OVERHEAD_THRESHOLD);
448+
} else {
449+
TH_TENSOR_APPLY2(scalar_t, tensor, unsigned char, mask,
450+
if (*mask_data > 1) {
451+
THFree(mask_counter);
452+
THFree(tensor_counter);
453+
THError("Mask tensor can take 0 and 1 values only");
454+
} else if (*mask_data == 1) {
455+
*tensor_data = value;
456+
});
457+
}
458+
}
459+
460+
void THTensor_(maskedFillBool)(THTensor *tensor, THBoolTensor *mask, scalar_t value)
461+
{
462+
int64_t tensor_size = THTensor_(nElement)(tensor);
463+
int tensor_contig = THTensor_(isContiguous)(tensor);
464+
int mask_contig = THTensor_(isContiguous)(mask);
465+
if (tensor_contig && mask_contig) {
466+
TH_TENSOR_APPLY2_PARALLEL(tensor_size, tensor_contig, mask_contig,
467+
scalar_t, tensor, bool, mask,
468+
if (*mask_data) {
469+
*tensor_data = value;
470+
},
471+
TH_OMP_OVERHEAD_THRESHOLD);
472+
} else {
473+
TH_TENSOR_APPLY2(scalar_t, tensor, bool, mask,
474+
if (*mask_data) {
475+
*tensor_data = value;
476+
});
477+
}
478+
}
479+
480+
void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src )
481+
{
482+
THTensor *srct = THTensor_(newContiguous)(src);
483+
scalar_t *src_data = srct->data<scalar_t>();
484+
ptrdiff_t cntr = 0;
485+
ptrdiff_t nelem = THTensor_(nElement)(srct);
486+
if (THTensor_(nElement)(tensor) != THByteTensor_nElement(mask))
487+
{
488+
c10::raw::intrusive_ptr::decref(srct);
489+
THError("Number of elements of destination tensor != Number of elements in mask");
490+
}
491+
TH_TENSOR_APPLY2(scalar_t, tensor, unsigned char, mask,
492+
if (*mask_data > 1)
493+
{
494+
c10::raw::intrusive_ptr::decref(srct);
495+
THFree(mask_counter);
496+
THFree(tensor_counter);
497+
THError("Mask tensor can take 0 and 1 values only");
498+
}
499+
else if (*mask_data == 1)
500+
{
501+
if (cntr == nelem)
502+
{
503+
c10::raw::intrusive_ptr::decref(srct);
504+
THFree(mask_counter);
505+
THFree(tensor_counter);
506+
THError("Number of elements of src < number of ones in mask");
507+
}
508+
*tensor_data = *src_data;
509+
src_data++;
510+
cntr++;
511+
});
512+
c10::raw::intrusive_ptr::decref(srct);
513+
}
514+
515+
void THTensor_(maskedCopyBool)(THTensor *tensor, THBoolTensor *mask, THTensor* src )
516+
{
517+
THTensor *srct = THTensor_(newContiguous)(src);
518+
scalar_t *src_data = srct->data<scalar_t>();
519+
ptrdiff_t cntr = 0;
520+
ptrdiff_t nelem = THTensor_(nElement)(srct);
521+
if (THTensor_(nElement)(tensor) != THBoolTensor_nElement(mask))
522+
{
523+
c10::raw::intrusive_ptr::decref(srct);
524+
THError("Number of elements of destination tensor != Number of elements in mask");
525+
}
526+
TH_TENSOR_APPLY2(scalar_t, tensor, bool, mask,
527+
if (*mask_data)
528+
{
529+
if (cntr == nelem)
530+
{
531+
c10::raw::intrusive_ptr::decref(srct);
532+
THFree(mask_counter);
533+
THFree(tensor_counter);
534+
THError("Number of elements of src < number of ones in mask");
535+
}
536+
*tensor_data = *src_data;
537+
src_data++;
538+
cntr++;
539+
});
540+
c10::raw::intrusive_ptr::decref(srct);
541+
}
542+
513543
void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src)
514544
{
515545
ptrdiff_t i, numel;
@@ -551,36 +581,6 @@ void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTenso
551581
THLongTensor_free(index);
552582
}
553583

554-
void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val)
555-
{
556-
ptrdiff_t i, numel;
557-
THTensor *tSlice;
558-
int64_t *index_data;
559-
560-
numel = THLongTensor_nElement(index);
561-
THArgCheck(THTensor_nDimensionLegacyNoScalars(index) == 1, 3, "Index is supposed to be a vector");
562-
THArgCheck(dim < THTensor_nDimensionLegacyNoScalars(tensor), 4,"Indexing dim %d is out of bounds of tensor", dim);
563-
564-
index = THLongTensor_newContiguous(index);
565-
index_data = THLongTensor_data(index);
566-
567-
for (i=0; i<numel; i++)
568-
{
569-
if (tensor->dim() > 1)
570-
{
571-
tSlice = THTensor_(new)();
572-
THTensor_(select)(tSlice, tensor,dim,index_data[i]);
573-
THTensor_(fill)(tSlice, val);
574-
c10::raw::intrusive_ptr::decref(tSlice);
575-
}
576-
else
577-
{
578-
THTensor_(set1d)(tensor, index_data[i], val);
579-
}
580-
}
581-
THLongTensor_free(index);
582-
}
583-
584584
void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index)
585585
{
586586
int64_t elems_per_row, i, idx;

aten/src/TH/generic/THTensorMath.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,20 @@ TH_API void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src);
6262
TH_API void THTensor_(cmaxValue)(THTensor *r, THTensor *t, scalar_t value);
6363
TH_API void THTensor_(cminValue)(THTensor *r, THTensor *t, scalar_t value);
6464

65+
TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
66+
TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
67+
TH_API void THTensor_(take)(THTensor *tensor, THTensor *src, THLongTensor *index);
68+
TH_API void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate);
69+
TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val);
70+
6571
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
6672

6773
TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value);
6874
TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src);
6975
TH_API void THTensor_(maskedFillBool)(THTensor *tensor, THBoolTensor *mask, scalar_t value);
7076
TH_API void THTensor_(maskedCopyBool)(THTensor *tensor, THBoolTensor *mask, THTensor* src);
7177

72-
TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
73-
TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
7478
TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
75-
TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, scalar_t val);
76-
TH_API void THTensor_(take)(THTensor *tensor, THTensor *src, THLongTensor *index);
77-
TH_API void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate);
7879

7980
TH_API void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
8081
TH_API void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);

aten/src/THC/THCTensorIndex.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,7 @@ void dispatchTakePut(THCState *state, TensorType *a, TensorType *b, THCudaLongTe
480480

481481
#include <THC/generic/THCTensorIndex.cu>
482482
#include <THC/THCGenerateAllTypes.h>
483+
484+
485+
#include <THC/generic/THCTensorIndex.cu>
486+
#include <THC/THCGenerateBoolType.h>

aten/src/THC/THCTensorMath.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@
6161
#include <THC/generic/THCTensorIndex.h>
6262
#include <THC/THCGenerateAllTypes.h>
6363

64+
#include <THC/generic/THCTensorIndex.h>
65+
#include <THC/THCGenerateBoolType.h>
66+
6467
#include <THC/generic/THCTensorSort.h>
6568
#include <THC/THCGenerateAllTypes.h>
6669

0 commit comments

Comments
 (0)