Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/lczero-common
Submodule lczero-common added at b326b1
221 changes: 168 additions & 53 deletions src/trainingdata/rescorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
#include "trainingdata/rescorer.h"

#include <algorithm>
#include <initializer_list>
#include <optional>
#include <source_location>
#include <span>
#include <sstream>
#include <string>
#include <string_view>

#include "gtb-probe.h"
#include "neural/decoder.h"
Expand Down Expand Up @@ -120,8 +124,49 @@ bool deblunderEnabled = false;
float deblunderQBlunderThreshold = 2.0f;
float deblunderQBlunderWidth = 0.0f;

void DataAssert(bool check_result) {
if (!check_result) throw Exception("Range Violation");
struct RangeViolationDetail {
std::string name;
std::string value;
};

template <typename T>
RangeViolationDetail RangeDetail(std::string_view name, const T& value) {
std::ostringstream stream;
stream << std::boolalpha << value;
return {std::string(name), stream.str()};
}

void DataAssert(bool check_result, std::string_view condition_description = {},
std::initializer_list<RangeViolationDetail> details = {},
std::source_location loc = std::source_location::current()) {
if (check_result) return;

std::ostringstream message;
message << "Range Violation";
if (!condition_description.empty()) {
message << ": " << condition_description;
}
message << " @ " << loc.file_name() << ':' << loc.line();
if (const char* function_name = loc.function_name();
function_name && function_name[0] != '\0') {
message << " in " << function_name;
}
if (!details.size()) {
throw Exception(message.str());
}

message << " [";
bool first = true;
for (const auto& detail : details) {
if (!first) {
message << ", ";
}
first = false;
message << detail.name << '=' << detail.value;
}
message << ']';

throw Exception(message.str());
}

void Validate(std::span<const V6TrainingData> fileContents) {
Expand All @@ -131,82 +176,148 @@ void Validate(std::span<const V6TrainingData> fileContents) {
auto& data = fileContents[i];
DataAssert(
data.input_format ==
pblczero::NetworkFormat::INPUT_CLASSICAL_112_PLANE ||
data.input_format ==
pblczero::NetworkFormat::INPUT_112_WITH_CASTLING_PLANE ||
data.input_format ==
pblczero::NetworkFormat::INPUT_112_WITH_CANONICALIZATION ||
data.input_format == pblczero::NetworkFormat::
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES ||
data.input_format ==
pblczero::NetworkFormat::
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON ||
data.input_format ==
pblczero::NetworkFormat::INPUT_112_WITH_CANONICALIZATION_V2 ||
data.input_format == pblczero::NetworkFormat::
INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON);
DataAssert(data.best_d >= 0.0f && data.best_d <= 1.0f);
DataAssert(data.root_d >= 0.0f && data.root_d <= 1.0f);
DataAssert(data.best_q >= -1.0f && data.best_q <= 1.0f);
DataAssert(data.root_q >= -1.0f && data.root_q <= 1.0f);
DataAssert(data.root_m >= 0.0f);
DataAssert(data.best_m >= 0.0f);
DataAssert(data.plies_left >= 0.0f);
pblczero::NetworkFormat::INPUT_CLASSICAL_112_PLANE ||
data.input_format ==
pblczero::NetworkFormat::INPUT_112_WITH_CASTLING_PLANE ||
data.input_format ==
pblczero::NetworkFormat::INPUT_112_WITH_CANONICALIZATION ||
data.input_format ==
pblczero::NetworkFormat::
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES ||
data.input_format ==
pblczero::NetworkFormat::
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON ||
data.input_format ==
pblczero::NetworkFormat::INPUT_112_WITH_CANONICALIZATION_V2 ||
data.input_format ==
pblczero::NetworkFormat::
INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON,
"Unsupported input format",
{RangeDetail("data.input_format",
static_cast<int>(data.input_format))});
DataAssert(data.best_d >= 0.0f && data.best_d <= 1.0f,
"data.best_d outside [0.0, 1.0]",
{RangeDetail("data.best_d", data.best_d)});
DataAssert(data.root_d >= 0.0f && data.root_d <= 1.0f,
"data.root_d outside [0.0, 1.0]",
{RangeDetail("data.root_d", data.root_d)});
DataAssert(data.best_q >= -1.0f && data.best_q <= 1.0f,
"data.best_q outside [-1.0, 1.0]",
{RangeDetail("data.best_q", data.best_q)});
DataAssert(data.root_q >= -1.0f && data.root_q <= 1.0f,
"data.root_q outside [-1.0, 1.0]",
{RangeDetail("data.root_q", data.root_q)});
DataAssert(data.root_m >= 0.0f, "data.root_m negative",
{RangeDetail("data.root_m", data.root_m)});
DataAssert(data.best_m >= 0.0f, "data.best_m negative",
{RangeDetail("data.best_m", data.best_m)});
DataAssert(data.plies_left >= 0.0f, "data.plies_left negative",
{RangeDetail("data.plies_left", data.plies_left)});
switch (data.input_format) {
case pblczero::NetworkFormat::INPUT_CLASSICAL_112_PLANE:
DataAssert(data.castling_them_oo <= 1);
DataAssert(data.castling_them_ooo <= 1);
DataAssert(data.castling_us_oo <= 1);
DataAssert(data.castling_us_ooo <= 1);
DataAssert(
data.castling_them_oo <= 1,
"data.castling_them_oo exceeds classical encoding",
{RangeDetail("data.castling_them_oo", data.castling_them_oo)});
DataAssert(
data.castling_them_ooo <= 1,
"data.castling_them_ooo exceeds classical encoding",
{RangeDetail("data.castling_them_ooo", data.castling_them_ooo)});
DataAssert(data.castling_us_oo <= 1,
"data.castling_us_oo exceeds classical encoding",
{RangeDetail("data.castling_us_oo", data.castling_us_oo)});
DataAssert(data.castling_us_ooo <= 1,
"data.castling_us_ooo exceeds classical encoding",
{RangeDetail("data.castling_us_ooo", data.castling_us_ooo)});
break;
default:
// Verifiy at most one bit set.
DataAssert((data.castling_them_oo & (data.castling_them_oo - 1)) == 0);
DataAssert((data.castling_them_ooo & (data.castling_them_ooo - 1)) ==
0);
DataAssert((data.castling_us_oo & (data.castling_us_oo - 1)) == 0);
DataAssert((data.castling_us_ooo & (data.castling_us_ooo - 1)) == 0);
DataAssert(
(data.castling_them_oo & (data.castling_them_oo - 1)) == 0,
"data.castling_them_oo has multiple bits set",
{RangeDetail("data.castling_them_oo", data.castling_them_oo)});
DataAssert(
(data.castling_them_ooo & (data.castling_them_ooo - 1)) == 0,
"data.castling_them_ooo has multiple bits set",
{RangeDetail("data.castling_them_ooo", data.castling_them_ooo)});
DataAssert((data.castling_us_oo & (data.castling_us_oo - 1)) == 0,
"data.castling_us_oo has multiple bits set",
{RangeDetail("data.castling_us_oo", data.castling_us_oo)});
DataAssert((data.castling_us_ooo & (data.castling_us_ooo - 1)) == 0,
"data.castling_us_ooo has multiple bits set",
{RangeDetail("data.castling_us_ooo", data.castling_us_ooo)});
}
if (IsCanonicalFormat(static_cast<pblczero::NetworkFormat::InputFormat>(
data.input_format))) {
// At most one en-passant bit.
DataAssert((data.side_to_move_or_enpassant &
(data.side_to_move_or_enpassant - 1)) == 0);
(data.side_to_move_or_enpassant - 1)) == 0,
"side_to_move_or_enpassant has multiple bits set",
{RangeDetail("data.side_to_move_or_enpassant",
data.side_to_move_or_enpassant)});
} else {
DataAssert(data.side_to_move_or_enpassant <= 1);
DataAssert(data.side_to_move_or_enpassant <= 1,
"side_to_move_or_enpassant exceeds classical encoding",
{RangeDetail("data.side_to_move_or_enpassant",
data.side_to_move_or_enpassant)});
}
DataAssert(data.result_q >= -1 && data.result_q <= 1);
DataAssert(data.result_d >= 0 && data.result_q <= 1);
DataAssert(data.rule50_count <= 100);
DataAssert(data.result_q >= -1 && data.result_q <= 1,
"data.result_q outside [-1, 1]",
{RangeDetail("data.result_q", data.result_q)});
DataAssert(data.result_d >= 0 && data.result_q <= 1,
"data.result_d >= 0 && data.result_q <= 1",
{RangeDetail("data.result_d", data.result_d),
RangeDetail("data.result_q", data.result_q)});
DataAssert(data.rule50_count <= 100, "rule50_count exceeds 100",
{RangeDetail("data.rule50_count", data.rule50_count)});
float sum = 0.0f;
for (size_t j = 0; j < sizeof(data.probabilities) / sizeof(float); j++) {
float prob = data.probabilities[j];
DataAssert((prob >= 0.0f && prob <= 1.0f) || prob == -1.0f ||
std::isnan(prob));
DataAssert(
(prob >= 0.0f && prob <= 1.0f) || prob == -1.0f || std::isnan(prob),
"probabilities entry outside legal range",
{RangeDetail("index", j), RangeDetail("prob", prob)});
if (prob >= 0.0f) {
sum += prob;
}
// Only check best_idx/played_idx for real v6 data.
if (data.visits > 0) {
// Best_idx and played_idx must be marked legal in probabilities.
if (j == data.best_idx || j == data.played_idx) {
DataAssert(prob >= 0.0f);
DataAssert(prob >= 0.0f,
"probabilities entry for played/best move is illegal",
{RangeDetail("index", j), RangeDetail("prob", prob),
RangeDetail("data.best_idx", data.best_idx),
RangeDetail("data.played_idx", data.played_idx)});
}
}
}
if (sum < 0.99f || sum > 1.01f) {
throw Exception("Probability sum error is huge!");
}
DataAssert(data.best_idx <= 1858);
DataAssert(data.played_idx <= 1858);
DataAssert(data.played_q >= -1.0f && data.played_q <= 1.0f);
DataAssert(data.played_d >= 0.0f && data.played_d <= 1.0f);
DataAssert(data.played_m >= 0.0f);
DataAssert(data.best_idx <= 1858, "data.best_idx exceeds 1858",
{RangeDetail("data.best_idx", data.best_idx)});
DataAssert(data.played_idx <= 1858, "data.played_idx exceeds 1858",
{RangeDetail("data.played_idx", data.played_idx)});
DataAssert(data.played_q >= -1.0f && data.played_q <= 1.0f,
"data.played_q outside [-1.0, 1.0]",
{RangeDetail("data.played_q", data.played_q)});
DataAssert(data.played_d >= 0.0f && data.played_d <= 1.0f,
"data.played_d outside [0.0, 1.0]",
{RangeDetail("data.played_d", data.played_d)});
DataAssert(data.played_m >= 0.0f, "data.played_m negative",
{RangeDetail("data.played_m", data.played_m)});
DataAssert(std::isnan(data.orig_q) ||
(data.orig_q >= -1.0f && data.orig_q <= 1.0f));
DataAssert(std::isnan(data.orig_d) ||
(data.orig_d >= 0.0f && data.orig_d <= 1.0f));
DataAssert(std::isnan(data.orig_m) || data.orig_m >= 0.0f);
(data.orig_q >= -1.0f && data.orig_q <= 1.0f),
"data.orig_q outside [-1.0, 1.0]",
{RangeDetail("data.orig_q", data.orig_q)});
DataAssert(
std::isnan(data.orig_d) || (data.orig_d >= 0.0f && data.orig_d <= 1.0f),
"data.orig_d outside [0.0, 1.0]",
{RangeDetail("data.orig_d", data.orig_d)});
DataAssert(std::isnan(data.orig_m) || data.orig_m >= 0.0f,
"data.orig_m negative",
{RangeDetail("data.orig_m", data.orig_m)});
// TODO: if visits > 0 - assert best_idx/played_idx are valid in
// probabilities.
}
Expand Down Expand Up @@ -427,11 +538,15 @@ void ChangeInputFormat(int newInputFormat, V6TrainingData* data,
int ResultForData(const V6TrainingData& data) {
// Ensure we aren't reprocessing some data that has had custom adjustments to
// result training target applied.
DataAssert(data.result_q == -1.0f || data.result_q == 1.0f ||
data.result_q == 0.0f);
DataAssert(
data.result_q == -1.0f || data.result_q == 1.0f || data.result_q == 0.0f,
"data.result_q must be -1, 0, or 1",
{RangeDetail("data.result_q", data.result_q)});
// Paranoia - ensure int cast never breaks the value.
DataAssert(data.result_q ==
static_cast<float>(static_cast<int>(data.result_q)));
DataAssert(
data.result_q == static_cast<float>(static_cast<int>(data.result_q)),
"data.result_q loses precision when cast to int",
{RangeDetail("data.result_q", data.result_q)});
return static_cast<int>(data.result_q);
}

Expand Down