Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/main/scala/game/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ object Encoder:
if 0 < i then if board.isCheck() then output(i - 1) += (if legals.isEmpty then "#" else "+")

if i < plies then
val moveIndex = Huffman.read(reader)
legals.partialSort(moveIndex + 1)
val move = legals.get(moveIndex)
val rank = Huffman.read(reader)
val move = legals.selectRank(rank)
output(i) = san(move, legals)
board.play(move)

Expand Down
54 changes: 23 additions & 31 deletions src/main/scala/game/MoveList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,37 +51,29 @@ final class MoveList(capacity: Int = 256):
buffer(i) = buffer(j)
buffer(j) = tmp

def partialSort(last: Int): Unit =
require(last <= size)
makeHeap(last)
for i <- last until size do
if buffer(i) < buffer(0) then
swap(0, i)
adjustHeap(0, last)
sortHeap(last)

private def makeHeap(last: Int): Unit =
for parent <- last / 2 until 0 by -1 do adjustHeap(parent - 1, last)

private def adjustHeap(holeIndex: Int, len: Int): Unit =
require(len <= size)
require(holeIndex < size)
var leftChild = holeIndex * 2 + 1
var holeDest = holeIndex
val tmp = buffer(holeDest)
while leftChild < len do
if leftChild + 1 < len && buffer(leftChild) < buffer(leftChild + 1) then leftChild += 1
if tmp < buffer(leftChild) then
buffer(holeDest) = buffer(leftChild)
holeDest = leftChild
leftChild = leftChild * 2 + 1
else leftChild = len
buffer(holeDest) = tmp

private def sortHeap(last: Int): Unit =
for i <- last - 1 until 0 by -1 do
swap(0, i)
adjustHeap(0, i)
def selectRank(rank: Int): Move =
require(rank < size)
// Quickselect. Bounds are small enough that naive pivot selection is ok,
// even on adversarial inputs.
var left = 0
var right = size - 1
while left < right do
val pivot = partition(left, right)
if pivot == rank then return buffer(rank)
if pivot < rank then left = pivot + 1
else right = pivot - 1
buffer(rank)

private def partition(left: Int, right: Int): Int =
val pivot = buffer(right)
var i = left - 1
for j <- left until right do
if buffer(j) < pivot then
i += 1
swap(i, j)
i += 1
swap(i, right)
i

def pretty(): String =
val builder = StringBuilder()
Expand Down