Skip to content

Commit

Permalink
Move RecordFunction into ATen (pytorch#37548)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#37548

Moving RecordFunction from torch::autograd::profiler into at namespace

Test Plan:
CI

Imported from OSS

Differential Revision: D21315852

fbshipit-source-id: 4a4dbabf116c162f9aef0da8606590ec3f3847aa
  • Loading branch information
Ilia Cherniavskii authored and facebook-github-bot committed May 7, 2020
1 parent c24c5f9 commit 2d708ce
Show file tree
Hide file tree
Showing 21 changed files with 101 additions and 123 deletions.
14 changes: 6 additions & 8 deletions android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>

#include <torch/csrc/autograd/record_function.h>
#include <ATen/record_function.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/script.h>
#include "caffe2/serialize/read_adapter_interface.h"
Expand Down Expand Up @@ -88,12 +88,12 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {

#ifdef TRACE_ENABLED
static bool onFunctionEnter(
const RecordFunction& fn) {
const at::RecordFunction& fn) {
Trace::beginSection(fn.name().str());
return true;
}

static void onFunctionExit(const RecordFunction&) {
static void onFunctionExit(const at::RecordFunction&) {
Trace::endSection();
}
#endif
Expand All @@ -112,12 +112,10 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
#endif

#ifdef TRACE_ENABLED
pushCallback(
at::addGlobalCallback(at::RecordFunctionCallback(
&onFunctionEnter,
&onFunctionExit,
/* need_inputs */ false,
/* sampling_prob */ 1.0,
/* scopes */ {RecordScope::FUNCTION, RecordScope::USER_SCOPE});
&onFunctionExit)
.scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
#endif
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
#include <ATen/record_function.h>
#include <algorithm>
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/profiler.h>
#include <cstdlib>
#include <random>

namespace torch {
namespace autograd {
namespace profiler {
namespace at {

namespace {

Expand Down Expand Up @@ -206,7 +202,7 @@ inline CallbackManager& manager() {
/* static */
double RecordFunctionCallback::sample_zero_one() {
static thread_local auto gen =
torch::make_unique<std::mt19937>(std::random_device()());
std::make_unique<std::mt19937>(std::random_device()());
std::uniform_real_distribution<double> dist(0.0, 1.0);
return dist(*gen);
}
Expand Down Expand Up @@ -303,18 +299,6 @@ void RecordFunction::_before(std::string name, int64_t sequence_nr) {
manager().runStartCallbacks(*this);
}

void RecordFunction::_before(Node* fn, int64_t sequence_nr) {
if (!active_) {
return;
}
fn_ = fn;
name_ = StringView(fn->name());
sequence_nr_ = (sequence_nr >= 0) ? sequence_nr : fn->sequence_nr();
thread_id_ = currentThreadId();

manager().runStartCallbacks(*this);
}

RecordFunction::~RecordFunction() {
_end();
}
Expand All @@ -335,6 +319,4 @@ RecordFunction* RecordFunction::current() {
return current_record_func_;
}

} // namespace profiler
} // namespace autograd
} // namespace torch
} // namespace at
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
#include <ATen/core/ivalue.h>
#include <ATen/ThreadLocalState.h>
#include <c10/util/SmallVector.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/utils/memory.h>
#include <c10/macros/Export.h>
#include <memory>

#include <functional>

namespace torch { namespace autograd {

struct Node;

namespace profiler {
namespace at {

// Kind of record function scope;
// workaround for the older GCC versions:
Expand All @@ -33,23 +29,19 @@ enum class TORCH_API RecordScope : uint8_t {
# pragma GCC diagnostic pop
#endif

} // namespace profiler
} // namespace autograd
} // namespace torch
} // namespace at

namespace std {
template <>
struct hash<torch::autograd::profiler::RecordScope> {
struct hash<at::RecordScope> {
inline size_t operator()(
const torch::autograd::profiler::RecordScope& sc) const {
const at::RecordScope& sc) const {
return static_cast<std::size_t>(sc);
}
};
} // namespace std

namespace torch {
namespace autograd {
namespace profiler {
namespace at {

struct TORCH_API StringView {
StringView() : StringView(nullptr) {}
Expand Down Expand Up @@ -98,10 +90,6 @@ struct TORCH_API RecordFunction {
RecordFunction(const RecordFunction&) = delete;
RecordFunction& operator=(const RecordFunction&) = delete;

inline Node* func() const {
return fn_;
}

inline const StringView& name() const {
return name_;
}
Expand Down Expand Up @@ -138,7 +126,6 @@ struct TORCH_API RecordFunction {
// start callbacks
void _before(const char* name, int64_t sequence_nr = -1);
void _before(std::string name, int64_t sequence_nr = -1);
void _before(Node* fn, int64_t sequence_nr = -1);

template<typename F>
void _before(
Expand Down Expand Up @@ -181,7 +168,6 @@ struct TORCH_API RecordFunction {
bool needs_inputs_ = false;

private:
Node* fn_ = nullptr;
StringView name_;
int64_t sequence_nr_ = -1;
std::vector<c10::IValue> inputs_;
Expand Down Expand Up @@ -317,7 +303,7 @@ class TORCH_API RecordFunctionCallback {
// Using macro to minimize inputs copies,
// optional argument - function's seq_no
#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \
torch::autograd::profiler::RecordFunction guard(scope); \
at::RecordFunction guard(scope); \
if (guard.active_) { \
guard._setCurrent(); \
if (guard.needs_inputs_) { \
Expand All @@ -329,17 +315,17 @@ class TORCH_API RecordFunctionCallback {

#define RECORD_FUNCTION(fn, inputs, ...) \
RECORD_FUNCTION_WITH_SCOPE( \
torch::autograd::profiler::RecordScope::FUNCTION, \
at::RecordScope::FUNCTION, \
fn, inputs, ##__VA_ARGS__)

#define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \
RECORD_FUNCTION_WITH_SCOPE( \
torch::autograd::profiler::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs)
at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs)

// Custom user scopes in C++; similar to Python's 'with record_function("..."):'
#define RECORD_USER_SCOPE(fn) \
RECORD_FUNCTION_WITH_SCOPE( \
torch::autograd::profiler::RecordScope::USER_SCOPE, fn, {})
at::RecordScope::USER_SCOPE, fn, {})

// Notes:
// - two types of callbacks are provided: thread local and global
Expand Down Expand Up @@ -455,6 +441,4 @@ class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
virtual ~DisableRecordFunctionGuard() {}
};

} // namespace profiler
} // namespace autograd
} // namespace torch
} // namespace at
14 changes: 7 additions & 7 deletions binaries/record_function_benchmark.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <torch/torch.h>
#include <torch/csrc/autograd/record_function.h>
#include <ATen/record_function.h>

#include "c10/util/Flags.h"

Expand All @@ -22,20 +22,20 @@ using namespace torch::autograd;

void setupCallbacks() {
// non-sampled callback
profiler::addGlobalCallback(profiler::RecordFunctionCallback(
[&](const profiler::RecordFunction& fn) {
at::addGlobalCallback(at::RecordFunctionCallback(
[&](const at::RecordFunction& fn) {
return true;
},
[](const profiler::RecordFunction&) {})
[](const at::RecordFunction&) {})
.needsInputs(true));

// sampled
for (auto idx = 0; idx < kNumSampledCb; ++idx) {
profiler::addGlobalCallback(profiler::RecordFunctionCallback(
[](const profiler::RecordFunction& fn) {
at::addGlobalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) {
return true;
},
[](const profiler::RecordFunction&) {})
[](const at::RecordFunction&) {})
.needsInputs(true)
.samplingProb(kSampingProb)
);
Expand Down
3 changes: 0 additions & 3 deletions docs/source/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ and nvprof based (registers both CPU and GPU activity) using
.. autoclass:: torch.autograd.profiler.profile
:members:

.. autoclass:: torch.autograd.profiler.record_function
:members:

.. autoclass:: torch.autograd.profiler.emit_nvtx
:members:

Expand Down
6 changes: 3 additions & 3 deletions docs/source/notes/large_scale_deployments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ gathering information about PyTorch workloads running in a given process or
across the entire set of machines.

New callbacks for any operator invocation can be added with
``torch::autograd::profiler::addGlobalCallback``. Hooks will be called with
``torch::autograd::profiler::RecordFunction`` struct that describes invocation
``torch::addGlobalCallback``. Hooks will be called with
``torch::RecordFunction`` struct that describes invocation
context (e.g. `name`). If enabled, ``RecordFunction::inputs()`` contains arguments
of the function represented as ``torch::IValue`` variant type. Note, that inputs
logging is relatively expensive and thus has to be enabled explicitly.
Expand All @@ -42,7 +42,7 @@ application down to the operator callbacks.

Invoking callbacks adds some overhead, so usually it's useful to just randomly
sample operator invocations. This can be enabled on per-callback basis with an
optional sampling rate passed into ``torch::autograd::profiler::addGlobalCallback``.
optional sampling rate passed into ``torch::addGlobalCallback``.

Note, that ``addGlobalCallback`` is not thread-safe and can be called only when no
PyTorch operator is running. Usually, it's a good idea to call them once during
Expand Down
32 changes: 16 additions & 16 deletions test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,47 +761,47 @@ void checkScopeCallbacks() {
bool found_function_scope = false;
bool found_method_scope = false;
bool found_user_scope = false;
profiler::addGlobalCallback(profiler::RecordFunctionCallback(
[&](const profiler::RecordFunction& fn) {
if (fn.scope() == profiler::RecordScope::FUNCTION &&
at::addGlobalCallback(at::RecordFunctionCallback(
[&](const at::RecordFunction& fn) {
if (fn.scope() == at::RecordScope::FUNCTION &&
std::string(fn.name().str()) == "test_function") {
found_function_scope = true;
}
if (fn.scope() == profiler::RecordScope::TORCHSCRIPT_FUNCTION &&
if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
std::string(fn.name().str()) == "test_method") {
found_method_scope = true;
}
if (fn.scope() == profiler::RecordScope::USER_SCOPE &&
if (fn.scope() == at::RecordScope::USER_SCOPE &&
std::string(fn.name().str()) == "test_user_scope") {
found_user_scope = true;
}
},
[](const profiler::RecordFunction&) {}));
[](const at::RecordFunction&) {}));

bool bad_scope = false;
auto pushScopedCallback = [&](profiler::RecordScope scope, size_t& cnt) {
profiler::addGlobalCallback(
profiler::RecordFunctionCallback(
[&bad_scope, &cnt, scope](const profiler::RecordFunction& fn) {
auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) {
at::addGlobalCallback(
at::RecordFunctionCallback(
[&bad_scope, &cnt, scope](const at::RecordFunction& fn) {
if (fn.scope() == scope) {
++cnt;
} else {
bad_scope = true;
}
return true;
},
[](const profiler::RecordFunction&) {})
[](const at::RecordFunction&) {})
.scopes({scope}));
};

size_t fun_cnt = 0;
pushScopedCallback(profiler::RecordScope::FUNCTION, fun_cnt);
pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt);
size_t ts_fun_cnt = 0;
pushScopedCallback(profiler::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt);
pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt);
size_t user_scope_cnt = 0;
pushScopedCallback(profiler::RecordScope::USER_SCOPE, user_scope_cnt);
pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt);

TORCH_CHECK(profiler::hasCallbacks());
TORCH_CHECK(at::hasCallbacks());

{
RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
Expand Down Expand Up @@ -874,7 +874,7 @@ void testRecordFunction() {

checkTracedInputs(eager_inputs);
checkTracedInputs(jit_inputs);
profiler::clearCallbacks();
at::clearCallbacks();

// test sampled callbacks
int sampled_cb_ctr = 0;
Expand Down
1 change: 0 additions & 1 deletion tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ libtorch_core_sources = [
"torch/csrc/autograd/functions/utils.cpp",
"torch/csrc/autograd/input_buffer.cpp",
"torch/csrc/autograd/profiler.cpp",
"torch/csrc/autograd/record_function.cpp",
"torch/csrc/autograd/record_function_ops.cpp",
"torch/csrc/autograd/saved_variable.cpp",
"torch/csrc/autograd/variable.cpp",
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/api/include/torch/utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <ATen/Parallel.h>
#include <ATen/record_function.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/api/include/torch/types.h>
#include <cstdint>
Expand Down Expand Up @@ -88,4 +89,22 @@ inline bool equal_if_defined(Tensor t1, Tensor t2) {
return ((!t1.defined() && !t2.defined()) || (t1.defined() && t2.defined() && torch::equal(t1, t2)));
}

// RecordFunction API
using at::RecordFunctionCallback;
using at::addThreadLocalCallback;
using at::hasThreadLocalCallbacks;
using at::clearThreadLocalCallbacks;
using at::addGlobalCallback;
using at::removeCallback;
using at::hasGlobalCallbacks;
using at::clearGlobalCallbacks;
using at::hasCallbacks;
using at::clearCallbacks;
using at::enableRecordFunction;
using at::isRecordFunctionEnabled;
using at::RecordFunctionGuard;
using at::DisableRecordFunctionGuard;
using at::CallbackHandle;
using at::RecordFunction;

} // namespace torch
3 changes: 1 addition & 2 deletions torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
/// function call.
variable_list operator()(variable_list&& inputs) {
RECORD_FUNCTION(
this, std::vector<c10::IValue>(inputs.begin(), inputs.end()));

name(), std::vector<c10::IValue>(inputs.begin(), inputs.end()), sequence_nr());
// In the first iteration of named tensors, autograd ignores names and
// operates on unnamed tensors. In the long term, autograd should
// probably operate with names.
Expand Down
Loading

0 comments on commit 2d708ce

Please sign in to comment.