Skip to content

Commit 7480e83

Browse files
Taylor Robiepytorchmergebot
authored andcommitted
[Profiler] Add disabled and global methods to ProfilerConfig. (pytorch#83891)
`ProfilerState::Disabled` and `ProfilerState::KINETO_ONDEMAND` have special semantics. The former is somewhat intuitive, but the degree of behavior branching on the latter (and why the branching is necessary) is less clear. By factoring the enum checks into methods, we can both clairify intent and future proof in case we ever add other global profiling contexts. Differential Revision: [D38917980](https://our.internmc.facebook.com/intern/diff/D38917980/) Pull Request resolved: pytorch#83891 Approved by: https://github.com/slgong-fb
1 parent 8e6207b commit 7480e83

File tree

5 files changed

+32
-33
lines changed

5 files changed

+32
-33
lines changed

torch/csrc/autograd/profiler_kineto.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
226226
int64_t total_allocated,
227227
int64_t total_reserved,
228228
c10::Device device) override {
229-
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
229+
if (config_.profile_memory && !config_.disabled()) {
230230
record_queue_.getSubqueue()->emplace_allocation_event(
231231
torch::profiler::impl::getApproximateTime(),
232232
ptr,
@@ -243,7 +243,7 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalStateBase {
243243
int64_t total_allocated,
244244
int64_t total_reserved,
245245
c10::Device device) override {
246-
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
246+
if (config_.profile_memory && !config_.disabled()) {
247247
record_queue_.getSubqueue()->emplace_ooms_event(
248248
torch::profiler::impl::getApproximateTime(),
249249
alloc_size,
@@ -558,13 +558,18 @@ void enableProfiler(
558558

559559
TORCH_CHECK(
560560
config.state == ProfilerState::KINETO ||
561-
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
562-
config.state == ProfilerState::KINETO_ONDEMAND);
561+
config.state == ProfilerState::KINETO_GPU_FALLBACK || config.global());
563562
TORCH_CHECK(
564563
!activities.empty(), "No activities specified for Kineto profiler");
565564

566-
if (config.state == ProfilerState::KINETO ||
567-
config.state == ProfilerState::KINETO_GPU_FALLBACK) {
565+
if (config.global()) {
566+
KinetoTLSGlobalStateManager::init(config, activities);
567+
568+
TORCH_INTERNAL_ASSERT(
569+
activities.count(ActivityType::CPU),
570+
"Ondemand profiling must enable CPU tracing");
571+
pushProfilingCallbacks<true>(scopes);
572+
} else {
568573
auto state = std::make_shared<KinetoThreadLocalState>(config, activities);
569574
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
570575

@@ -573,15 +578,6 @@ void enableProfiler(
573578
}
574579
torch::profiler::impl::kineto::startTrace();
575580
}
576-
577-
if (config.state == ProfilerState::KINETO_ONDEMAND) {
578-
KinetoTLSGlobalStateManager::init(config, activities);
579-
580-
TORCH_INTERNAL_ASSERT(
581-
activities.count(ActivityType::CPU),
582-
"Ondemand profiling must enable CPU tracing");
583-
pushProfilingCallbacks<true>(scopes);
584-
}
585581
}
586582

587583
std::unique_ptr<ProfilerResult> disableProfiler() {

torch/csrc/autograd/profiler_legacy.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ thread_event_lists ProfilerLegacyThreadLocalState::consolidate() {
193193
}
194194

195195
void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) {
196-
if (config_.state == torch::profiler::impl::ProfilerState::Disabled) {
196+
if (config_.disabled()) {
197197
return;
198198
}
199199
if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
@@ -225,7 +225,7 @@ void ProfilerLegacyThreadLocalState::pushRange(
225225
const at::RecordFunction& fn,
226226
const bool record_cuda,
227227
std::vector<std::vector<int64_t>>&& shapes) {
228-
if (config_.state == torch::profiler::impl::ProfilerState::Disabled) {
228+
if (config_.disabled()) {
229229
return;
230230
}
231231
if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
@@ -273,7 +273,7 @@ void ProfilerLegacyThreadLocalState::pushRange(
273273
void ProfilerLegacyThreadLocalState::popRange(
274274
const at::RecordFunction& fn,
275275
const bool record_cuda) {
276-
if (config_.state == torch::profiler::impl::ProfilerState::Disabled) {
276+
if (config_.disabled()) {
277277
return;
278278
}
279279
if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
@@ -300,8 +300,7 @@ void ProfilerLegacyThreadLocalState::reportMemoryUsage(
300300
int64_t /* total_allocated, unused for legacy */,
301301
int64_t /* total_reserved, unused for legacy */,
302302
c10::Device device) {
303-
if (config_.profile_memory &&
304-
config_.state != torch::profiler::impl::ProfilerState::Disabled) {
303+
if (config_.profile_memory && !config_.disabled()) {
305304
uint64_t thread_id = at::RecordFunction::currentThreadId();
306305
LegacyEvent evt(
307306
EventKind::MemoryAlloc,
@@ -372,9 +371,7 @@ void pushProfilingCallbacksLegacy() {
372371
[](const at::RecordFunction& fn)
373372
-> std::unique_ptr<at::ObserverContext> {
374373
auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
375-
if (!state_ptr ||
376-
state_ptr->config().state ==
377-
torch::profiler::impl::ProfilerState::Disabled) {
374+
if (!state_ptr || state_ptr->config().disabled()) {
378375
return nullptr;
379376
}
380377
bool record_cuda = state_ptr->config().state ==
@@ -396,9 +393,7 @@ void pushProfilingCallbacksLegacy() {
396393
},
397394
[](const at::RecordFunction& fn, at::ObserverContext*) {
398395
auto state_ptr = ProfilerLegacyThreadLocalState::getTLS();
399-
if (!state_ptr ||
400-
state_ptr->config().state ==
401-
torch::profiler::impl::ProfilerState::Disabled) {
396+
if (!state_ptr || state_ptr->config().disabled()) {
402397
return;
403398
}
404399
bool record_cuda = state_ptr->config().state ==
@@ -454,9 +449,7 @@ thread_event_lists disableProfilerLegacy(
454449

455450
auto state_ptr = static_cast<ProfilerLegacyThreadLocalState*>(state.get());
456451
TORCH_CHECK(
457-
state_ptr &&
458-
state_ptr->config().state !=
459-
torch::profiler::impl::ProfilerState::Disabled,
452+
state_ptr && !state_ptr->config().disabled(),
460453
"Can't disable profiler when it's not running");
461454

462455
if (cleanupTLSState) {

torch/csrc/profiler/api.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ ExperimentalConfig::ExperimentalConfig(
1414
return !profiler_metrics.empty();
1515
}
1616

17+
bool ProfilerConfig::disabled() const {
18+
return state == torch::profiler::impl::ProfilerState::Disabled;
19+
}
20+
21+
bool ProfilerConfig::global() const {
22+
return state == torch::profiler::impl::ProfilerState::KINETO_ONDEMAND;
23+
}
24+
1725
namespace {
1826
enum ProfilerIValueIdx {
1927
STATE = 0,
@@ -52,9 +60,7 @@ ProfilerConfig ProfilerConfig::fromIValue(
5260

5361
bool profilerEnabled() {
5462
auto state_ptr = ProfilerThreadLocalStateBase::getTLS();
55-
return state_ptr &&
56-
state_ptr->config().state !=
57-
torch::profiler::impl::ProfilerState::Disabled;
63+
return state_ptr && !state_ptr->config().disabled();
5864
}
5965

6066
TORCH_API ActiveProfilerType profilerType() {

torch/csrc/profiler/api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ struct TORCH_API ProfilerConfig {
6666
with_flops(with_flops),
6767
with_modules(with_modules) {}
6868
~ProfilerConfig() = default;
69+
70+
bool disabled() const;
71+
bool global() const;
72+
6973
ProfilerState state;
7074
ExperimentalConfig experimental_config;
7175
bool report_input_shapes;

torch/csrc/profiler/collection.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ trace_ptr_t addKinetoEvents(
836836
passEventsToKineto(results, start_time_us, end_time_us);
837837

838838
// In on demand mode kineto is directly controlled by other machinery.
839-
if (config.state == ProfilerState::KINETO_ONDEMAND) {
839+
if (config.global()) {
840840
return nullptr;
841841
}
842842

0 commit comments

Comments
 (0)