Skip to content

Commit

Permalink
Change semantics of maxvisits so it counts weightless visits
Browse files Browse the repository at this point in the history
  • Loading branch information
lightvector committed Aug 4, 2024
1 parent 827f755 commit c2efacc
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 22 deletions.
15 changes: 14 additions & 1 deletion cpp/search/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ SearchThread::SearchThread(int tIdx, const Search& search)
history(search.rootHistory),
graphHash(search.rootGraphHash),
graphPath(),
shouldCountPlayout(false),
rand(makeSeed(search,tIdx)),
nnResultBuf(),
statsBuf(),
Expand Down Expand Up @@ -1102,7 +1103,11 @@ bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft)
//Store this value, used for futile-visit pruning this thread's root children selections.
thread.upperBoundVisitsLeft = upperBoundVisitsLeft;

//Prep this value, playoutDescend will set it to true if we do have a playout
thread.shouldCountPlayout = false;

bool finishedPlayout = playoutDescend(thread,*rootNode,true);
(void)finishedPlayout;

//Restore thread state back to the root state
thread.pla = rootPla;
Expand All @@ -1111,7 +1116,7 @@ bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft)
thread.graphHash = rootGraphHash;
thread.graphPath.clear();

return finishedPlayout;
return thread.shouldCountPlayout;
}

bool Search::playoutDescend(
Expand All @@ -1136,6 +1141,7 @@ bool Search::playoutDescend(
double lead = 0.0;
double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0;
addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false);
thread.shouldCountPlayout = true;
return true;
}
else {
Expand All @@ -1146,6 +1152,7 @@ bool Search::playoutDescend(
double lead = scoreMean;
double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0;
addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false);
thread.shouldCountPlayout = true;
return true;
}
}
Expand All @@ -1171,6 +1178,7 @@ bool Search::playoutDescend(
//Perform the nn evaluation and finish!
node.initializeChildren();
node.state.store(SearchNode::STATE_EXPANDED0, std::memory_order_seq_cst);
thread.shouldCountPlayout = true;
return true;
}
}
Expand Down Expand Up @@ -1221,13 +1229,15 @@ bool Search::playoutDescend(
//Return TRUE though, so that the parent path we traversed increments its edge visits.
//We want the search to continue as best it can, so we increment visits so search will still make progress
//even if this keeps happening in some really bad transposition or something.
thread.shouldCountPlayout = true;
return true;
}

if(bestChildIdx <= -1) {
//This might happen if all moves have been forbidden. The node will just get stuck counting visits without expanding
//and we won't do any search.
addCurrentNNOutputAsLeafValue(node,false);
thread.shouldCountPlayout = true;
return true;
}

Expand Down Expand Up @@ -1295,6 +1305,7 @@ bool Search::playoutDescend(
if(countEdgeVisit && maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) {
updateStatsAfterPlayout(node,thread,isRoot);
child->virtualLosses.fetch_add(-1,std::memory_order_release);
thread.shouldCountPlayout = true;
return true;
}
}
Expand All @@ -1312,6 +1323,7 @@ bool Search::playoutDescend(
if(countEdgeVisit && maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) {
updateStatsAfterPlayout(node,thread,isRoot);
child->virtualLosses.fetch_add(-1,std::memory_order_release);
thread.shouldCountPlayout = true;
return true;
}

Expand Down Expand Up @@ -1339,6 +1351,7 @@ bool Search::playoutDescend(
SearchNodeChildrenReference children = node.getChildren(nodeState);
children[bestChildIdx].addEdgeVisits(1);
updateStatsAfterPlayout(node,thread,isRoot);
thread.shouldCountPlayout = true;
}
child->virtualLosses.fetch_add(-1,std::memory_order_release);
// If we didn't count an edge visit, none of the parents need to update either.
Expand Down
4 changes: 4 additions & 0 deletions cpp/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ struct SearchThread {
//The path we trace down the graph as we do a playout
std::unordered_set<SearchNode*> graphPath;

//Tracks whether this thread did something that "should" be counted as a playout
//for the purpose of playout limits
bool shouldCountPlayout;

Rand rand;

NNResultBuf nnResultBuf;
Expand Down
Loading

0 comments on commit c2efacc

Please sign in to comment.