Skip to content

Commit

Permalink
[PyTorch] Use plain old function pointer for RecordFunctionCallback (p…
Browse files Browse the repository at this point in the history
…ytorch#48629)

Summary:
Pull Request resolved: pytorch#48629

Nearly every non-test callsite doesn't need to capture any variables anyway, and this saves 48 bytes per callback.
ghstack-source-id: 118568240

Test Plan: CI

Reviewed By: dhruvbird

Differential Revision: D25135415

fbshipit-source-id: 5e92dc79da6473ed15d1e381a21ed315879168f3
  • Loading branch information
swolchok authored and facebook-github-bot committed Dec 15, 2020
1 parent 900aa4e commit 7e23ee1
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 136 deletions.
6 changes: 4 additions & 2 deletions aten/src/ATen/record_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,12 @@ class CallbackManager {
bool is_start) {
try {
if (is_start) {
ctx = rfcb.start()(rf);
ctx = rfcb.start() ? rfcb.start()(rf) : nullptr;
}
else {
rfcb.end()(rf, ctx.get());
if (rfcb.end()) {
rfcb.end()(rf, ctx.get());
}
}
return true;
} catch (const std::exception &e) {
Expand Down
20 changes: 11 additions & 9 deletions aten/src/ATen/record_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,16 @@ struct TORCH_API RecordFunction {
*/
class TORCH_API RecordFunctionCallback {
public:
using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
using EndCallback = void (*)(const RecordFunction&, ObserverContext*);

// This interface supports observers that require passing an ObserverContext
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
std::function<void(const RecordFunction&, ObserverContext*)> end =
[](const RecordFunction&, ObserverContext*) {}):
start_(std::move(start)),
end_(std::move(end)) {
StartCallback start,
EndCallback end = nullptr) :
start_(start),
end_(end) {
scopes_.fill(true);
}

Expand Down Expand Up @@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback {
return scopes_[(size_t)sc];
}

inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
inline StartCallback start() const {
return start_;
}

inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
inline EndCallback end() const {
return end_;
}

private:
friend class CallbackManager;
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
std::function<void(const RecordFunction&, ObserverContext*)> end_;
StartCallback start_;
EndCallback end_;
bool(*should_run_)(const RecordFunctionCallback&) = nullptr;
double sampling_prob_ = 1.0;
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
Expand Down
10 changes: 5 additions & 5 deletions binaries/record_function_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001;

void addTestCallback(
double sampling_prob = 1.0,
std::function<std::unique_ptr<at::ObserverContext>(const at::RecordFunction&)> fn =
[](const at::RecordFunction&) { return nullptr; }) {
at::RecordFunctionCallback::StartCallback fn =
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
auto cb = at::RecordFunctionCallback(
std::move(fn),
fn,
[](const at::RecordFunction&, at::ObserverContext*) {})
.needsInputs(false);
if (sampling_prob < 1.0) {
Expand Down Expand Up @@ -106,10 +106,10 @@ int main(int argc, char** argv) {
at::clearCallbacks();

std::cout << "Checking number of sampled observer invocations" << std::endl;
int cb_count = 0;
static int cb_count = 0;
addTestCallback(
kLowSamplingProb,
[&](const at::RecordFunction& fn) {
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
++cb_count;
return nullptr;
}
Expand Down
Loading

0 comments on commit 7e23ee1

Please sign in to comment.