Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Play move with the highest lower confidence bound #817

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ install:
- cmd: IF %NAME%==opencl set OPENCL=true
- cmd: IF %NAME%==blas set BLAS=true
- cmd: IF %NAME%==blas set GTEST=true
- cmd: set BOOST_ROOT=C:\Libraries\boost_1_69_0
- cmd: IF %BLAS%==true IF NOT EXIST C:\cache\OpenBLAS appveyor DownloadFile https://sjeng.org/ftp/OpenBLAS-0.3.3-win-oldthread.zip
- cmd: IF %BLAS%==true IF NOT EXIST C:\cache\OpenBLAS 7z x OpenBLAS-0.3.3-win-oldthread.zip -oC:\cache\OpenBLAS
- cmd: IF %OPENCL%==true nuget install opencl-nug -Version 0.777.77 -OutputDirectory C:\cache
Expand Down
4 changes: 4 additions & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ has_backends = false
# Third party files.
includes += include_directories('third_party', is_system: true)

# Boost
deps += dependency('boost')

# Both protobuf and protoc must be the same version, so couple them together.
protobuf_lib = cc.find_library('libprotobuf', dirs : get_option('protobuf_libdir'), required : false)
if not protobuf_lib.found()
Expand Down Expand Up @@ -118,6 +121,7 @@ files += [
'src/utils/optionsdict.cc',
'src/utils/optionsparser.cc',
'src/utils/random.cc',
'src/utils/stats.cc',
'src/utils/string.cc',
'src/utils/transpose.cc',
'src/utils/weights_adapter.cc',
Expand Down
17 changes: 14 additions & 3 deletions src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mcts/search.h"
#include "utils/configfile.h"
#include "utils/logging.h"
#include "utils/stats.h"

namespace lczero {
namespace {
Expand Down Expand Up @@ -91,6 +92,10 @@ const OptionId kRamLimitMbId{
"terminal node counted several times, and the estimation assumes that all "
"positions have 30 possible moves. When set to 0, no RAM limit is "
"enforced."};
const OptionId kCIAlphaId{"cialpha", "CIAlpha",
"Confidence interval certainty used for calculating "
"the confidence interval "
"of moves evaluation."};

const size_t kAvgNodeSize = sizeof(Node) + kAvgMovesPerPosition * sizeof(Edge);
const size_t kAvgCacheItemSize =
Expand Down Expand Up @@ -143,12 +148,15 @@ void EngineController::PopulateOptions(OptionsParser* options) {
options->Add<BoolOption>(kPonderId) = true;
options->Add<FloatOption>(kSpendSavedTimeId, 0.0f, 1.0f) = 1.0f;
options->Add<IntOption>(kRamLimitMbId, 0, 100000000) = 0;
options->Add<FloatOption>(kCIAlphaId, 0.0f, 1.0f) = 2e-5f;

ConfigFile::PopulateOptions(options);

// Hide time curve options.
options->HideOption(kTimeMidpointMoveId);
options->HideOption(kTimeSteepnessId);

CreatezTable(options_.Get<float>(kCIAlphaId.GetId()));
}

SearchLimits EngineController::PopulateSearchLimits(
Expand Down Expand Up @@ -187,8 +195,10 @@ SearchLimits EngineController::PopulateSearchLimits(

// How to scale moves time.
const float slowmover = options_.Get<float>(kSlowMoverId.GetId());
const float time_curve_midpoint = options_.Get<float>(kTimeMidpointMoveId.GetId());
const float time_curve_steepness = options_.Get<float>(kTimeSteepnessId.GetId());
const float time_curve_midpoint =
options_.Get<float>(kTimeMidpointMoveId.GetId());
const float time_curve_steepness =
options_.Get<float>(kTimeSteepnessId.GetId());

float movestogo =
ComputeEstimatedMovesToGo(ply, time_curve_midpoint, time_curve_steepness);
Expand Down Expand Up @@ -266,7 +276,8 @@ void EngineController::UpdateFromUciOptions() {
}

// Network.
const auto network_configuration = NetworkFactory::BackendConfiguration(options_);
const auto network_configuration =
NetworkFactory::BackendConfiguration(options_);
if (network_configuration_ != network_configuration) {
network_ = NetworkFactory::LoadNetwork(options_);
network_configuration_ = network_configuration;
Expand Down
3 changes: 3 additions & 0 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ void Node::CancelScoreUpdate(int multivisit) {

void Node::FinalizeScoreUpdate(float v, float d, int multivisit) {
// Recompute Q.
// Need to divide by 4 to scale to 0..1 range probability.
q_squared_diff_ +=
0.25f * (v - q_) * (v - q_) * multivisit * n_ / (n_ + multivisit);
q_ += multivisit * (v - q_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);

Expand Down
27 changes: 26 additions & 1 deletion src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "neural/encoder.h"
#include "neural/writer.h"
#include "utils/mutex.h"
#include "utils/stats.h"

namespace lczero {

Expand Down Expand Up @@ -155,6 +156,7 @@ class Node {
// for terminal nodes.
float GetQ() const { return q_; }
float GetD() const { return d_; }
float GetSquaredDiff() const { return q_squared_diff_; }

// Returns whether the node is known to be draw/lose/win.
bool IsTerminal() const { return is_terminal_; }
Expand Down Expand Up @@ -270,6 +272,8 @@ class Node {
// Averaged draw probability. Works similarly to Q, except that D is not
// flipped depending on the side to move.
float d_ = 0.0f;
// Sum of squared differences to mean. Used for calculating variance.
float q_squared_diff_ = 0.0f;
// Sum of policy priors which have had at least one playout.
float visited_policy_ = 0.0f;
// How many completed visits this node had.
Expand Down Expand Up @@ -308,7 +312,7 @@ class Node {

// A basic sanity check. This must be adjusted when Node members are adjusted.
#if defined(__i386__) || (defined(__arm__) && !defined(__aarch64__))
static_assert(sizeof(Node) == 52, "Unexpected size of Node for 32bit compile");
static_assert(sizeof(Node) == 56, "Unexpected size of Node for 32bit compile");
#else
static_assert(sizeof(Node) == 80, "Unexpected size of Node");
#endif
Expand Down Expand Up @@ -340,6 +344,27 @@ class EdgeAndNode {
float GetD() const {
return (node_ && node_->GetN() > 0) ? node_->GetD() : 0.0f;
}
float GetVariance(float default_var) const {
if (node_ && node_->GetN() > 1) {
return node_->GetSquaredDiff() / (node_->GetN() - 1);
}
return default_var;
}
float GetQLCB() const {
if (!node_) {
return -1e6f;
}
auto visits = node_->GetN();
if (visits < 2) {
// Return large negative value if not enough visits.
return -1e6f + visits;
}

auto stddev = std::sqrt(GetVariance(1.0f) / visits);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this is the standard error (standard deviation divided by sqrt(N)). Maybe stderr is a better variable name.

Or is this the standard deviation of something in another context?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of the estimated distribution for the mean which has this standard deviation. Standard error seems to be correct too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that seems logical, I didn't think about it that way :).

auto z = CachedtQuantile(visits - 1);

return (0.5f + 0.5f * node_->GetQ()) - z * stddev;
}
// N-related getters, from Node (if exists).
uint32_t GetN() const { return node_ ? node_->GetN() : 0; }
int GetNStarted() const { return node_ ? node_->GetNStarted() : 0; }
Expand Down
5 changes: 5 additions & 0 deletions src/mcts/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ const OptionId SearchParams::kKLDGainAverageInterval{
"kldgain-average-interval", "KLDGainAverageInterval",
"Used to decide how frequently to evaluate the average KLDGainPerNode to "
"check the MinimumKLDGainPerNode, if specified."};
const OptionId SearchParams::kMinimumLCBNRatioId{
"minimum-lcb-n-ratio", "MinimumLCBNRatio",
"Minimum ratio for nodes of the played move the the move with the most "
"nodes."};

void SearchParams::Populate(OptionsParser* options) {
// Here the uci optimized defaults" are set.
Expand Down Expand Up @@ -217,6 +221,7 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<ChoiceOption>(kHistoryFillId, history_fill_opt) = "fen_only";
options->Add<IntOption>(kKLDGainAverageInterval, 1, 10000000) = 100;
options->Add<FloatOption>(kMinimumKLDGainPerNode, 0.0f, 1.0f) = 0.0f;
options->Add<FloatOption>(kMinimumLCBNRatioId, 0.0f, 1.0f) = 0.32f;

options->HideOption(kLogLiveStatsId);
}
Expand Down
4 changes: 4 additions & 0 deletions src/mcts/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class SearchParams {
float GetMinimumKLDGainPerNode() const {
return options_.Get<float>(kMinimumKLDGainPerNode.GetId());
}
float GetMinimumLCBNRatio() const {
return options_.Get<float>(kMinimumLCBNRatioId.GetId());
}

// Search parameter IDs.
static const OptionId kMiniBatchSizeId;
Expand Down Expand Up @@ -129,6 +132,7 @@ class SearchParams {
static const OptionId kHistoryFillId;
static const OptionId kMinimumKLDGainPerNode;
static const OptionId kKLDGainAverageInterval;
static const OptionId kMinimumLCBNRatioId;

private:
const OptionsDict& options_;
Expand Down
40 changes: 28 additions & 12 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,13 @@ std::vector<std::string> Search::GetVerboseStats(Node* node,
std::vector<EdgeAndNode> edges;
for (const auto& edge : node->Edges()) edges.push_back(edge);

std::sort(
edges.begin(), edges.end(),
[&fpu, &U_coeff](EdgeAndNode a, EdgeAndNode b) {
return std::forward_as_tuple(a.GetN(), a.GetQ(fpu) + a.GetU(U_coeff)) <
std::forward_as_tuple(b.GetN(), b.GetQ(fpu) + b.GetU(U_coeff));
});
std::sort(edges.begin(), edges.end(),
[&fpu, &U_coeff](EdgeAndNode a, EdgeAndNode b) {
return std::forward_as_tuple(a.GetQLCB(),
a.GetQ(fpu) + a.GetU(U_coeff)) <
std::forward_as_tuple(b.GetQLCB(),
b.GetQ(fpu) + b.GetU(U_coeff));
});

std::vector<std::string> infos;
for (const auto& edge : edges) {
Expand All @@ -239,15 +240,15 @@ std::vector<std::string> Search::GetVerboseStats(Node* node,
oss << "(Q: " << std::setw(8) << std::setprecision(5) << edge.GetQ(fpu)
<< ") ";

oss << "(LCB: " << std::setw(8) << std::setprecision(5)
<< std::max(0.0f, edge.GetQLCB()) << ") ";

oss << "(D: " << std::setw(6) << std::setprecision(3) << edge.GetD()
<< ") ";

oss << "(U: " << std::setw(6) << std::setprecision(5) << edge.GetU(U_coeff)
<< ") ";

oss << "(Q+U: " << std::setw(8) << std::setprecision(5)
<< edge.GetQ(fpu) + edge.GetU(U_coeff) << ") ";

oss << "(V: ";
optional<float> v;
if (edge.IsTerminal()) {
Expand Down Expand Up @@ -542,20 +543,35 @@ std::vector<EdgeAndNode> Search::GetBestChildrenNoTemperature(Node* parent,
if (parent == root_node_) {
PopulateRootMoveLimit(&root_limit);
}
unsigned int max_n = 0;
for (auto edge : parent->Edges()) {
if (parent == root_node_ && !root_limit.empty() &&
std::find(root_limit.begin(), root_limit.end(), edge.GetMove()) ==
root_limit.end()) {
continue;
}
if (edge.GetN() > max_n) max_n = edge.GetN();
}

float min_ratio = max_n * params_.GetMinimumLCBNRatio();

// Best child is selected using the following criteria:
// * Minimum LCB node ratio exceeded.
// * Largest lower confidence bound.
// * Largest number of playouts.
// * If two nodes have equal number:
// * If that number is 0, the one with larger prior wins.
// * If that number is larger than 0, the one with larger eval wins.
using El = std::tuple<uint64_t, float, float, EdgeAndNode>;
using El = std::tuple<bool, float, uint64_t, float, float, EdgeAndNode>;
std::vector<El> edges;
for (auto edge : parent->Edges()) {
if (parent == root_node_ && !root_limit.empty() &&
std::find(root_limit.begin(), root_limit.end(), edge.GetMove()) ==
root_limit.end()) {
continue;
}
edges.emplace_back(edge.GetN(), edge.GetQ(0), edge.GetP(), edge);
edges.emplace_back(edge.GetN() > min_ratio, edge.GetQLCB(), edge.GetN(),
edge.GetQ(0), edge.GetP(), edge);
}
const auto middle = (static_cast<int>(edges.size()) > count)
? edges.begin() + count
Expand All @@ -564,7 +580,7 @@ std::vector<EdgeAndNode> Search::GetBestChildrenNoTemperature(Node* parent,

std::vector<EdgeAndNode> res;
std::transform(edges.begin(), middle, std::back_inserter(res),
[](const El& x) { return std::get<3>(x); });
[](const El& x) { return std::get<5>(x); });
return res;
}

Expand Down
59 changes: 59 additions & 0 deletions src/utils/stats.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2019 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.

Additional permission under GNU GPL version 3 section 7

If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/

#include "stats.h"
#include <boost/math/distributions/students_t.hpp>

namespace lczero {

namespace {
auto constexpr kzEntries = 1000;
std::array<float, kzEntries> kzLookup;
} // namespace

void CreatezTable(float ci_alpha) {
for (auto i = 1; i < kzEntries + 1; i++) {
boost::math::students_t dist(i);
auto z = boost::math::quantile(boost::math::complement(dist, ci_alpha));
kzLookup[i - 1] = z;
}
}

float CachedtQuantile(int v) {
if (v < 1) {
return kzLookup[0];
}
if (v < kzEntries) {
return kzLookup[v - 1];
}
// z approaches constant when v is high enough.
// With default lookup table size the function is flat enough that we
// can just return the last entry for all v bigger than it.
return kzLookup[kzEntries - 1];
}

} // namespace lczero
35 changes: 35 additions & 0 deletions src/utils/stats.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
This file is part of Leela Chess Zero.
Copyright (C) 2019 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.

Additional permission under GNU GPL version 3 section 7

If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/

#pragma once

namespace lczero {

void CreatezTable(float ci_alpha);
float CachedtQuantile(int v);

} // namespace lczero