diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index d1b0acb87c286..a75b1a1295dbe 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -277,10 +277,12 @@ class CallbackManager { bool is_start) { try { if (is_start) { - ctx = rfcb.start()(rf); + ctx = rfcb.start() ? rfcb.start()(rf) : nullptr; } else { - rfcb.end()(rf, ctx.get()); + if (rfcb.end()) { + rfcb.end()(rf, ctx.get()); + } } return true; } catch (const std::exception &e) { diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 6b2e085760685..e9939667feb77 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -305,14 +305,16 @@ struct TORCH_API RecordFunction { */ class TORCH_API RecordFunctionCallback { public: + using StartCallback = std::unique_ptr(*)(const RecordFunction&); + using EndCallback = void (*)(const RecordFunction&, ObserverContext*); + // This interface supports observers that require passing an ObserverContext // between start and end callbacks. explicit RecordFunctionCallback( - std::function(const RecordFunction&)> start, - std::function end = - [](const RecordFunction&, ObserverContext*) {}): - start_(std::move(start)), - end_(std::move(end)) { + StartCallback start, + EndCallback end = nullptr) : + start_(start), + end_(end) { scopes_.fill(true); } @@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback { return scopes_[(size_t)sc]; } - inline const std::function(const RecordFunction&)>& start() const { + inline StartCallback start() const { return start_; } - inline const std::function& end() const { + inline EndCallback end() const { return end_; } private: friend class CallbackManager; - std::function(const RecordFunction&)> start_; - std::function end_; + StartCallback start_; + EndCallback end_; bool(*should_run_)(const RecordFunctionCallback&) = nullptr; double sampling_prob_ = 1.0; std::array(RecordScope::NUM_SCOPES)> scopes_ = {}; diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index d47cedada40f3..c80f46d756524 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001; void addTestCallback( double sampling_prob = 1.0, - std::function(const at::RecordFunction&)> fn = - [](const at::RecordFunction&) { return nullptr; }) { + at::RecordFunctionCallback::StartCallback fn = + [](const at::RecordFunction&) -> std::unique_ptr { return nullptr; }) { auto cb = at::RecordFunctionCallback( - std::move(fn), + fn, [](const at::RecordFunction&, at::ObserverContext*) {}) .needsInputs(false); if (sampling_prob < 1.0) { @@ -106,10 +106,10 @@ int main(int argc, char** argv) { at::clearCallbacks(); std::cout << "Checking number of sampled observer invocations" << std::endl; - int cb_count = 0; + static int cb_count = 0; addTestCallback( kLowSamplingProb, - [&](const at::RecordFunction& fn) { + [](const at::RecordFunction&) -> std::unique_ptr { ++cb_count; return nullptr; } diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 10f36cc8e3949..445da59aee50e 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -721,12 +721,34 @@ void checkTracedInputs(const TracedTestInputs& inputs) { TORCH_CHECK(found_mul); } +static bool bad_scope = false; +template +std::unique_ptr checkScopeCallback(const at::RecordFunction& fn) { + if (fn.scope() == scope) { + ++(*cnt); + } else { + bad_scope = true; + } + return nullptr; +} + +template +void pushScopedCallback() { + at::addGlobalCallback( + at::RecordFunctionCallback( + checkScopeCallback) + .scopes({scope})); +} + void checkScopeCallbacks() { - bool found_function_scope = false; - bool found_method_scope = false; - bool found_user_scope = false; + static bool found_function_scope; + static bool found_method_scope; + static bool found_user_scope; + found_function_scope = false; + found_method_scope = false; + found_user_scope = false; at::addGlobalCallback(at::RecordFunctionCallback( - [&](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr{ if (fn.scope() == at::RecordScope::FUNCTION && std::string(fn.name().str()) == "test_function") { found_function_scope = true; @@ -742,27 +764,17 @@ void checkScopeCallbacks() { return nullptr; })); - bool bad_scope = false; - auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) { - at::addGlobalCallback( - at::RecordFunctionCallback( - [&bad_scope, &cnt, scope](const at::RecordFunction& fn) { - if (fn.scope() == scope) { - ++cnt; - } else { - bad_scope = true; - } - return nullptr; - }) - .scopes({scope})); - }; + static size_t fun_cnt; + static size_t ts_fun_cnt; + static size_t user_scope_cnt; - size_t fun_cnt = 0; - pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt); - size_t ts_fun_cnt = 0; - pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt); - size_t user_scope_cnt = 0; - pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt); + bad_scope = false; + fun_cnt = 0; + pushScopedCallback(); + ts_fun_cnt = 0; + pushScopedCallback(); + user_scope_cnt = 0; + pushScopedCallback(); TORCH_CHECK(at::hasCallbacks()); @@ -788,33 +800,33 @@ static bool shouldRunCallback(const RecordFunctionCallback&) { return should_run; } -TEST(RecordFunctionTest, Basic) { +static TracedTestInputs traced_inputs; +static std::unordered_set ts_names; + +std::unique_ptr tracedInputsCallback(const RecordFunction& fn) { + if (fn.scope() == RecordScope::FUNCTION) { + auto inputs = fn.inputs(); + std::vector> sizes; + for (const auto& input : inputs) { + if (input.isTensor()) { + sizes.push_back(input.toTensor().sizes().vec()); + } else if (input.isScalar()) { + sizes.push_back(std::vector()); + } + } + traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes)); + } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { + ts_names.insert(fn.name().str()); + } + return nullptr; +} + +TEST(RecordFunctionTest, TracedTestInputs) { // disabling the inlining of method calls GraphOptimizerEnabledGuard opt_guard(false); // [(fn, [[sizes], [sizes], ...]), ...] - TracedTestInputs traced_inputs; - std::unordered_set ts_names; - addGlobalCallback( - RecordFunctionCallback( - [&](const RecordFunction& fn) { - if (fn.scope() == RecordScope::FUNCTION) { - auto inputs = fn.inputs(); - std::vector> sizes; - for (const auto& input : inputs) { - if (input.isTensor()) { - sizes.push_back(input.toTensor().sizes().vec()); - } else if (input.isScalar()) { - sizes.push_back(std::vector()); - } - } - traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes)); - } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { - ts_names.insert(fn.name().str()); - } - return nullptr; - }) - .needsInputs(true)); + addGlobalCallback(RecordFunctionCallback(tracedInputsCallback).needsInputs(true)); TracedTestInputs eager_inputs, jit_inputs; { @@ -841,28 +853,36 @@ TEST(RecordFunctionTest, Basic) { checkTracedInputs(eager_inputs); checkTracedInputs(jit_inputs); at::clearCallbacks(); +} + +static int sampled_cb_ctr = 0; +std::unique_ptr sampledCallback(const RecordFunction& fn) { + if (std::string(fn.name().str()) == "test") { + ++sampled_cb_ctr; + } + return nullptr; +} + +static int non_sampled_cb_ctr = 0; +std::unique_ptr nonSampledCallback(const RecordFunction& fn) { + if (std::string(fn.name().str()) == "test") { + ++non_sampled_cb_ctr; + } + return nullptr; +} + +TEST(RecordFunctionTest, SampledCallbacks) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); // test sampled callbacks - int sampled_cb_ctr = 0; - auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) { - return addGlobalCallback(RecordFunctionCallback( - [&sampled_cb_ctr](const RecordFunction& fn) { - if (std::string(fn.name().str()) == "test") { - ++sampled_cb_ctr; - } - return nullptr; - }) + sampled_cb_ctr = 0; + auto setup_sampled_callback = [](double sampling_prob) { + return addGlobalCallback(RecordFunctionCallback(sampledCallback) .samplingProb(sampling_prob)); }; - int non_sampled_cb_ctr = 0; - addGlobalCallback(RecordFunctionCallback( - [&non_sampled_cb_ctr](const RecordFunction& fn) { - if (std::string(fn.name().str()) == "test") { - ++non_sampled_cb_ctr; - } - return nullptr; - })); + addGlobalCallback(RecordFunctionCallback(nonSampledCallback)); auto handle = setup_sampled_callback(0.5); @@ -897,13 +917,19 @@ TEST(RecordFunctionTest, Basic) { // test the scope of the callbacks checkScopeCallbacks(); clearCallbacks(); +} + +TEST(RecordFunctionTest, RecordFunctionGuard) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + static std::vector fn_names; + static std::mutex guard_mtx; // check record function guard - std::vector fn_names; - std::mutex mtx; addGlobalCallback(RecordFunctionCallback( - [&fn_names, &mtx](const RecordFunction& fn) { - std::lock_guard lock(mtx); + [](const RecordFunction& fn) -> std::unique_ptr{ + std::lock_guard lock(guard_mtx); fn_names.push_back(fn.name().str()); return nullptr; })); @@ -925,20 +951,26 @@ TEST(RecordFunctionTest, Basic) { TORCH_CHECK(fn_names.size() == 1); TORCH_CHECK(fn_names[0] == "B"); clearCallbacks(); +} - // test add/remove - std::vector ids; - auto add_remove_test_add_cb = [&ids](size_t id) { - return addGlobalCallback(RecordFunctionCallback( - [&ids, id](const RecordFunction& fn) { - ids.push_back(id); - return nullptr ; - })); - }; +static std::vector ids; - auto h1 = add_remove_test_add_cb(1); - auto h2 = add_remove_test_add_cb(2); - auto h3 = add_remove_test_add_cb(3); +template +auto add_remove_test_add_cb() { + return addGlobalCallback(RecordFunctionCallback( + [](const RecordFunction& fn) -> std::unique_ptr { + ids.push_back(id); + return nullptr; + })); +} + +TEST(RecordFunctionTest, Callbacks) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + auto h1 = add_remove_test_add_cb<1>(); + auto h2 = add_remove_test_add_cb<2>(); + auto h3 = add_remove_test_add_cb<3>(); { RECORD_USER_SCOPE("test"); } @@ -969,8 +1001,7 @@ TEST(RecordFunctionTest, Basic) { // thread local / global callbacks ids.clear(); - addGlobalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; })); + add_remove_test_add_cb<1>(); { RECORD_USER_SCOPE("test"); } @@ -978,9 +1009,9 @@ TEST(RecordFunctionTest, Basic) { TORCH_CHECK(ids[0] == 1); ids.clear(); - auto th = std::thread([&ids]() { + auto th = std::thread([]() { addThreadLocalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; })); + [](const RecordFunction& fn) -> std::unique_ptr { ids.push_back(2); return nullptr; })); { RECORD_USER_SCOPE("test_thread"); } }); @@ -1005,22 +1036,19 @@ TEST(RecordFunctionTest, Basic) { }; ids.clear(); { // START: global test - const int test_val = 123; - const std::string test_str = "test str"; addGlobalCallback(RecordFunctionCallback( - [test_val, test_str, &ids](const RecordFunction& /* unused */) { + [](const RecordFunction& /* unused */) -> std::unique_ptr { auto ctx = std::make_unique(); - ctx->a = test_val; - ctx->b = test_str; + ctx->a = 123; + ctx->b = "test_str"; ids.push_back(1); return ctx; }, - [test_val, test_str]( - const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { + [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { auto ctx = dynamic_cast(ctx_ptr); TORCH_CHECK(ctx_ptr != nullptr); - TORCH_CHECK(ctx->a == test_val); - TORCH_CHECK(ctx->b == test_str); + TORCH_CHECK(ctx->a == 123); + TORCH_CHECK(ctx->b == "test_str"); })); { RECORD_USER_SCOPE("test"); } @@ -1030,23 +1058,23 @@ TEST(RecordFunctionTest, Basic) { ids.clear(); } // END: global test { // START: thread local test - auto ctx_th = std::thread([&ids]() { + auto ctx_th = std::thread([]() { const int test_val = 234; const std::string test_str = "test thread str"; addThreadLocalCallback(RecordFunctionCallback( - [test_val, test_str, &ids](const RecordFunction& /* unused */) { + [](const RecordFunction& /* unused */) -> std::unique_ptr { auto ctx = std::make_unique(); - ctx->a = test_val; - ctx->b = test_str; + ctx->a = 234; + ctx->b = "test_thread_str"; ids.push_back(2); return ctx; }, - [test_val, test_str]( + []( const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { auto ctx = dynamic_cast(ctx_ptr); TORCH_CHECK(ctx_ptr != nullptr); - TORCH_CHECK(ctx->a == test_val); - TORCH_CHECK(ctx->b == test_str); + TORCH_CHECK(ctx->a == 234); + TORCH_CHECK(ctx->b == "test_thread_str"); })); // Will call both global and thread local callbacks. @@ -1060,13 +1088,16 @@ TEST(RecordFunctionTest, Basic) { } // END: thread local test clearCallbacks(); +} - // test should_run +TEST(RecordFunctionTest, ShouldRun) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); - bool ran = false; should_run = false; + static bool ran = false; addGlobalCallback(RecordFunctionCallback( - [&ran](const RecordFunction& fn) { ran = true; return nullptr; }) + [](const RecordFunction& fn) -> std::unique_ptr { ran = true; return nullptr; }) .setShouldRun(shouldRunCallback)); { RECORD_USER_SCOPE("test"); } @@ -1080,13 +1111,20 @@ TEST(RecordFunctionTest, Basic) { TORCH_CHECK(ran); clearCallbacks(); +} + +TEST(RecordFunctionTest, Basic) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + static std::string recorded_op; + static bool has_ids = false; // test propagation of TLS callbacks std::thread t([]() { RecordFunctionGuard enable_rec_fn; - std::string recorded_op; auto handle = addThreadLocalCallback(RecordFunctionCallback( - [&recorded_op](const RecordFunction& fn) { + [](const RecordFunction& fn) -> std::unique_ptr { recorded_op = fn.name().str(); return nullptr; })); @@ -1096,17 +1134,16 @@ TEST(RecordFunctionTest, Basic) { RECORD_USER_SCOPE("test_in_thread"); }); t_child.join(); - TORCH_CHECK(recorded_op == "test_in_thread"); + EXPECT_EQ(recorded_op, "test_in_thread"); removeCallback(handle); }); t.join(); clearCallbacks(); // test set ids - bool has_ids = false; addGlobalCallback( RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { + [](const RecordFunction& fn) -> std::unique_ptr { has_ids = fn.handle() > 0; return nullptr; }) @@ -1116,7 +1153,7 @@ TEST(RecordFunctionTest, Basic) { clearCallbacks(); has_ids = false; addGlobalCallback(RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { + [](const RecordFunction& fn) -> std::unique_ptr { has_ids = fn.handle() > 0; return nullptr; })); @@ -1126,10 +1163,9 @@ TEST(RecordFunctionTest, Basic) { } TEST(RecordFunctionTest, OperatorNameOverload) { - std::set operator_names; - + static std::set operator_names; at::addGlobalCallback(at::RecordFunctionCallback( - [&operator_names](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr { c10::optional op_name = fn.operator_name(); if (op_name.has_value()) { @@ -1178,6 +1214,8 @@ void checkDebugInfo(c10::DebugInfoKind kind, int model_id) { } TEST(ThreadLocalDebugInfoTest, Basic) { + static std::atomic done{false}; + TORCH_CHECK( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); auto debug_info = std::make_shared(); @@ -1190,10 +1228,9 @@ TEST(ThreadLocalDebugInfoTest, Basic) { // check that thread local debug info is propagated through fork calls TORCH_CHECK( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); - std::atomic done{false}; { c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); - at::launch([&done]() { + at::launch([]() { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; }); @@ -1206,7 +1243,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) { c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); done = false; auto handle = addGlobalCallback(RecordFunctionCallback( - [&done](const RecordFunction&) { + [](const RecordFunction&) -> std::unique_ptr { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; return nullptr; @@ -1236,7 +1273,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); done = false; - at::launch([&done]() { + at::launch([]() { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); done = true; diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 488b7be9bd8a4..7bf11a4d6316d 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -172,9 +172,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { at::enableRecordFunction(enable); }); m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { - auto cb = at::RecordFunctionCallback( - [](const at::RecordFunction&) { return nullptr; }, - [](const at::RecordFunction&, at::ObserverContext*) {}) + auto cb = at::RecordFunctionCallback(nullptr) .needsInputs(true) .samplingProb(sampling_prob); if (is_global) { diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index eb52aec8920dd..d478aa5098221 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -414,7 +414,7 @@ void pushProfilingCallbacksLegacy() { auto state_ptr = getProfilerTLSState(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr{ auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { return nullptr;