forked from leela-zero/leela-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUCTNodePointer.cpp
164 lines (136 loc) · 4.74 KB
/
UCTNodePointer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
/*
This file is part of Leela Zero.
Copyright (C) 2018-2019 Gian-Carlo Pascutto and contributors
Leela Zero is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Zero is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Zero. If not, see <http://www.gnu.org/licenses/>.
Additional permission under GNU GPL version 3 section 7
If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the
NVIDIA CUDA Toolkit and/or the NVIDIA CUDA Deep Neural
Network library and/or the NVIDIA TensorRT inference library
(or a modified version of those libraries), containing parts covered
by the terms of the respective license agreement, the licensors of
this Program grant you additional permission to convey the resulting
work.
*/
#include "config.h"
#include <atomic>
#include <cassert>
#include <cstring>
#include <memory>
#include "UCTNode.h"
std::atomic<size_t> UCTNodePointer::m_tree_size = {0};
size_t UCTNodePointer::get_tree_size() {
return m_tree_size.load();
}
void UCTNodePointer::increment_tree_size(const size_t sz) {
m_tree_size += sz;
}
void UCTNodePointer::decrement_tree_size(const size_t sz) {
assert(UCTNodePointer::m_tree_size >= sz);
m_tree_size -= sz;
}
UCTNodePointer::~UCTNodePointer() {
auto sz = sizeof(UCTNodePointer);
auto v = m_data.load();
if (is_inflated(v)) {
delete read_ptr(v);
sz += sizeof(UCTNode);
}
decrement_tree_size(sz);
}
UCTNodePointer::UCTNodePointer(UCTNodePointer&& n) {
auto nv = std::atomic_exchange(&n.m_data, INVALID);
auto v = std::atomic_exchange(&m_data, nv);
#ifdef NDEBUG
(void)v;
#else
assert(v == INVALID);
#endif
increment_tree_size(sizeof(UCTNodePointer));
}
UCTNodePointer::UCTNodePointer(const std::int16_t vertex, const float policy) {
std::uint32_t i_policy;
auto i_vertex = static_cast<std::uint16_t>(vertex);
std::memcpy(&i_policy, &policy, sizeof(i_policy));
m_data = (static_cast<std::uint64_t>(i_policy) << 32)
| (static_cast<std::uint64_t>(i_vertex) << 16);
increment_tree_size(sizeof(UCTNodePointer));
}
UCTNodePointer& UCTNodePointer::operator=(UCTNodePointer&& n) {
auto nv = std::atomic_exchange(&n.m_data, INVALID);
auto v = std::atomic_exchange(&m_data, nv);
if (is_inflated(v)) {
decrement_tree_size(sizeof(UCTNode));
delete read_ptr(v);
}
return *this;
}
UCTNode* UCTNodePointer::release() {
auto v = std::atomic_exchange(&m_data, INVALID);
decrement_tree_size(sizeof(UCTNode));
return read_ptr(v);
}
void UCTNodePointer::inflate() const {
while (true) {
auto v = m_data.load();
if (is_inflated(v)) return;
auto v2 = reinterpret_cast<std::uint64_t>(
new UCTNode(read_vertex(v), read_policy(v)));
assert((v2 & 3ULL) == 0);
v2 |= POINTER;
bool success = m_data.compare_exchange_strong(v, v2);
if (success) {
increment_tree_size(sizeof(UCTNode));
return;
} else {
// this means that somebody else also modified this instance.
// Try again next time
delete read_ptr(v2);
}
}
}
bool UCTNodePointer::valid() const {
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->valid();
return true;
}
int UCTNodePointer::get_visits() const {
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->get_visits();
return 0;
}
float UCTNodePointer::get_policy() const {
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->get_policy();
return read_policy(v);
}
float UCTNodePointer::get_eval_lcb(const int color) const {
auto v = m_data.load();
assert(is_inflated(v));
return read_ptr(v)->get_eval_lcb(color);
}
bool UCTNodePointer::active() const {
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->active();
return true;
}
float UCTNodePointer::get_eval(const int tomove) const {
// this can only be called if it is an inflated pointer
auto v = m_data.load();
assert(is_inflated(v));
return read_ptr(v)->get_eval(tomove);
}
int UCTNodePointer::get_move() const {
auto v = m_data.load();
if (is_inflated(v)) return read_ptr(v)->get_move();
return read_vertex(v);
}