|
| 1 | +#include <ATen/record_function.h> |
| 2 | +#include <torch/nativert/executor/GraphExecutorBase.h> |
| 3 | + |
| 4 | +#include <c10/util/Logging.h> |
| 5 | +#include <caffe2/core/timer.h> |
| 6 | + |
| 7 | +namespace torch::nativert { |
| 8 | + |
| 9 | +GraphExecutorBase::GraphExecutorBase( |
| 10 | + const Graph& graph, |
| 11 | + std::vector<std::unique_ptr<OpKernel>> nodeKernels, |
| 12 | + const ExecutorConfig& executorConfig) |
| 13 | + : graph_(graph), |
| 14 | + nodeKernels_(std::move(nodeKernels)), |
| 15 | + executorConfig_(executorConfig), |
| 16 | + execPlan_(ExecutionPlanner{graph_}.createPlan()) {}; |
| 17 | + |
| 18 | +void GraphExecutorBase::fillUserInputs( |
| 19 | + ExecutionFrame& frame, |
| 20 | + std::vector<c10::IValue> inputs) { |
| 21 | + RECORD_USER_SCOPE("Executor::fillUserInputs"); |
| 22 | + const auto& inputValues = graph_.userInputs(); |
| 23 | + TORCH_CHECK_EQ(inputValues.size(), inputs.size()); |
| 24 | + |
| 25 | + // load user input tensor into execution frame |
| 26 | + for (size_t i = 0; i < inputValues.size(); i++) { |
| 27 | + if (inputValues[i]) { |
| 28 | + frame.setIValue(inputValues[i]->id(), std::move(inputs[i])); |
| 29 | + } |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( |
| 34 | + ExecutionFrame& executionFrame, |
| 35 | + std::vector<std::vector<c10::IValue>> inputsList, |
| 36 | + const uint32_t warmupRuns, |
| 37 | + const uint32_t mainRuns) { |
| 38 | + // TODO: add support for memory profiling |
| 39 | + TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1); |
| 40 | + |
| 41 | + ProfileMetrics results; |
| 42 | + const auto numNodes = static_cast<uint32_t>(nodeKernels_.size()); |
| 43 | + results.timePerNode.resize(numNodes, 0); |
| 44 | + if (inputsList.empty()) { |
| 45 | + auto i = 0; |
| 46 | + for (const auto& nodeKernel : nodeKernels_) { |
| 47 | + std::string target(nodeKernel->node()->target()); |
| 48 | + results.timePerNode[i] = 0; |
| 49 | + results.timePerNodeType[target] = 0; |
| 50 | + results.instancesPerNodeType[target]++; |
| 51 | + if (nodeKernel->hasPrimKernel()) { |
| 52 | + results.primNodesCount++; |
| 53 | + results.primNodes.insert(target); |
| 54 | + } else if (nodeKernel->hasStaticDispatch()) { |
| 55 | + results.staticDispatchNodesCount++; |
| 56 | + results.staticDispatchNodes.insert(target); |
| 57 | + } |
| 58 | + i++; |
| 59 | + } |
| 60 | + results.totalNodesCount = numNodes; |
| 61 | + for (const auto& p : results.timePerNodeType) { |
| 62 | + const std::string& kind = p.first; |
| 63 | + results.percentPerNodeType[kind] = 0; |
| 64 | + } |
| 65 | + return results; |
| 66 | + } |
| 67 | + |
| 68 | + // Warmup |
| 69 | + for (uint32_t i = 0; i < warmupRuns; i++) { |
| 70 | + for (const auto& inputs : inputsList) { |
| 71 | + execute(executionFrame, inputs); |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + // Execute kernels |
| 76 | + caffe2::Timer timer; |
| 77 | + for (uint32_t i = 0; i < mainRuns; i++) { |
| 78 | + for (auto inputs : inputsList) { |
| 79 | + const auto& inputValues = graph_.userInputs(); |
| 80 | + |
| 81 | + TORCH_CHECK_EQ(inputValues.size(), inputs.size()); |
| 82 | + for (size_t j = 0; j < inputValues.size(); j++) { |
| 83 | + executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j])); |
| 84 | + } |
| 85 | + for (NodeIndex nodeIdx = 0; nodeIdx < nodeKernels_.size(); ++nodeIdx) { |
| 86 | + timer.Start(); |
| 87 | + nodeKernels_[nodeIdx]->compute(executionFrame); |
| 88 | + float millis = timer.MilliSeconds(); |
| 89 | + results.timePerNode[nodeIdx] += millis; |
| 90 | + } |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + // Summarize results |
| 95 | + const float numTotalIters = |
| 96 | + (static_cast<float>(mainRuns) * static_cast<float>(inputsList.size())); |
| 97 | + for (const auto i : c10::irange(numNodes)) { |
| 98 | + const Node* node = nodeKernels_[i]->node(); |
| 99 | + std::string target(node->target()); |
| 100 | + results.timePerNode[i] /= numTotalIters; |
| 101 | + results.timePerNodeType[target] += results.timePerNode[i]; |
| 102 | + results.instancesPerNodeType[target]++; |
| 103 | + if (nodeKernels_[i]->hasPrimKernel()) { |
| 104 | + results.primNodes.insert(target); |
| 105 | + results.primNodesCount++; |
| 106 | + } else if (nodeKernels_[i]->hasStaticDispatch()) { |
| 107 | + results.staticDispatchNodes.insert(target); |
| 108 | + results.staticDispatchNodesCount++; |
| 109 | + } |
| 110 | + results.totalTime += results.timePerNode[i]; |
| 111 | + } |
| 112 | + results.totalNodesCount = numNodes; |
| 113 | + for (const auto& r : results.timePerNodeType) { |
| 114 | + const std::string& target = r.first; |
| 115 | + results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime; |
| 116 | + } |
| 117 | + return results; |
| 118 | +} |
| 119 | + |
| 120 | +} // namespace torch::nativert |
0 commit comments