Skip to content

Commit 3fcdb47

Browse files
committed
fix dygraph has_grad
1 parent 6151ccd commit 3fcdb47

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

paddle/fluid/imperative/tracer.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ DECLARE_string(tracer_mkldnn_ops_off);
3030
namespace paddle {
3131
namespace imperative {
3232

33+
thread_local bool Tracer::has_grad_ = true;
34+
3335
static std::shared_ptr<Tracer> g_current_tracer(nullptr);
3436

3537
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }

paddle/fluid/imperative/tracer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ class Tracer {
118118
bool enable_program_desc_tracing_{false};
119119
std::unique_ptr<UniqueNameGenerator> generator_;
120120
platform::Place expected_place_;
121-
bool has_grad_{true};
122121
bool enable_autocast_{false};
123122
GarbageCollectorMap gcs_;
123+
static thread_local bool has_grad_;
124124
};
125125

126126
// To access static variable current_tracer
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
import time
18+
import paddle.nn as nn
19+
import numpy as np
20+
import threading
21+
22+
23+
class SimpleNet(nn.Layer):
24+
def __init__(self, in_dim, out_dim):
25+
super(SimpleNet, self).__init__()
26+
self.fc = nn.Linear(in_dim, out_dim)
27+
28+
def forward(self, x):
29+
return self.fc(x)
30+
31+
32+
class TestCases(unittest.TestCase):
33+
@paddle.no_grad()
34+
def thread_1_main(self):
35+
time.sleep(8)
36+
37+
def thread_2_main(self):
38+
in_dim = 10
39+
out_dim = 3
40+
net = SimpleNet(in_dim, out_dim)
41+
for _ in range(1000):
42+
x = paddle.to_tensor(np.random.rand(32, in_dim).astype('float32'))
43+
self.assertTrue(x.stop_gradient)
44+
x = net(x)
45+
self.assertFalse(x.stop_gradient)
46+
47+
def test_main(self):
48+
threads = []
49+
for _ in range(10):
50+
threads.append(threading.Thread(target=self.thread_1_main))
51+
threads.append(threading.Thread(target=self.thread_2_main))
52+
for t in threads:
53+
t.start()
54+
for t in threads:
55+
t.join()
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

0 commit comments

Comments
 (0)