Skip to content

Commit c60d818

Browse files
yiming0416pytorchmergebot
authored andcommitted
[nativert] Move GraphExecutorBase to PyTorch core (#156196)
Summary: Moves GraphExecutorBase class to PyTorch core. GraphExecutorBase is a lightweight abstraction to execute a graph with execution frames without actually owning the graph nor the weights. This is introduced to decouple the state management of the top level runtime from the kernel executions so that sub graphs from higher order ops can be supported. Torch Native Runtime RFC: pytorch/rfcs#72 Test Plan: CI Rollback Plan: Differential Revision: D76830436 Pull Request resolved: #156196 Approved by: https://github.com/zhxchen17
1 parent 34d8e64 commit c60d818

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ libtorch_nativert_sources = [
600600
"torch/nativert/executor/Placement.cpp",
601601
"torch/nativert/executor/ExecutionPlanner.cpp",
602602
"torch/nativert/executor/ExecutionFrame.cpp",
603+
"torch/nativert/executor/GraphExecutorBase.cpp",
603604
"torch/nativert/executor/OpKernel.cpp",
604605
"torch/nativert/executor/PlacementUtils.cpp",
605606
"torch/nativert/executor/Weights.cpp",
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include <ATen/record_function.h>
2+
#include <torch/nativert/executor/GraphExecutorBase.h>
3+
4+
#include <c10/util/Logging.h>
5+
#include <caffe2/core/timer.h>
6+
7+
namespace torch::nativert {
8+
9+
GraphExecutorBase::GraphExecutorBase(
10+
const Graph& graph,
11+
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
12+
const ExecutorConfig& executorConfig)
13+
: graph_(graph),
14+
nodeKernels_(std::move(nodeKernels)),
15+
executorConfig_(executorConfig),
16+
execPlan_(ExecutionPlanner{graph_}.createPlan()) {};
17+
18+
void GraphExecutorBase::fillUserInputs(
19+
ExecutionFrame& frame,
20+
std::vector<c10::IValue> inputs) {
21+
RECORD_USER_SCOPE("Executor::fillUserInputs");
22+
const auto& inputValues = graph_.userInputs();
23+
TORCH_CHECK_EQ(inputValues.size(), inputs.size());
24+
25+
// load user input tensor into execution frame
26+
for (size_t i = 0; i < inputValues.size(); i++) {
27+
if (inputValues[i]) {
28+
frame.setIValue(inputValues[i]->id(), std::move(inputs[i]));
29+
}
30+
}
31+
}
32+
33+
ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes(
34+
ExecutionFrame& executionFrame,
35+
std::vector<std::vector<c10::IValue>> inputsList,
36+
const uint32_t warmupRuns,
37+
const uint32_t mainRuns) {
38+
// TODO: add support for memory profiling
39+
TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1);
40+
41+
ProfileMetrics results;
42+
const auto numNodes = static_cast<uint32_t>(nodeKernels_.size());
43+
results.timePerNode.resize(numNodes, 0);
44+
if (inputsList.empty()) {
45+
auto i = 0;
46+
for (const auto& nodeKernel : nodeKernels_) {
47+
std::string target(nodeKernel->node()->target());
48+
results.timePerNode[i] = 0;
49+
results.timePerNodeType[target] = 0;
50+
results.instancesPerNodeType[target]++;
51+
if (nodeKernel->hasPrimKernel()) {
52+
results.primNodesCount++;
53+
results.primNodes.insert(target);
54+
} else if (nodeKernel->hasStaticDispatch()) {
55+
results.staticDispatchNodesCount++;
56+
results.staticDispatchNodes.insert(target);
57+
}
58+
i++;
59+
}
60+
results.totalNodesCount = numNodes;
61+
for (const auto& p : results.timePerNodeType) {
62+
const std::string& kind = p.first;
63+
results.percentPerNodeType[kind] = 0;
64+
}
65+
return results;
66+
}
67+
68+
// Warmup
69+
for (uint32_t i = 0; i < warmupRuns; i++) {
70+
for (const auto& inputs : inputsList) {
71+
execute(executionFrame, inputs);
72+
}
73+
}
74+
75+
// Execute kernels
76+
caffe2::Timer timer;
77+
for (uint32_t i = 0; i < mainRuns; i++) {
78+
for (auto inputs : inputsList) {
79+
const auto& inputValues = graph_.userInputs();
80+
81+
TORCH_CHECK_EQ(inputValues.size(), inputs.size());
82+
for (size_t j = 0; j < inputValues.size(); j++) {
83+
executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j]));
84+
}
85+
for (NodeIndex nodeIdx = 0; nodeIdx < nodeKernels_.size(); ++nodeIdx) {
86+
timer.Start();
87+
nodeKernels_[nodeIdx]->compute(executionFrame);
88+
float millis = timer.MilliSeconds();
89+
results.timePerNode[nodeIdx] += millis;
90+
}
91+
}
92+
}
93+
94+
// Summarize results
95+
const float numTotalIters =
96+
(static_cast<float>(mainRuns) * static_cast<float>(inputsList.size()));
97+
for (const auto i : c10::irange(numNodes)) {
98+
const Node* node = nodeKernels_[i]->node();
99+
std::string target(node->target());
100+
results.timePerNode[i] /= numTotalIters;
101+
results.timePerNodeType[target] += results.timePerNode[i];
102+
results.instancesPerNodeType[target]++;
103+
if (nodeKernels_[i]->hasPrimKernel()) {
104+
results.primNodes.insert(target);
105+
results.primNodesCount++;
106+
} else if (nodeKernels_[i]->hasStaticDispatch()) {
107+
results.staticDispatchNodes.insert(target);
108+
results.staticDispatchNodesCount++;
109+
}
110+
results.totalTime += results.timePerNode[i];
111+
}
112+
results.totalNodesCount = numNodes;
113+
for (const auto& r : results.timePerNodeType) {
114+
const std::string& target = r.first;
115+
results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime;
116+
}
117+
return results;
118+
}
119+
120+
} // namespace torch::nativert
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#pragma once
2+
3+
#include <torch/nativert/executor/ExecutionFrame.h>
4+
#include <torch/nativert/executor/ExecutionPlanner.h>
5+
#include <torch/nativert/executor/ExecutorConfig.h>
6+
#include <torch/nativert/executor/OpKernel.h>
7+
#include <torch/nativert/graph/Graph.h>
8+
#include <torch/nativert/graph/GraphSignature.h>
9+
10+
namespace torch::nativert {
11+
12+
struct ProfileMetrics {
13+
size_t primNodesCount{0};
14+
size_t staticDispatchNodesCount{0};
15+
size_t totalNodesCount{0};
16+
std::vector<float> timePerNode;
17+
std::unordered_map<std::string, float> timePerNodeType;
18+
std::unordered_map<std::string, float> percentPerNodeType;
19+
std::unordered_map<std::string, int> instancesPerNodeType;
20+
std::unordered_set<std::string> staticDispatchNodes;
21+
std::unordered_set<std::string> primNodes;
22+
float totalTime{0};
23+
};
24+
25+
/**
26+
* GraphExecutor is a lightweight abstraction to execute a graph with
27+
* execution frames without actually owning the graph nor the weights. This is
28+
* introduced to decouple the state management of the top level runtime from the
29+
* kernel executions so that sub graphs from higher order ops can be supported.
30+
*/
31+
class GraphExecutorBase {
32+
public:
33+
GraphExecutorBase(
34+
const Graph& graph,
35+
std::vector<std::unique_ptr<OpKernel>> nodeKernels,
36+
const ExecutorConfig& executorConfig);
37+
virtual ~GraphExecutorBase() = default;
38+
39+
const Graph& graph() const {
40+
return graph_;
41+
}
42+
43+
// This API only returns the flattened UserOutputs,
44+
// intended to be used for Inference path
45+
virtual std::vector<c10::IValue> execute(
46+
ExecutionFrame& frame,
47+
std::vector<c10::IValue> inputs) = 0;
48+
49+
virtual std::vector<c10::IValue> executeWithPrefilledFrame(
50+
ExecutionFrame& frame) = 0;
51+
52+
ProfileMetrics benchmarkIndividualNodes(
53+
ExecutionFrame& executionFrame,
54+
std::vector<std::vector<c10::IValue>> inputs,
55+
const uint32_t warmup_runs,
56+
const uint32_t main_runs);
57+
58+
std::vector<std::unique_ptr<OpKernel>> stealKernels() {
59+
return std::move(nodeKernels_);
60+
}
61+
62+
void setKernels(std::vector<std::unique_ptr<OpKernel>>&& kernels) {
63+
nodeKernels_ = std::move(kernels);
64+
}
65+
66+
protected:
67+
void fillUserInputs(ExecutionFrame& frame, std::vector<c10::IValue> inputs);
68+
69+
const Graph& graph_;
70+
71+
// cache of the constructed kernels to avoid reconstruction per execution
72+
std::vector<std::unique_ptr<OpKernel>> nodeKernels_;
73+
74+
const ExecutorConfig& executorConfig_;
75+
76+
std::unique_ptr<ExecutionPlan> execPlan_;
77+
};
78+
79+
} // namespace torch::nativert

0 commit comments

Comments
 (0)