Skip to content

Commit

Permalink
Retry Rework execution frame to reduce memory allocations (#11897)
Browse files Browse the repository at this point in the history
* Revert "Revert "Refactor ExecutionFrame and SessionState to reduce memory all… (#11888)"

This reverts commit d2cbae3.

* Revert prepacked_weights to avoid indirect inclusion in CUDA and TRT code that breaks the build.
  • Loading branch information
yuslepukhin authored Jun 20, 2022
1 parent 6ee2c1b commit 267a424
Show file tree
Hide file tree
Showing 33 changed files with 323 additions and 228 deletions.
24 changes: 24 additions & 0 deletions include/onnxruntime/core/framework/ortmemoryinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ struct OrtMemoryInfo {
return strcmp(name, other.name) < 0;
}

static void HashCombine(size_t h, size_t& seed) {
seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

// This is to make OrtMemoryInfo a valid key in hash tables
// we ignore device id
size_t Hash() const {
auto h = std::hash<int>()(alloc_type);
HashCombine(std::hash<int>()(mem_type), h);
HashCombine(std::hash<int>()(id), h);
HashCombine(std::hash<const char*>()(name), h);
return h;
}

std::string ToString() const {
std::ostringstream ostr;
ostr << "OrtMemoryInfo:["
Expand All @@ -51,6 +65,7 @@ struct OrtMemoryInfo {
}
};

// Required by hash tables
inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) {
return left.mem_type == other.mem_type &&
left.alloc_type == other.alloc_type &&
Expand All @@ -61,3 +76,12 @@ inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) {
inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { return !(lhs == rhs); }

std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info);

namespace std {
template<>
struct hash<OrtMemoryInfo> {
size_t operator()(const OrtMemoryInfo& i) const {
return i.Hash();
}
};
}
3 changes: 2 additions & 1 deletion include/onnxruntime/core/graph/graph_viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ class GraphViewer {
// if we're limiting the view to an IndexedSubGraph we need to create a few pieces of infrastructure that would
// usually come from the full graph
const IndexedSubGraph* filter_info_{nullptr};
std::unordered_set<NodeIndex> filtered_node_indices_;
using FilteredNodeSet = InlinedHashSet<NodeIndex>;
FilteredNodeSet filtered_node_indices_;
std::vector<const NodeArg*> filtered_node_inputs_;
std::vector<const NodeArg*> filtered_node_inputs_including_initializers_;
std::vector<const NodeArg*> filtered_node_outputs_;
Expand Down
23 changes: 13 additions & 10 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,16 @@ std::ostream& operator<<(std::ostream& out, std::pair<const SequentialExecutionP
const SequentialExecutionPlan& plan = *planinfo.first;
const SessionState& session_state = *planinfo.second;
auto& graph = session_state.GetGraphViewer();
std::unordered_map<int, std::string> index_to_name;

const auto& name_idx_map = session_state.GetOrtValueNameIdxMap();
InlinedHashMap<int, std::string_view> index_to_name;
index_to_name.reserve(name_idx_map.Size());

out << "Allocation Plan:\n";
out << "(ort_value_idx) output_name : <allocation plan>\n";
auto plan_size = plan.allocation_plan.size();

for (auto& name_index : session_state.GetOrtValueNameIdxMap()) {
for (auto& name_index : name_idx_map) {
auto index = name_index.second;
index_to_name[index] = name_index.first;
out << "(" << index << ") " << name_index.first << " : ";
Expand Down Expand Up @@ -141,10 +144,10 @@ static const KernelCreateInfo& GetKernelCreateInfo(
class PlannerImpl {
public:
PlannerImpl(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const NodeArg*>& outer_scope_node_args, const ExecutionProviders& providers,
gsl::span<const NodeArg* const> outer_scope_node_args, const ExecutionProviders& providers,
const KernelCreateInfoMap& kernel_create_info_map,
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const InlinedHashMap<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
: context_(context),
Expand All @@ -166,13 +169,13 @@ class PlannerImpl {

const Node* parent_node_;
const onnxruntime::GraphViewer& graph_viewer_;
const std::vector<const NodeArg*>& outer_scope_node_args_;
gsl::span<const NodeArg* const> outer_scope_node_args_;
const ExecutionProviders& execution_providers_;

const KernelCreateInfoMap& kernel_create_info_map_;
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps_;

const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map_;
const InlinedHashMap<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map_;

const OrtValueNameIdxMap& ort_value_name_idx_map_;

Expand Down Expand Up @@ -1331,16 +1334,16 @@ Status PlannerImpl::CreatePlan() {
Status SequentialPlanner::CreatePlan(
const Node* parent_node,
const onnxruntime::GraphViewer& graph_viewer,
const std::vector<const NodeArg*>& outer_scope_node_args,
gsl::span<const NodeArg* const> outer_scope_node_args,
const ExecutionProviders& providers,
const KernelCreateInfoMap& kernel_create_info_map,
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const InlinedHashMap<OrtValueName, OrtMemoryInfo>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context,
std::unique_ptr<SequentialExecutionPlan>& plan) {
std::optional<SequentialExecutionPlan>& plan) {
// allocate/reset here so we know it's clean
plan = std::make_unique<SequentialExecutionPlan>();
plan.emplace();

PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
kernel_create_info_map, subgraphs_kernel_create_info_maps,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/framework/allocation_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ class SequentialPlanner {
// This API allows user to provide a custom planner context.
static Status CreatePlan(
const Node* parent_node, const onnxruntime::GraphViewer& graph,
const std::vector<const NodeArg*>& outer_scope_node_args,
gsl::span<const NodeArg* const> outer_scope_node_args,
const ExecutionProviders& providers,
const KernelCreateInfoMap& kernel_create_info_map,
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const std::unordered_map<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map,
const InlinedHashMap<OrtValueName, OrtMemoryInfo>& outer_scope_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context,
std::unique_ptr<SequentialExecutionPlan>& plan);
std::optional<SequentialExecutionPlan>& plan);
};

} // namespace onnxruntime
61 changes: 32 additions & 29 deletions onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ namespace onnxruntime {

IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map,
const NodeIndexInfo& node_index_info,
const std::vector<int>& fetch_mlvalue_idxs)
gsl::span<const int> fetch_mlvalue_idxs)
: node_index_info_(node_index_info),
all_values_size_(static_cast<size_t>(ort_value_idx_map.MaxIdx()) + 1),
fetch_mlvalue_idxs_(fetch_mlvalue_idxs),
fetch_mlvalue_idxs_(fetch_mlvalue_idxs.begin(), fetch_mlvalue_idxs.end()),
ort_value_idx_map_(ort_value_idx_map) {
ORT_ENFORCE(node_index_info_.GetMaxMLValueIdx() == ort_value_idx_map.MaxIdx(),
"node_index_info and ort_value_idx_map are out of sync and cannot be used");
Expand All @@ -55,7 +55,7 @@ Status IExecutionFrame::SetOutputMLValue(int index, const OrtValue& ort_value) {
#endif

#ifdef ENABLE_TRAINING
void IExecutionFrame::UpdateFeeds(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds) {
void IExecutionFrame::UpdateFeeds(gsl::span<const int> feed_mlvalue_idxs, gsl::span<const OrtValue> feeds) {
ORT_ENFORCE(feed_mlvalue_idxs.size() == feeds.size());

for (size_t idx = 0, end = feed_mlvalue_idxs.size(); idx < end; ++idx) {
Expand All @@ -68,11 +68,12 @@ void IExecutionFrame::UpdateFeeds(const std::vector<int>& feed_mlvalue_idxs, con
}
}

void IExecutionFrame::UpdateFetches(const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches, const std::unordered_map<int, OrtValue>& initializers) {
void IExecutionFrame::UpdateFetches(gsl::span<const int> fetch_mlvalue_idxs,
gsl::span<const OrtValue> fetches, const std::unordered_map<int, OrtValue>& initializers) {
ORT_ENFORCE(fetch_mlvalue_idxs.size() == fetches.size());

if (!fetches.empty()) {
fetch_mlvalue_idxs_ = fetch_mlvalue_idxs;
fetch_mlvalue_idxs_.assign(fetch_mlvalue_idxs.begin(), fetch_mlvalue_idxs.end());

auto num_fetches = fetch_mlvalue_idxs_.size();

Expand Down Expand Up @@ -102,7 +103,7 @@ void IExecutionFrame::UpdateFetches(const std::vector<int>& fetch_mlvalue_idxs,
}
}

Status IExecutionFrame::GetOutputs(const std::vector<int>& fetch_mlvalue_idxs, std::vector<OrtValue>& fetches) {
Status IExecutionFrame::GetOutputs(gsl::span<const int> fetch_mlvalue_idxs, std::vector<OrtValue>& fetches) {
auto num_fetches = fetch_mlvalue_idxs.size();

if (fetches.empty()) {
Expand Down Expand Up @@ -213,10 +214,10 @@ int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const {
return ort_value_idx;
}

void IExecutionFrame::Init(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
void IExecutionFrame::Init(gsl::span<const int> feed_mlvalue_idxs, gsl::span<const OrtValue> feeds,
const std::unordered_map<int, OrtValue>& initializers,
const std::function<bool(const std::string& name)>& is_initializer_sparse_func,
const std::vector<OrtValue>& fetches) {
gsl::span<const OrtValue> fetches) {
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size());
ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size());

Expand Down Expand Up @@ -280,7 +281,7 @@ void IExecutionFrame::Init(const std::vector<int>& feed_mlvalue_idxs, const std:
*dest.GetMutable<SparseTensor>()));
} else {
#else
ORT_UNUSED_PARAMETER(is_initializer_sparse_func);
ORT_UNUSED_PARAMETER(is_initializer_sparse_func);
#endif // !defined(DISABLE_SPARSE_TENSORS)
if (!dest.IsAllocated()) {
// NOTE: This doesn't need to support ExecutionFrame custom allocators as they only come into play
Expand Down Expand Up @@ -331,14 +332,13 @@ bool IExecutionFrame::IsOutput(int ort_value_idx) const {
return std::find(fetch_mlvalue_idxs_.begin(), fetch_mlvalue_idxs_.end(), ort_value_idx) != fetch_mlvalue_idxs_.end();
}

ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches,
ExecutionFrame::ExecutionFrame(gsl::span<const int> feed_mlvalue_idxs, gsl::span<const OrtValue> feeds,
gsl::span<const int> fetch_mlvalue_idxs, gsl::span<const OrtValue> fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state)
: IExecutionFrame(session_state.GetOrtValueNameIdxMap(), session_state.GetNodeIndexInfo(), fetch_mlvalue_idxs),
session_state_(session_state),
mem_patterns_(nullptr),
planner_(nullptr) {
mem_patterns_(nullptr) {
Init(
feed_mlvalue_idxs, feeds, session_state.GetInitializedTensors(),
#if !defined(DISABLE_SPARSE_TENSORS)
Expand All @@ -362,12 +362,12 @@ ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const

// map the custom allocators to ort_value_idx entries
if (!fetch_allocators.empty()) {
for (size_t idx = 0, end = fetch_mlvalue_idxs.size(); idx < end; ++idx) {
int ort_value_idx = fetch_mlvalue_idxs[idx];

auto custom_alloc_entry = fetch_allocators.find(idx);
if (custom_alloc_entry != fetch_allocators.cend()) {
custom_allocators_[ort_value_idx] = custom_alloc_entry->second;
custom_allocators_.reserve(fetch_allocators.size());
const auto idx_size = fetch_mlvalue_idxs.size();
for (const auto& e : fetch_allocators) {
if (e.first < idx_size) {
int ort_value_idx = fetch_mlvalue_idxs[e.first];
custom_allocators_.insert_or_assign(ort_value_idx, e.second);
}
}
}
Expand All @@ -388,12 +388,13 @@ ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const
//if there are some traditional ml value type in inputs disable the memory pattern optimization.
if (all_tensors) {
mem_patterns_ = session_state.GetMemoryPatternGroup(feeds, feed_mlvalue_idxs, inferred_shapes_);
// if no existing patterns, generate one in this executionframe
// if no existing patterns, generate one in this execution frame
if (!mem_patterns_) {
planner_ = std::make_unique<OrtValuePatternPlanner>(*session_state.GetExecutionPlan());
planner_.emplace(*session_state.GetExecutionPlan());
} else {
// pre-allocate the big chunk requested in memory pattern.
// all the internal kernel's input/output tensors will be allocated on these buffer.
buffers_.reserve(mem_patterns_->locations.size());
for (size_t i = 0; i < mem_patterns_->locations.size(); i++) {
const auto& location = mem_patterns_->locations[i];
ORT_ENFORCE(buffers_.find(location) == buffers_.end());
Expand Down Expand Up @@ -828,7 +829,7 @@ const AllocPlanPerValue& ExecutionFrame::GetAllocationPlan(int ort_value_idx) {
}

void ExecutionFrame::TraceAllocate(int ort_value_idx, size_t size) {
if (planner_) {
if (planner_.has_value()) {
// don't trace the output tensors or external outputs.
auto& allocation_plan = GetAllocationPlan(ort_value_idx);
if (allocation_plan.alloc_kind == AllocKind::kAllocateOutput ||
Expand All @@ -845,7 +846,7 @@ void ExecutionFrame::TraceAllocate(int ort_value_idx, size_t size) {

void ExecutionFrame::TraceFree(int ort_value_idx) {
// don't trace free on output tensors.
if (planner_ && !IsOutput(ort_value_idx)) {
if (planner_.has_value() && !IsOutput(ort_value_idx)) {
const SequentialExecutionPlan* p_seq_exec_plan = session_state_.GetExecutionPlan();
const auto& alloc_plan = p_seq_exec_plan->allocation_plan;
ORT_ENFORCE(ort_value_idx >= 0 && static_cast<size_t>(ort_value_idx) < alloc_plan.size());
Expand All @@ -870,8 +871,8 @@ void ExecutionFrame::TraceFree(int ort_value_idx) {

// generate memory pattern based on the tracing of memory allocation/free in current execution
// return error if the planner is not setup.
Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const {
if (!planner_) {
Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup& out) {
if (!planner_.has_value()) {
return Status(ONNXRUNTIME, FAIL, "Memory pattern planner is not enabled on this execution framework.");
}

Expand All @@ -889,10 +890,12 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {

// Search for inferred shape.
// If inferred shape is found, it's assigned to "shape" so that caller can use it.
auto it = inferred_shapes_.find(ort_value_idx);
if (it != inferred_shapes_.end()) {
shape = it->second;
return true;
if (inferred_shapes_ != nullptr) {
auto it = inferred_shapes_->find(ort_value_idx);
if (it != inferred_shapes_->end()) {
shape = it->second;
return true;
}
}

// Tell the caller if the search is successful or not.
Expand Down
Loading

0 comments on commit 267a424

Please sign in to comment.