Skip to content

Commit

Permalink
Delete SymIntArrayRef wrapper struct (pytorch#84837)
Browse files Browse the repository at this point in the history
Since we separated at::foo and at::foo_symint there is no benefit
to trying to make initializer lists work in both cases.  So we can
get rid of the special different struct.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: pytorch#84837
Approved by: https://github.com/kit1980
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 12, 2022
1 parent 8cdc067 commit 9c78f59
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 215 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e0dcc3171c8024ab288551d105fba24fbfae7332
09be9870437684ba2da6741af3eb10126c04aede
2 changes: 0 additions & 2 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,6 @@ struct TORCH_API IValue final {
}
}

IValue(c10::SymIntArrayRef v);

bool isSymInt() const {
return Tag::SymInt == tag;
}
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,6 @@ inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
list.push_back(e);
}
}
inline IValue::IValue(c10::SymIntArrayRef v) : IValue(at::ArrayRef<c10::SymInt>(v.data(), v.size())) {}
template <class T, IValue::enable_if_ivalue_constructible<T>>
inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
auto list = to<c10::List<T>>();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/metal/ops/MetalReshape.mm
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) {

Tensor reshape(const Tensor& input, IntArrayRef shape) {
TORCH_CHECK(input.is_metal());
return view(input, c10::SymIntArrayRef::fromIntArrayRef(shape));
return view(input, c10::fromIntArrayRef(shape));
}

Tensor flatten_using_ints(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/extension_backend_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Tensor empty_strided_override(
c10::optional<c10::Device> device,
c10::optional<bool> pin_memory) {

return empty_override(SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
return empty_override(fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt);
}

TORCH_LIBRARY_IMPL(aten, ORT, m) {
Expand Down
217 changes: 21 additions & 196 deletions c10/core/SymIntArrayRef.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
// This file defines `SymIntArrayRef` which serves as the view onto
// std::vector<SymInt>. This class is conceptually and mostly functionally
// equivalent to ArrayRef<SymInt>.
//
// However, ArrayRef<SymInt> can't be used directly as it introduces ambiguity
// in the following cases:
// - a.expand({1, 2, 3}) matches two overloads:
// 1. `at::Tensor Tensor::expand(c10::SymIntArrayRef size, bool implicit)`
// 2. `at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit)`
// Introducing `SymIntArrayRef` allows to have a finer-grained control over
// which overload will be used.

#pragma once

#include <c10/core/SymInt.h>
Expand All @@ -23,196 +11,33 @@
#include <vector>

namespace c10 {
/// SymIntArrayRef - Represent a constant reference to an array (0 or more
/// elements consecutively in memory), i.e. a start pointer and a length. It
/// allows various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the SymIntArrayRef. For this reason, it is not in
/// general safe to store an SymIntArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.

class SymIntArrayRef final {
public:
using iterator = const c10::SymInt*;
using const_iterator = const c10::SymInt*;
using size_type = size_t;
using value_type = c10::SymInt;

using reverse_iterator = std::reverse_iterator<iterator>;

private:
ArrayRef<c10::SymInt> wrapped_symint_array_ref;

public:
/// @name Constructors
/// @{

/// Construct an empty SymIntArrayRef.
/* implicit */ constexpr SymIntArrayRef() {}

/* implicit */ SymIntArrayRef(const std::vector<c10::SymInt>& Vec)
: wrapped_symint_array_ref(Vec) {}

/// Construct an SymIntArrayRef from a pointer and length.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
const c10::SymInt* data,
size_t length)
: wrapped_symint_array_ref(data, length) {}

template <typename U>
/* implicit */ SymIntArrayRef(
const SmallVectorTemplateCommon<c10::SymInt, U>& Vec)
: wrapped_symint_array_ref(Vec) {}

/// Construct an SymIntArrayRef from a range.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
const c10::SymInt* begin,
const c10::SymInt* end)
: wrapped_symint_array_ref(begin, end) {}

/// Construct an SymIntArrayRef from a C array.
template <size_t N>
/* implicit */ constexpr SymIntArrayRef(const c10::SymInt (&Arr)[N])
: wrapped_symint_array_ref(Arr) {}

// Prefer using a more semantic constructor, like
// fromIntArrayRefKnownNonNegative
static SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}

static SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
return fromIntArrayRefUnchecked(array_ref);
}

static SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
for (size_t i = 0; i < array_ref.size(); ++i) {
TORCH_CHECK(
SymInt::check_range(array_ref[i]),
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
array_ref[i]);
}
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}

/// @}
/// @name Simple Operations
/// @{

constexpr iterator begin() const {
return wrapped_symint_array_ref.begin();
}
constexpr iterator end() const {
return wrapped_symint_array_ref.end();
}

// These are actually the same as iterator, since SymIntArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return wrapped_symint_array_ref.cbegin();
}
constexpr const_iterator cend() const {
return wrapped_symint_array_ref.cend();
}

/// empty - Check if the array is empty.
constexpr bool empty() const {
return size() == 0;
}

constexpr const c10::SymInt* data() const {
return wrapped_symint_array_ref.data();
}

/// size - Get the array size.
constexpr size_t size() const {
return wrapped_symint_array_ref.size();
}

/// front - Get the first element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& front() const {
return wrapped_symint_array_ref.front();
}

/// back - Get the last element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& back() const {
return wrapped_symint_array_ref.back();
}

/// equals - Check for element-wise equality.
constexpr bool equals(SymIntArrayRef RHS) const {
return this->wrapped_symint_array_ref.equals(RHS.wrapped_symint_array_ref);
}

