Skip to content

Commit 04285ab

Browse files
authored
[AMP] support setting amp_level in multi-thread (#39198)
1 parent 3e80253 commit 04285ab

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

paddle/fluid/imperative/tracer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ namespace imperative {
3232

3333
thread_local bool Tracer::has_grad_ = true;
3434

35+
thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
36+
3537
static std::shared_ptr<Tracer> g_current_tracer(nullptr);
3638

3739
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }

paddle/fluid/imperative/tracer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class Tracer {
126126
platform::Place expected_place_;
127127
GarbageCollectorMap gcs_;
128128
static thread_local bool has_grad_;
129-
AmpLevel amp_level_{AmpLevel::O0};
129+
static thread_local AmpLevel amp_level_;
130130
};
131131

132132
// To access static variable current_tracer

0 commit comments

Comments
 (0)