Skip to content

Commit ddf914f

Browse files
committed
Replace custom optional with std::optional
We have C++17, no reason to keep this old thing around. Had to resolve some ambiguous overloads. ghstack-source-id: 20257c2e2e8722980b9de19bbc0d2044f0bc2919 ghstack-comment-id: 2707860333 Pull Request resolved: #9068
1 parent 46d0580 commit ddf914f

File tree

5 files changed

+41
-185
lines changed

5 files changed

+41
-185
lines changed

extension/aten_util/make_aten_functor_from_et_functor.h

-18
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,6 @@ struct type_convert<std::optional<F>, torch::executor::optional<T>> final {
166166
}
167167
};
168168

169-
// Optionals: ETen to ATen.
170-
template <class F, class T>
171-
struct type_convert<torch::executor::optional<F>, std::optional<T>> final {
172-
public:
173-
torch::executor::optional<F> val;
174-
std::unique_ptr<struct type_convert<F, T>> convert_struct;
175-
explicit type_convert(torch::executor::optional<F> value) : val(value) {}
176-
std::optional<T> call() {
177-
if (val.has_value()) {
178-
convert_struct = std::make_unique<struct type_convert<F, T>>(
179-
type_convert<F, T>(val.value()));
180-
return std::optional<T>(convert_struct->call());
181-
} else {
182-
return std::optional<T>();
183-
}
184-
}
185-
};
186-
187169
// ArrayRefs: ATen to ETen.
188170
template <class F, class T>
189171
struct type_convert<c10::ArrayRef<F>, torch::executor::ArrayRef<T>> final {

kernels/portable/cpu/util/reduce_util.h

+31
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ size_t get_reduced_dim_product(
163163
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
164164
dim_list);
165165

166+
// Resolve ambiguity between the above two overloads -- ArrayRef and
167+
// optional are both implicitly constructible from int64_t.
168+
inline size_t get_reduced_dim_product(
169+
const executorch::aten::Tensor& in,
170+
int64_t dim) {
171+
return get_reduced_dim_product(in, executorch::aten::optional<int64_t>(dim));
172+
}
173+
166174
size_t get_out_numel(
167175
const executorch::aten::Tensor& in,
168176
const executorch::aten::optional<int64_t>& dim);
@@ -172,6 +180,12 @@ size_t get_out_numel(
172180
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
173181
dim_list);
174182

183+
// Resolve ambiguity between the above two overloads -- ArrayRef and
184+
// optional are both implicitly constructible from int64_t.
185+
inline size_t get_out_numel(const executorch::aten::Tensor& in, int64_t dim) {
186+
return get_out_numel(in, executorch::aten::optional<int64_t>(dim));
187+
}
188+
175189
size_t get_init_index(
176190
const executorch::aten::Tensor& in,
177191
const executorch::aten::optional<int64_t>& dim,
@@ -183,6 +197,12 @@ size_t get_init_index(
183197
dim_list,
184198
const size_t out_ix);
185199

200+
inline size_t get_init_index(
201+
const executorch::aten::Tensor& in,
202+
int64_t dim,
203+
const size_t out_ix) {
204+
return get_init_index(in, executorch::aten::optional<int64_t>(dim), out_ix);
205+
}
186206
//
187207
// Iteration Functions
188208
//
@@ -614,6 +634,17 @@ Error resize_reduction_out(
614634
bool keepdim,
615635
executorch::aten::Tensor& out);
616636

637+
// Resolve ambiguity between the above two overloads -- ArrayRef and
638+
// optional are both implicitly constructible from int64_t.
639+
inline Error resize_reduction_out(
640+
const executorch::aten::Tensor& in,
641+
int64_t dim,
642+
bool keepdim,
643+
executorch::aten::Tensor& out) {
644+
return resize_reduction_out(
645+
in, executorch::aten::optional<int64_t>(dim), keepdim, out);
646+
}
647+
617648
#ifndef USE_ATEN_LIB
618649
bool check_reduction_args(
619650
const Tensor& in,

runtime/core/exec_aten/exec_aten.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ template <typename T>
106106
using optional = torch::executor::optional<T>;
107107
using nullopt_t = torch::executor::nullopt_t;
108108
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
109-
static constexpr nullopt_t nullopt{0};
109+
using std::nullopt;
110110
using ScalarType = torch::executor::ScalarType;
111111
using TensorList = ArrayRef<Tensor>;
112112
using Scalar = torch::executor::Scalar;

runtime/core/portable_type/optional.h

+7-164
Original file line numberDiff line numberDiff line change
@@ -8,175 +8,18 @@
88

99
#pragma once
1010

11-
#include <executorch/runtime/platform/assert.h>
12-
#include <new>
13-
#include <utility> // std::forward and other template magic checks
11+
#include <optional>
1412

1513
namespace executorch {
1614
namespace runtime {
1715
namespace etensor {
1816

19-
/// Used to indicate an optional type with uninitialized state.
20-
struct nullopt_t final {
21-
constexpr explicit nullopt_t(int32_t) {}
22-
};
23-
24-
/// A constant of type nullopt_t that is used to indicate an optional type with
25-
/// uninitialized state.
26-
constexpr nullopt_t nullopt{0};
27-
28-
/// Leaner optional class, subset of c10, std, and boost optional APIs.
29-
template <class T>
30-
class optional final {
31-
public:
32-
/// The type wrapped by the optional class.
33-
using value_type = T;
34-
35-
/// Constructs an optional object that does not contain a value.
36-
/* implicit */ optional() noexcept : storage_(trivial_init), init_(false) {}
37-
38-
/// Constructs an optional object that does not contain a value.
39-
/* implicit */ optional(nullopt_t) noexcept
40-
: storage_(trivial_init), init_(false) {}
41-
42-
/// Constructs an optional object that matches the state of v.
43-
/* implicit */ optional(const optional<T>& v)
44-
: storage_(trivial_init), init_(v.init_) {
45-
if (init_) {
46-
new (&storage_.value_) T(v.storage_.value_);
47-
}
48-
}
49-
50-
/// Constructs an optional object that contains the specified value.
51-
/* implicit */ optional(const T& v) : storage_(v), init_(true) {}
52-
53-
/// Constructs an optional object from v.
54-
/* implicit */ optional(optional<T>&& v) noexcept(
55-
std::is_nothrow_move_constructible<T>::value)
56-
: storage_(trivial_init), init_(v.init_) {
57-
if (init_) {
58-
new (&storage_.value_) T(std::forward<T>(v.storage_.value_));
59-
}
60-
}
61-
62-
/// Constructs an optional object that contains the specified value.
63-
/* implicit */ optional(T&& v) : storage_(std::forward<T>(v)), init_(true) {}
64-
65-
optional& operator=(const optional& rhs) {
66-
if (init_ && !rhs.init_) {
67-
clear();
68-
} else if (!init_ && rhs.init_) {
69-
init_ = true;
70-
new (&storage_.value_) T(rhs.storage_.value_);
71-
} else if (init_ && rhs.init_) {
72-
storage_.value_ = rhs.storage_.value_;
73-
}
74-
return *this;
75-
}
76-
77-
optional& operator=(optional&& rhs) noexcept(
78-
std::is_nothrow_move_assignable<T>::value &&
79-
std::is_nothrow_move_constructible<T>::value) {
80-
if (init_ && !rhs.init_) {
81-
clear();
82-
} else if (!init_ && rhs.init_) {
83-
init_ = true;
84-
new (&storage_.value_) T(std::forward<T>(rhs.storage_.value_));
85-
} else if (init_ && rhs.init_) {
86-
storage_.value_ = std::forward<T>(rhs.storage_.value_);
87-
}
88-
return *this;
89-
}
90-
91-
/// Destroys the stored value if there is one
92-
~optional() {
93-
if (init_) {
94-
storage_.value_.~T();
95-
}
96-
}
97-
98-
optional& operator=(nullopt_t) noexcept {
99-
clear();
100-
return *this;
101-
}
102-
103-
/// Returns true if the object contains a value, false otherwise
104-
explicit operator bool() const noexcept {
105-
return init_;
106-
}
107-
108-
/// Returns true if the object contains a value, false otherwise
109-
bool has_value() const noexcept {
110-
return init_;
111-
}
112-
113-
/// Returns a constant reference to the contained value. Calls ET_CHECK if
114-
/// the object does not contain a value.
115-
T const& value() const& {
116-
ET_CHECK(init_);
117-
return contained_val();
118-
}
119-
120-
/// Returns a mutable reference to the contained value. Calls ET_CHECK if the
121-
/// object does not contain a value.
122-
T& value() & {
123-
ET_CHECK(init_);
124-
return contained_val();
125-
}
126-
127-
/// Returns an rvalue of the contained value. Calls ET_CHECK if the object
128-
/// does not contain a value.
129-
T&& value() && {
130-
ET_CHECK(init_);
131-
return std::forward<T>(contained_val());
132-
}
133-
134-
private:
135-
// Used to invoke the dummy ctor of storage_t in the initializer lists of
136-
// optional_base as default ctor is implicitly deleted because T is nontrivial
137-
struct trivial_init_t {
138-
} trivial_init{};
139-
140-
/**
141-
* A wrapper type that lets us avoid constructing a T when there is no value.
142-
* If there is a value present, the optional class must destroy it.
143-
*/
144-
union storage_t {
145-
/// A small, trivially-constructable alternative to T.
146-
unsigned char dummy_;
147-
/// The constructed value itself, if optional::has_value_ is true.
148-
T value_;
149-
150-
/* implicit */ storage_t(trivial_init_t) {
151-
dummy_ = 0;
152-
}
153-
154-
template <class... Args>
155-
storage_t(Args&&... args) : value_(std::forward<Args>(args)...) {}
156-
157-
~storage_t() {}
158-
};
159-
160-
const T& contained_val() const& {
161-
return storage_.value_;
162-
}
163-
T&& contained_val() && {
164-
return std::move(storage_.value_);
165-
}
166-
T& contained_val() & {
167-
return storage_.value_;
168-
}
169-
170-
void clear() noexcept {
171-
if (init_) {
172-
storage_.value_.~T();
173-
}
174-
init_ = false;
175-
}
176-
177-
storage_t storage_;
178-
bool init_;
179-
};
17+
// NOLINTNEXTLINE(misc-unused-using-decls)
18+
using std::nullopt;
19+
// NOLINTNEXTLINE(misc-unused-using-decls)
20+
using std::nullopt_t;
21+
// NOLINTNEXTLINE(misc-unused-using-decls)
22+
using std::optional;
18023

18124
} // namespace etensor
18225
} // namespace runtime

runtime/core/portable_type/test/optional_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ TEST(TestOptional, NulloptHasNoValue) {
3636
EXPECT_FALSE(o.has_value());
3737
}
3838

39-
TEST(TestOptional, ValueOfEmptyOptionalShouldDie) {
39+
TEST(TestOptional, ValueOfEmptyOptionalShouldThrow) {
4040
optional<int32_t> o;
4141
EXPECT_FALSE(o.has_value());
4242

43-
ET_EXPECT_DEATH({ (void)o.value(); }, "");
43+
EXPECT_THROW({ (void)o.value(); }, std::bad_optional_access);
4444
}
4545

4646
TEST(TestOptional, IntValue) {

0 commit comments

Comments
 (0)