Skip to content

Commit

Permalink
compute common dtype based on inputs only (pytorch#25593)
Browse files Browse the repository at this point in the history
Summary:
Currently we compute common dtype for TensorIterator based on all inputs and outputs. It can be a problem when dtype of the outputs should be different from dtype of inputs. (Example torch.eq)
We also have `dont_compute_common_dtype` method that allows us to avoid a computation of a common dtype for all inputs and outputs.

This PR will give the ability to compute common dtype based only on inputs using `compute_common_dtype_only_for_inputs`. Also it will provide a simple method `input_dtype(int arg=0) that will give the ability to dispatch based on input's dtype.

```
AT_DISPATCH_ALL_TYPES(iter.input_dtype(), ...
```
Pull Request resolved: pytorch#25593

Differential Revision: D17286352

Pulled By: ifedan

fbshipit-source-id: a94fb608acd2763120992fe85b8dfd02ff21f9ba
  • Loading branch information
ifedan authored and facebook-github-bot committed Sep 11, 2019
1 parent 8f7020b commit e69a6ba
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 13 deletions.
49 changes: 39 additions & 10 deletions aten/src/ATen/native/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,17 @@ compute_result_type(at::ArrayRef<OperandInfo> operands,
return compute_result_type(operands, predicates...);
}

std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
static std::tuple<Device, ScalarType> compute_common_type_(at::ArrayRef<OperandInfo> operands) {
// See [Result type computation] in TensorIterator.h

auto result_type =
compute_result_type(operands_,
compute_result_type(operands,
[](const OperandInfo& op) { return op.tensor.dim() > 0; },
[](const OperandInfo& op) { return !op.tensor.unsafeGetTensorImpl()->is_wrapped_number(); },
[](const OperandInfo& op) { return true; });

if (ScalarType::Bool == std::get<1>(result_type)) {
auto alternate = compute_result_type(operands_,
auto alternate = compute_result_type(operands,
[](const OperandInfo& op) {
return op.tensor.dim() == 0;
}
Expand All @@ -130,7 +130,7 @@ std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
// if non-zero-dim tensor result is an integral type and there's a zero-dim
// floating point operand, we'll promote the floating point type.
if (isIntegralType(std::get<1>(result_type), false)) {
auto alternate = compute_result_type(operands_,
auto alternate = compute_result_type(operands,
[](const OperandInfo& op) {
return isFloatingType(op.tensor.scalar_type()) && op.tensor.dim() == 0;
}
Expand All @@ -145,6 +145,10 @@ std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
return result_type;
}

std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
return compute_common_type_(operands_);
}

static void validate_dtype(OperandInfo& op, ScalarType common_dtype, int ninputs) {
if (op.tensor.defined()) {
// For binary_ops, we follow casting rules. For unary/nullary types
Expand Down Expand Up @@ -182,23 +186,40 @@ static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype)

void TensorIterator::compute_types() {
bool missing_dtypes = false;
bool missing_output_dtypes = false;
bool has_read_write_op = false;
ScalarType common_dtype = dtype();
for (auto& op : operands_) {
if (!op.tensor.defined() && !op.is_type_defined()) {
missing_dtypes = true;
if (op.is_output) {
missing_output_dtypes = true;
}
}
if (op.is_read_write) {
has_read_write_op = true;
}
}

if (missing_dtypes || compute_common_dtype_) {
auto common_type = compute_common_type();
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) {
TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs");
TORCH_CHECK(!has_read_write_op, "unable to compute and promote common dtype based only on inputs if input is same as output");
}

bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE);
bool compute_common_dtype_only_for_inputs = (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS);

if (missing_dtypes || compute_common_dtype) {
auto operands = compute_common_dtype_only_for_inputs ? at::ArrayRef<OperandInfo>(operands_).slice(noutputs()) : operands_;
auto common_type = compute_common_type_(operands);
auto common_device = std::get<0>(common_type);
common_dtype = std::get<1>(common_type);
bool has_cpu_scalar = false;
for (auto& op : operands_) {
if (!op.is_type_defined()) {
op.device = common_device;
op.dtype = common_dtype;
} else if (compute_common_dtype_ &&
} else if (compute_common_dtype &&
(op.device != common_device || op.dtype != common_dtype)) {
if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
common_device.is_cuda() && op.tensor.device().is_cpu() &&
Expand All @@ -217,12 +238,20 @@ void TensorIterator::compute_types() {
op.dtype = op.tensor.scalar_type();
} else {
op.device = common_device;
op.dtype = common_dtype;
if (compute_common_dtype_only_for_inputs && op.is_output) {
op.dtype = op.tensor.scalar_type();
} else {
op.dtype = common_dtype;
}
}
}

validate_dtype(op, common_dtype, ninputs());
maybe_promote_common_dtype(op, common_dtype);
if (!compute_common_dtype_only_for_inputs) {
validate_dtype(op, common_dtype, ninputs());
}
if (!compute_common_dtype_only_for_inputs || !op.is_output) {
maybe_promote_common_dtype(op, common_dtype);
}

if (op.tensor.defined() && op.device != op.tensor.device()) {
if (op.is_output) {
Expand Down
17 changes: 14 additions & 3 deletions aten/src/ATen/native/TensorIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ struct CAFFE2_API OperandInfo {

struct SplitUntil32Bit;

enum class CommonDTypeStrategy : uint8_t {
COMPUTE_ALL = 0, // Compute common dtype based on inputs and outputs. Try to promote common dtype to inputs and outputs
COMPUTE_INPUTS = 1, // Compute common dtype based only on inputs. Try to promote common dtype only to inputs
COMPUTE_NONE = 2, // Do not compute and promote common dtype
};

struct CAFFE2_API TensorIterator {
using DimMask = std::bitset<64>;
using PtrVector = SmallVector<char*, 4>;
Expand Down Expand Up @@ -177,6 +183,7 @@ struct CAFFE2_API TensorIterator {
IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; }
void* data_ptr(int arg) const;
ScalarType dtype(int arg=0) const { return operands_[arg].tensor.scalar_type(); }
ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].dtype; }
Device device(int arg=0) const { return operands_[arg].device; }
DeviceType device_type(int arg=0) const { return device(arg).type(); }
int64_t element_size(int arg) const { return elementSize(dtype(arg)); }
Expand All @@ -192,7 +199,7 @@ struct CAFFE2_API TensorIterator {
}

void cast_outputs() {
if (compute_common_dtype_) {
if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_ALL) {
for(int i=0; i < noutputs(); i++) {
if (operands_[i].original_tensor.defined() && dtype(i) != operands_[i].original_tensor.scalar_type()) {
operands_[i].original_tensor.copy_(operands_[i].tensor);
Expand Down Expand Up @@ -295,7 +302,11 @@ struct CAFFE2_API TensorIterator {
}

void dont_compute_common_dtype() {
compute_common_dtype_ = false;
compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_NONE;
}

void compute_common_dtype_only_for_inputs() {
compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_INPUTS;
}

void dont_resize_outputs() {
Expand Down Expand Up @@ -328,11 +339,11 @@ struct CAFFE2_API TensorIterator {
#endif
SmallVector<OperandInfo, 4> operands_;
int num_outputs_ = 0;
CommonDTypeStrategy compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_ALL;
bool has_coalesced_dimensions_ = false;
bool accumulate_ = false;
bool resize_outputs_ = true;
bool is_reduction_ = false;
bool compute_common_dtype_ = true;
bool allow_cpu_scalars_ = false;
bool promote_gpu_output_dtypes_ = false;
bool final_output_ = true;
Expand Down
56 changes: 56 additions & 0 deletions aten/src/ATen/test/tensor_iterator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,59 @@ TEST(TensorIteratorTest, SerialLoopSingleThread) {
});
}

TEST(TensorIteratorTest, InputDType) {
auto iter = at::TensorIterator();
iter.add_output(at::ones({1, 1}, at::dtype(at::kBool)));
iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat)));
iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble)));
iter.dont_compute_common_dtype();
iter.build();
EXPECT_TRUE(iter.input_dtype() == at::kFloat);
EXPECT_TRUE(iter.input_dtype(0) == at::kFloat);
EXPECT_TRUE(iter.input_dtype(1) == at::kDouble);
}

TEST(TensorIteratorTest, ComputeCommonDTypeInputOnly) {
auto iter = at::TensorIterator();
iter.add_output(at::ones({1, 1}, at::dtype(at::kBool)));
iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat)));
iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble)));
iter.compute_common_dtype_only_for_inputs();
iter.build();
EXPECT_TRUE(iter.dtype(0) == at::kBool);
EXPECT_TRUE(iter.dtype(1) == at::kDouble);
EXPECT_TRUE(iter.dtype(2) == at::kDouble);
}

TEST(TensorIteratorTest, DoNotComputeCommonDTypeInputOnly) {
auto iter = at::TensorIterator();
iter.add_output(at::ones({1, 1}, at::dtype(at::kLong)));
iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat)));
iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble)));
iter.compute_common_dtype_only_for_inputs();
iter.dont_compute_common_dtype();
iter.build();
EXPECT_TRUE(iter.dtype(0) == at::kLong);
EXPECT_TRUE(iter.dtype(1) == at::kFloat);
EXPECT_TRUE(iter.dtype(2) == at::kDouble);
}

TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfInputSameAsOutput) {
Tensor inout = at::ones({1, 1}, at::dtype(at::kFloat));
auto iter = at::TensorIterator();
iter.add_output(inout);
iter.add_input(inout);
iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble)));
iter.compute_common_dtype_only_for_inputs();
ASSERT_ANY_THROW(iter.build());
}

TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) {
Tensor out;
auto iter = at::TensorIterator();
iter.add_output(out);
iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble)));
iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat)));
iter.compute_common_dtype_only_for_inputs();
ASSERT_ANY_THROW(iter.build());
}

0 comments on commit e69a6ba

Please sign in to comment.