Skip to content

Commit

Permalink
Overload get() function for Optional type. (apache#9748)
Browse files Browse the repository at this point in the history
* upd

* simplify

* upd

* fix

* upd

* fix docstring
yzh119 authored Feb 16, 2022
1 parent 75122db commit 1da0093
Showing 3 changed files with 7 additions and 2 deletions.
5 changes: 5 additions & 0 deletions include/tvm/runtime/container/optional.h
Original file line number Diff line number Diff line change
@@ -93,6 +93,11 @@ class Optional : public ObjectRef {
ICHECK(data_ != nullptr);
return T(data_);
}
/*!
* \return The internal object pointer with container type of T.
* \note This function do not perform not-null checking.
*/
const ContainerType* get() const { return static_cast<ContainerType*>(data_.get()); }
/*!
* \return The contained value if the Optional is not null
* otherwise return the default_value.
2 changes: 1 addition & 1 deletion src/tir/transforms/inject_rolling_buffer.cc
Original file line number Diff line number Diff line change
@@ -172,7 +172,7 @@ class RollingBufferInjector : public StmtExprMutator {

auto it{std::find_if(
bound_iter_vars.begin(), bound_iter_vars.end(),
[&](Optional<Var> var) { return var && (var.value().get() == loop_var.get()); })};
[&](Optional<Var> var) { return var && (var.get() == loop_var.get()); })};

if (it != bound_iter_vars.end()) {
auto i{std::distance(bound_iter_vars.begin(), it)};
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ class BufferAllocationLocator : public StmtExprMutator {
// create buffers to be allocated at each stmts
for (const auto& kv : buffer_lca) {
const Buffer& buffer = kv.first;
const StmtNode* stmt = kv.second.defined() ? kv.second.value().get() : nullptr;
const StmtNode* stmt = kv.second.get();
if (arg_buffers.count(buffer.get())) {
continue;
}

0 comments on commit 1da0093

Please sign in to comment.