Skip to content

Commit

Permalink
[kineto] Optimize getStepCallbacks for common case of no active callb…
Browse files Browse the repository at this point in the history
…acks

Pull Request resolved: pytorch#77804

IIUC, the result of this function will be empty and unused if there are no sampled callbacks, which is the common case. We can accelerate this case by wrapping the result in an optional to save initializing an empty SmallVector.

Differential Revision: [D36497279](https://our.internmc.facebook.com/intern/diff/D36497279/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36497279/)!

Approved by: https://github.com/robieta
  • Loading branch information
swolchok authored and pytorchmergebot committed May 24, 2022
1 parent 02c4d87 commit c083489
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 25 deletions.
12 changes: 6 additions & 6 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,9 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl
.template getDispatchKeySetUnboxed<Args...>(args...);
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(!step_callbacks.empty() && op.operatorDef_->op.isObserved())) {
return callWithDispatchKeySlowPath<Return, Args...>(op, step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
Expand All @@ -568,9 +568,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
const auto& kernel = entry.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(!step_callbacks.empty() && entry.isObserved())) {
at::RecordFunction guard(std::move(step_callbacks));
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
at::RecordFunction guard(std::move(*step_callbacks));
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
auto& schema = op.schema();
auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
Expand Down
37 changes: 34 additions & 3 deletions aten/src/ATen/record_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class CacheEntry {
// The caller is expected to check `GlobalCallbackManager::get().version()'
// and call CacheEntry::update() if necessary.
StepCallbacks getActiveCallbacks();
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty();

// Full rebuild. (E.g. during registration)
void update(const std::vector<RecordFunctionCallback>& callbacks);
Expand All @@ -142,6 +143,8 @@ class CacheEntry {
int tries_left_{-1};
};

C10_ALWAYS_INLINE void getActiveCallbacksImpl();

void rebuildActiveCallbacks();
int sampleTries(double p) const;

Expand Down Expand Up @@ -169,6 +172,7 @@ class LocalCallbackManager {
public:
const RecordFunctionTLS& getTLS() const;
StepCallbacks getActiveCallbacks(const RecordScope scope);
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope);

void setTLS(const RecordFunctionTLS& tls);
void seed(uint32_t seed);
Expand All @@ -178,6 +182,8 @@ class LocalCallbackManager {
void clearCallbacks();

private:
void rebuildActiveCallbacksIfNeeded();

void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot);

void rebuild_callback_scopes(
Expand Down Expand Up @@ -271,7 +277,7 @@ void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
rebuildActiveCallbacks();
}

StepCallbacks CacheEntry::getActiveCallbacks() {
void CacheEntry::getActiveCallbacksImpl() {
// We rebuild the active set when `sampling_countdown_` reaches zero, so if it
// reaches zero at the start of this function something has gone wrong.
TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_);
Expand All @@ -295,7 +301,18 @@ StepCallbacks CacheEntry::getActiveCallbacks() {
}
}
}
}

StepCallbacks CacheEntry::getActiveCallbacks() {
getActiveCallbacksImpl();
return active_callbacks_;
}

c10::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
getActiveCallbacksImpl();
if (C10_LIKELY(active_callbacks_.empty())) {
return c10::nullopt;
}
return active_callbacks_;
}

Expand Down Expand Up @@ -365,15 +382,25 @@ const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
return registered_callbacks_;
}

StepCallbacks LocalCallbackManager::getActiveCallbacks(
const RecordScope scope) {
void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() {
const auto global_version = GlobalCallbackManager::get().version();
if (C10_UNLIKELY(global_version != global_version_)) {
rebuild_all(GlobalCallbackManager::get().getSnapshot());
}
}

StepCallbacks LocalCallbackManager::getActiveCallbacks(
const RecordScope scope) {
rebuildActiveCallbacksIfNeeded();
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks();
}

c10::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty(
const RecordScope scope) {
rebuildActiveCallbacksIfNeeded();
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty();
}

void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) {
registered_callbacks_ = tls;
rebuild_all(GlobalCallbackManager::get().getSnapshot());
Expand Down Expand Up @@ -572,6 +599,10 @@ StepCallbacks getStepCallbacks(RecordScope scope) {
return LocalCallbackManager::get().getActiveCallbacks(scope);
}

c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) {
return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope);
}

const RecordFunctionTLS& get_record_function_tls_() {
return LocalCallbackManager::get().getTLS();
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/record_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ struct TORCH_API RecordFunction {

TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);

TORCH_API c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope);

namespace detail {
template <typename Inputs, typename F, typename... Args>
void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) {
Expand Down
7 changes: 4 additions & 3 deletions binaries/record_function_benchmark.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

#include <torch/torch.h>
#include <ATen/record_function.h>

Expand Down Expand Up @@ -49,9 +50,9 @@ float runPureRecordFunctionBench(int iter) {
typedef std::chrono::microseconds us;
std::chrono::time_point<clock> start_time = clock::now();
for (auto idx = 0; idx < iter; ++idx) {
auto step_callbacks = at::getStepCallbacks(at::RecordScope::USER_SCOPE);
if (!step_callbacks.empty()) {
at::RecordFunction guard(std::move(step_callbacks));
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
if (step_callbacks.has_value()) {
at::RecordFunction guard(std::move(*step_callbacks));
guard.before("Test", -1);
}
}
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// probably operate with names.
at::NoNamesGuard no_names_guard;

auto step_callbacks = at::getStepCallbacks(at::RecordScope::BACKWARD_FUNCTION);
if (!step_callbacks.empty()) {
at::RecordFunction guard(std::move(step_callbacks));
auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
if (C10_UNLIKELY(step_callbacks.has_value())) {
at::RecordFunction guard(std::move(*step_callbacks));
// Using sequence number and thread id to correlate with
// the forward pass function
guard.setForwardThreadId(thread_id_);
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/runtime/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,11 +845,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {

static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
if (!frame.record_function) {
auto step_callbacks =
at::getStepCallbacks(at::RecordScope::TORCHSCRIPT_FUNCTION);
if (!step_callbacks.empty()) {
auto step_callbacks = at::getStepCallbacksUnlessEmpty(
at::RecordScope::TORCHSCRIPT_FUNCTION);
if (C10_UNLIKELY(step_callbacks.has_value())) {
auto rec_fn =
std::make_unique<at::RecordFunction>(std::move(step_callbacks));
std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
if (rec_fn->needsInputs()) {
rec_fn->before(
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,9 +1201,9 @@ c10::IValue BlockRunner::run_impl_record_functions(
IValueList&& args,
const KeywordArgs& kwargs) {
auto step_callbacks =
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL);
if (!step_callbacks.empty()) {
at::RecordFunction guard(std::move(step_callbacks));
at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
if (C10_UNLIKELY(step_callbacks.has_value())) {
at::RecordFunction guard(std::move(*step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
guard.needsInputs()
? guard.before(
Expand Down Expand Up @@ -1845,9 +1845,9 @@ std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
void ProcessedNode::run() {
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
auto step_callbacks =
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_OP);
if (!step_callbacks.empty()) {
at::RecordFunction guard(std::move(step_callbacks));
at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
if (C10_UNLIKELY(step_callbacks.has_value())) {
at::RecordFunction guard(std::move(*step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
if (guard.needsInputs()) {
const auto inputs = inputs_ivalue_vec();
Expand Down

0 comments on commit c083489

Please sign in to comment.