/// slice(n, m) - Take M elements of the array starting at element N
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef
slice(size_t N, size_t M) const {
return SymIntArrayRef(wrapped_symint_array_ref.data() + N, M);
}

/// slice(n) - Chop off the first N elements of the array.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef slice(size_t N) const {
return slice(N, size() - N);
}

/// @}
/// @name Operator Overloads
/// @{
constexpr const c10::SymInt& operator[](size_t Index) const {
return wrapped_symint_array_ref[Index];
}

/// Vector compatibility
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& at(size_t Index) const {
return wrapped_symint_array_ref.at(Index);
}

/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
type&
operator=(U&& Temporary) = delete;

/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
type&
operator=(std::initializer_list<U>) = delete;

/// @}
/// @name Expensive Operations
/// @{
std::vector<c10::SymInt> vec() const {
return wrapped_symint_array_ref.vec();
}

friend std::ostream& operator<<(
std::ostream& out,
const SymIntArrayRef& list);
/// @}
};
using SymIntArrayRef = ArrayRef<SymInt>;

TORCH_API at::IntArrayRef asIntArrayRefSlow(c10::SymIntArrayRef ar);
TORCH_API at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar);
TORCH_API c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt(
c10::SymIntArrayRef ar);

inline std::ostream& operator<<(
std::ostream& out,
const c10::SymIntArrayRef& list) {
return out << list.wrapped_symint_array_ref;
// Prefer using a more semantic constructor, like
// fromIntArrayRefKnownNonNegative
inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) {
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}

inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) {
return fromIntArrayRefUnchecked(array_ref);
}

inline SymIntArrayRef fromIntArrayRef(IntArrayRef array_ref) {
for (size_t i = 0; i < array_ref.size(); ++i) {
TORCH_CHECK(
SymInt::check_range(array_ref[i]),
"IntArrayRef contains an int that cannot be represented as a SymInt: ",
array_ref[i]);
}
return SymIntArrayRef(
reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size());
}

} // namespace c10
11 changes: 4 additions & 7 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return sym_sizes_custom();
}
// Sizes guaranteed to be non-negative, so unchecked cast is OK
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
return c10::fromIntArrayRefKnownNonNegative(
sizes_and_strides_.sizes_arrayref());
}

Expand All @@ -620,8 +620,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return extra_meta_->sizes_;
} else {
// Sizes guaranteed to be non-negative, so unchecked cast is OK
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
sizes_default());
return c10::fromIntArrayRefKnownNonNegative(sizes_default());
}
}

Expand Down Expand Up @@ -733,8 +732,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
return sym_strides_custom();
}
// strides guaranteed to be non-negative, so unchecked cast is OK
return c10::SymIntArrayRef::fromIntArrayRefUnchecked(strides_default());
return c10::fromIntArrayRefKnownNonNegative(strides_default());
}

IntArrayRef strides_default() const {
Expand All @@ -748,8 +746,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
if (has_symbolic_sizes_strides_) {
return extra_meta_->strides_;
} else {
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
strides_default());
return c10::fromIntArrayRefKnownNonNegative(strides_default());
}
}

Expand Down
3 changes: 2 additions & 1 deletion test/cpp/tensorexpr/test_quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ TEST_F(Quantization, QuantDequantUInt8_NLC) {
parseIR(graph_string, &*graph);

auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
x.unsafeGetTensorImpl()->set_sizes_and_strides({1, 2, 2}, {4, 1, 2});
x.unsafeGetTensorImpl()->set_sizes_and_strides(
std::initializer_list<int64_t>{1, 2, 2}, {4, 1, 2});
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
auto y_expected = at::dequantize(q);
TensorExprKernel k(graph);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/lazy/core/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
return c10::SymIntArrayRef(sym_sizes_->data(), sym_sizes_->size());
}

return c10::SymIntArrayRef::fromIntArrayRef(sizes_custom());
return c10::fromIntArrayRef(sizes_custom());
}

void LTCTensorImpl::setup_size_properties() {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/lazy/ts_backend/ts_native_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ at::Tensor LazyNativeFunctions::empty_strided(
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty_symint(
c10::SymIntArrayRef::fromIntArrayRef(size),
c10::fromIntArrayRef(size),
dtype,
layout,
device,
Expand Down Expand Up @@ -410,7 +410,7 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
return LazyNativeFunctions::view_copy_symint(
self, c10::SymIntArrayRef::fromIntArrayRef(size));
self, c10::fromIntArrayRef(size));
}

// This is needed by the torch.tensor constructor.
Expand Down
2 changes: 1 addition & 1 deletion torchgen/api/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def direct_solve(goal: NamedCType) -> str:
elif goal.type == BaseCType(symIntArrayRefT):
try:
r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
return f"c10::SymIntArrayRef::fromIntArrayRef({r})"
return f"c10::fromIntArrayRef({r})"
except UnsatError:
return direct_solve(NamedCType(goal.name, longSymVec_ctype))
elif goal.type == BaseCType(SymIntT):
Expand Down
2 changes: 1 addition & 1 deletion torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str]
if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) {
return self.reshape(size);
} else {
auto output = at::_ops::view::call(self, c10::SymIntArrayRef::fromIntArrayRef(size));
auto output = at::_ops::view::call(self, c10::fromIntArrayRef(size));
return output.clone();
}
}
Expand Down

0 comments on commit 9c78f59

Please sign in to comment.