Skip to content

Commit 0361757

Browse files
izdebyfacebook-github-bot
authored andcommitted
Сhange type of a tensor with bools (pytorch#19097)
Summary: **This is **bc-breaking** change** Change dtype of a tensor which was created from bool data. Old behavior: torch.tensor([True, False]) -> uint8 tensor Now: torch.tensor([True, False]) -> bool tensor Tested via tests. Pull Request resolved: pytorch#19097 Reviewed By: ezyang Differential Revision: D15632553 Pulled By: izdeby fbshipit-source-id: b019150844c561a6845710a3c62b12f06b68bbe3
1 parent 22ddddf commit 0361757

File tree

8 files changed

+14
-11
lines changed

8 files changed

+14
-11
lines changed

aten/src/ATen/Declarations.cwrap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@
365365
name: _th_equal
366366
cpu_bool: True
367367
cname: equal
368+
cpu_bool: True
368369
variants:
369370
- function
370371
return: bool

aten/src/ATen/core/jit_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
12541254
} else if (type == IntType::get()) {
12551255
return at::ScalarType::Long;
12561256
} else if (type == BoolType::get()) {
1257-
return at::ScalarType::Byte;
1257+
return at::ScalarType::Bool;
12581258
}
12591259
AT_ASSERTM(
12601260
0,

test/test_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ def test_default_collate_dtype(self):
10071007
arr = [True, False]
10081008
collated = _utils.collate.default_collate(arr)
10091009
self.assertEqual(collated, torch.tensor(arr))
1010-
self.assertEqual(collated.dtype, torch.uint8)
1010+
self.assertEqual(collated.dtype, torch.bool)
10111011

10121012
# Should be a no-op
10131013
arr = ['a', 'b', 'c']

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6723,7 +6723,7 @@ def func7():
67236723
return torch.tensor([[1]])
67246724

67256725
list_input = [func1, func2, func3, func4, func5, func6, func7]
6726-
expected_shape = ["Long(*)", ("Byte(*)"), "Double(*)", "Double()", "Long()", "Byte()", "Long(*, *)"]
6726+
expected_shape = ["Long(*)", ("Bool(*)"), "Double(*)", "Double()", "Long()", "Bool()", "Long(*, *)"]
67276727

67286728
for fn, expect in zip(list_input, expected_shape):
67296729
self.checkScript(fn, ())

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3272,7 +3272,7 @@ def test_inference(default_dtype):
32723272
self.assertIs(default_dtype, torch.tensor(()).dtype)
32733273
self.assertIs(default_dtype, torch.tensor(5.).dtype)
32743274
self.assertIs(torch.int64, torch.tensor(5).dtype)
3275-
self.assertIs(torch.uint8, torch.tensor(True).dtype)
3275+
self.assertIs(torch.bool, torch.tensor(True).dtype)
32763276
self.assertIs(torch.int32, torch.tensor(5, dtype=torch.int32).dtype)
32773277
self.assertIs(default_dtype, torch.tensor(((7, 5), (9, 5.))).dtype)
32783278
self.assertIs(default_dtype, torch.tensor(((5., 5), (3, 5))).dtype)

torch/csrc/autograd/python_variable_indexing.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis
173173
result = applySelect(result, dim, THPUtils_unpackLong(obj), i);
174174
} else {
175175
result = result.unsqueeze(dim);
176-
handle_var(boolToIndexingTensor(result, var.item<uint8_t>() != 0));
176+
if(scalar_type == at::kBool) {
177+
handle_var(boolToIndexingTensor(result, var.item<bool>() != 0));
178+
} else {
179+
handle_var(boolToIndexingTensor(result, var.item<uint8_t>() != 0));
180+
}
177181
}
178182
} else {
179183
handle_var(var);

torch/csrc/jit/register_special_ops.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ void storeLastDimension(
9999
}
100100
}
101101

102-
// bool vector needs to be cast to uint8_t
103102
template <>
104103
void storeLastDimension<bool>(
105104
char* data,
@@ -112,7 +111,7 @@ void storeLastDimension<bool>(
112111
auto seq_size = obj.size();
113112
checkSequenceSize(n, dim, seq_size);
114113
for (int64_t i = 0; i < n; i++) {
115-
*(uint8_t*)data = static_cast<uint8_t>(obj[i]);
114+
*(bool*)data = static_cast<bool>(obj[i]);
116115
data += strides[dim] * elementSize;
117116
}
118117
}
@@ -291,7 +290,7 @@ RegisterOperators reg({
291290
DEFINE_TORCH_TENSOR_OP(
292291
bool,
293292
bool,
294-
at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val))
293+
at::empty({}, at::CPU(at::kBool).options()).fill_(scalar_val))
295294

296295
// reference python implementation: internal_new_from_data in
297296
// tensor_new.cpp

torch/csrc/utils/tensor_new.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ ScalarType infer_scalar_type(PyObject *obj) {
124124
return ScalarType::Long;
125125
}
126126
if (PyBool_Check(obj)) {
127-
// TODO: infer Bool when we have Bool ScalarType
128-
return ScalarType::Byte;
127+
return ScalarType::Bool;
129128
}
130129
if (THPVariable_Check(obj)) {
131130
auto var = reinterpret_cast<THPVariable*>(obj)->cdata;
@@ -477,7 +476,7 @@ Tensor indexing_tensor_from_data(
477476
// Specific to tensor indexing, converts an indexing list to an
478477
// indexing tensor (type Byte or Long)
479478
ScalarType inferred_scalar_type = infer_scalar_type(data);
480-
if (inferred_scalar_type == ScalarType::Byte) {
479+
if (inferred_scalar_type == ScalarType::Byte || inferred_scalar_type == ScalarType::Bool) {
481480
auto& idx_type = type.toScalarType(inferred_scalar_type);
482481
return internal_new_from_data(idx_type, inferred_scalar_type, std::move(device), data, false, false, false);
483482
} else {

0 commit comments

Comments
 (0)