Skip to content

Replace custom optional with std::optional #9068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions extension/aten_util/make_aten_functor_from_et_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,6 @@ struct type_convert<std::optional<F>, torch::executor::optional<T>> final {
}
};

// Optionals: ETen to ATen.
template <class F, class T>
struct type_convert<torch::executor::optional<F>, std::optional<T>> final {
public:
torch::executor::optional<F> val;
std::unique_ptr<struct type_convert<F, T>> convert_struct;
explicit type_convert(torch::executor::optional<F> value) : val(value) {}
std::optional<T> call() {
if (val.has_value()) {
convert_struct = std::make_unique<struct type_convert<F, T>>(
type_convert<F, T>(val.value()));
return std::optional<T>(convert_struct->call());
} else {
return std::optional<T>();
}
}
};

// ArrayRefs: ATen to ETen.
template <class F, class T>
struct type_convert<c10::ArrayRef<F>, torch::executor::ArrayRef<T>> final {
Expand Down
31 changes: 31 additions & 0 deletions kernels/portable/cpu/util/reduce_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ size_t get_reduced_dim_product(
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
dim_list);

// Resolve ambiguity between the above two overloads -- ArrayRef and
// optional are both implicitly constructible from int64_t.
inline size_t get_reduced_dim_product(
const executorch::aten::Tensor& in,
int64_t dim) {
return get_reduced_dim_product(in, executorch::aten::optional<int64_t>(dim));
}

size_t get_out_numel(
const executorch::aten::Tensor& in,
const executorch::aten::optional<int64_t>& dim);
Expand All @@ -172,6 +180,12 @@ size_t get_out_numel(
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
dim_list);

// Resolve ambiguity between the above two overloads -- ArrayRef and
// optional are both implicitly constructible from int64_t.
inline size_t get_out_numel(const executorch::aten::Tensor& in, int64_t dim) {
return get_out_numel(in, executorch::aten::optional<int64_t>(dim));
}

size_t get_init_index(
const executorch::aten::Tensor& in,
const executorch::aten::optional<int64_t>& dim,
Expand All @@ -183,6 +197,12 @@ size_t get_init_index(
dim_list,
const size_t out_ix);

inline size_t get_init_index(
const executorch::aten::Tensor& in,
int64_t dim,
const size_t out_ix) {
return get_init_index(in, executorch::aten::optional<int64_t>(dim), out_ix);
}
//
// Iteration Functions
//
Expand Down Expand Up @@ -614,6 +634,17 @@ Error resize_reduction_out(
bool keepdim,
executorch::aten::Tensor& out);

// Resolve ambiguity between the above two overloads -- ArrayRef and
// optional are both implicitly constructible from int64_t.
inline Error resize_reduction_out(
const executorch::aten::Tensor& in,
int64_t dim,
bool keepdim,
executorch::aten::Tensor& out) {
return resize_reduction_out(
in, executorch::aten::optional<int64_t>(dim), keepdim, out);
}

#ifndef USE_ATEN_LIB
bool check_reduction_args(
const Tensor& in,
Expand Down
2 changes: 1 addition & 1 deletion runtime/core/exec_aten/exec_aten.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ template <typename T>
using optional = torch::executor::optional<T>;
using nullopt_t = torch::executor::nullopt_t;
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static constexpr nullopt_t nullopt{0};
using std::nullopt;
using ScalarType = torch::executor::ScalarType;
using TensorList = ArrayRef<Tensor>;
using Scalar = torch::executor::Scalar;
Expand Down
171 changes: 7 additions & 164 deletions runtime/core/portable_type/optional.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,175 +8,18 @@

#pragma once

#include <executorch/runtime/platform/assert.h>
#include <new>
#include <utility> // std::forward and other template magic checks
#include <optional>

namespace executorch {
namespace runtime {
namespace etensor {

/// Used to indicate an optional type with uninitialized state.
struct nullopt_t final {
constexpr explicit nullopt_t(int32_t) {}
};

/// A constant of type nullopt_t that is used to indicate an optional type with
/// uninitialized state.
constexpr nullopt_t nullopt{0};

/// Leaner optional class, subset of c10, std, and boost optional APIs.
template <class T>
class optional final {
public:
/// The type wrapped by the optional class.
using value_type = T;

/// Constructs an optional object that does not contain a value.
/* implicit */ optional() noexcept : storage_(trivial_init), init_(false) {}

/// Constructs an optional object that does not contain a value.
/* implicit */ optional(nullopt_t) noexcept
: storage_(trivial_init), init_(false) {}

/// Constructs an optional object that matches the state of v.
/* implicit */ optional(const optional<T>& v)
: storage_(trivial_init), init_(v.init_) {
if (init_) {
new (&storage_.value_) T(v.storage_.value_);
}
}

/// Constructs an optional object that contains the specified value.
/* implicit */ optional(const T& v) : storage_(v), init_(true) {}

/// Constructs an optional object from v.
/* implicit */ optional(optional<T>&& v) noexcept(
std::is_nothrow_move_constructible<T>::value)
: storage_(trivial_init), init_(v.init_) {
if (init_) {
new (&storage_.value_) T(std::forward<T>(v.storage_.value_));
}
}

/// Constructs an optional object that contains the specified value.
/* implicit */ optional(T&& v) : storage_(std::forward<T>(v)), init_(true) {}

optional& operator=(const optional& rhs) {
if (init_ && !rhs.init_) {
clear();
} else if (!init_ && rhs.init_) {
init_ = true;
new (&storage_.value_) T(rhs.storage_.value_);
} else if (init_ && rhs.init_) {
storage_.value_ = rhs.storage_.value_;
}
return *this;
}

optional& operator=(optional&& rhs) noexcept(
std::is_nothrow_move_assignable<T>::value &&
std::is_nothrow_move_constructible<T>::value) {
if (init_ && !rhs.init_) {
clear();
} else if (!init_ && rhs.init_) {
init_ = true;
new (&storage_.value_) T(std::forward<T>(rhs.storage_.value_));
} else if (init_ && rhs.init_) {
storage_.value_ = std::forward<T>(rhs.storage_.value_);
}
return *this;
}

/// Destroys the stored value if there is one
~optional() {
if (init_) {
storage_.value_.~T();
}
}

optional& operator=(nullopt_t) noexcept {
clear();
return *this;
}

/// Returns true if the object contains a value, false otherwise
explicit operator bool() const noexcept {
return init_;
}

/// Returns true if the object contains a value, false otherwise
bool has_value() const noexcept {
return init_;
}

/// Returns a constant reference to the contained value. Calls ET_CHECK if
/// the object does not contain a value.
T const& value() const& {
ET_CHECK(init_);
return contained_val();
}

/// Returns a mutable reference to the contained value. Calls ET_CHECK if the
/// object does not contain a value.
T& value() & {
ET_CHECK(init_);
return contained_val();
}

/// Returns an rvalue of the contained value. Calls ET_CHECK if the object
/// does not contain a value.
T&& value() && {
ET_CHECK(init_);
return std::forward<T>(contained_val());
}

private:
// Used to invoke the dummy ctor of storage_t in the initializer lists of
// optional_base as default ctor is implicitly deleted because T is nontrivial
struct trivial_init_t {
} trivial_init{};

/**
* A wrapper type that lets us avoid constructing a T when there is no value.
* If there is a value present, the optional class must destroy it.
*/
union storage_t {
/// A small, trivially-constructable alternative to T.
unsigned char dummy_;
/// The constructed value itself, if optional::has_value_ is true.
T value_;

/* implicit */ storage_t(trivial_init_t) {
dummy_ = 0;
}

template <class... Args>
storage_t(Args&&... args) : value_(std::forward<Args>(args)...) {}

~storage_t() {}
};

const T& contained_val() const& {
return storage_.value_;
}
T&& contained_val() && {
return std::move(storage_.value_);
}
T& contained_val() & {
return storage_.value_;
}

void clear() noexcept {
if (init_) {
storage_.value_.~T();
}
init_ = false;
}

storage_t storage_;
bool init_;
};
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::nullopt;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::nullopt_t;
// NOLINTNEXTLINE(misc-unused-using-decls)
using std::optional;

} // namespace etensor
} // namespace runtime
Expand Down
4 changes: 2 additions & 2 deletions runtime/core/portable_type/test/optional_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ TEST(TestOptional, NulloptHasNoValue) {
EXPECT_FALSE(o.has_value());
}

TEST(TestOptional, ValueOfEmptyOptionalShouldDie) {
TEST(TestOptional, ValueOfEmptyOptionalShouldThrow) {
optional<int32_t> o;
EXPECT_FALSE(o.has_value());

ET_EXPECT_DEATH({ (void)o.value(); }, "");
EXPECT_THROW({ (void)o.value(); }, std::bad_optional_access);
}

TEST(TestOptional, IntValue) {
Expand Down
Loading