Skip to content

Commit

Permalink
[Shogi] Make core singlefile (#1273)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Nov 1, 2024
1 parent 1e08b48 commit ffe939e
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 284 deletions.
280 changes: 267 additions & 13 deletions pgx/_src/games/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,281 @@
from typing import NamedTuple
from functools import partial

import numpy as np
import jax
from jax import Array
import jax.numpy as jnp

from pgx._src.shogi_utils import (
AROUND_IX,
BETWEEN_IX,
CAN_MOVE,
CAN_MOVE_ANY,
INIT_PIECE_BOARD,
LEGAL_FROM_IDX,
NEIGHBOUR_IX,
)
from pgx._src.struct import dataclass
from pgx._src.types import Array

MAX_TERMINATION_STEPS = 512 # From AZ paper

TRUE = jnp.bool_(True)
FALSE = jnp.bool_(False)


EMPTY = -1 # 空白
PAWN = 0 # 歩
LANCE = 1 # 香
KNIGHT = 2 # 桂
SILVER = 3 # 銀
BISHOP = 4 # 角
ROOK = 5 # 飛
GOLD = 6 # 金
KING = 7 # 玉
PRO_PAWN = 8 # と
PRO_LANCE = 9 # 成香
PRO_KNIGHT = 10 # 成桂
PRO_SILVER = 11 # 成銀
HORSE = 12 # 馬
DRAGON = 13 # 龍


# fmt: off
INIT_PIECE_BOARD = jnp.int32([[15, -1, 14, -1, -1, -1, 0, -1, 1], # noqa: E241
[16, 18, 14, -1, -1, -1, 0, 5, 2], # noqa: E241
[17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241
[20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241
[21, -1, 14, -1, -1, -1, 0, -1, 7], # noqa: E241
[20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241
[17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241
[16, 19, 14, -1, -1, -1, 0, 4, 2], # noqa: E241
[15, -1, 14, -1, -1, -1, 0, -1, 1]]).flatten() # noqa: E241
# fmt: on


# Can <piece,14> reach from <from,81> to <to,81> ignoring pieces on board?
def can_move_to(piece, from_, to):
"""Can <piece> move from <from_> to <to>?"""
if from_ == to:
return False
x0, y0 = from_ // 9, from_ % 9
x1, y1 = to // 9, to % 9
dx = x1 - x0
dy = y1 - y0
if piece == PAWN:
if dx == 0 and dy == -1:
return True
else:
return False
elif piece == LANCE:
if dx == 0 and dy < 0:
return True
else:
return False
elif piece == KNIGHT:
if dx in (-1, 1) and dy == -2:
return True
else:
return False
elif piece == SILVER:
if dx in (-1, 0, 1) and dy == -1:
return True
elif dx in (-1, 1) and dy == 1:
return True
else:
return False
elif piece == BISHOP:
if dx == dy or dx == -dy:
return True
else:
return False
elif piece == ROOK:
if dx == 0 or dy == 0:
return True
else:
return False
if piece in (GOLD, PRO_PAWN, PRO_LANCE, PRO_KNIGHT, PRO_SILVER):
if dx in (-1, 0, 1) and dy in (0, -1):
return True
elif dx == 0 and dy == 1:
return True
else:
return False
elif piece == KING:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
else:
return False
elif piece == HORSE:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
elif dx == dy or dx == -dy:
return True
else:
return False
elif piece == DRAGON:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
if dx == 0 or dy == 0:
return True
else:
return False
else:
assert False


def is_on_the_way(piece, from_, to, point):
if to == point:
return False
if piece not in (LANCE, BISHOP, ROOK, HORSE, DRAGON):
return False
if not can_move_to(piece, from_, to):
return False
if not can_move_to(piece, from_, point):
return False

x0, y0 = from_ // 9, from_ % 9
x1, y1 = to // 9, to % 9
x2, y2 = point // 9, point % 9
dx1, dy1 = x1 - x0, y1 - y0
dx2, dy2 = x2 - x0, y2 - y0

def sign(d):
if d == 0:
return 0
return d > 0

if (sign(dx1) != sign(dx2)) or (sign(dy1) != sign(dy2)):
return False

return abs(dx2) <= abs(dx1) and abs(dy2) <= abs(dy1)


CAN_MOVE = np.zeros((14, 81, 81), dtype=jnp.bool_)
for piece in range(14):
for from_ in range(81):
for to in range(81):
CAN_MOVE[piece, from_, to] = can_move_to(piece, from_, to)

assert CAN_MOVE.sum() == 8228
CAN_MOVE = jnp.array(CAN_MOVE)


# When <lance/bishop/rook/horse/dragon,5> moves from <from,81> to <to,81>,
# is <point,81> on the way between two points?
BETWEEN = np.zeros((5, 81, 81, 81), dtype=np.bool_)
for i, piece in enumerate((LANCE, BISHOP, ROOK, HORSE, DRAGON)):
for from_ in range(81):
for to in range(81):
for p in range(81):
BETWEEN[i, from_, to, p] = is_on_the_way(piece, from_, to, p)

BETWEEN = jnp.array(BETWEEN)
assert BETWEEN.sum() == 10564


# Give <dir,10> and <to,81>, return the legal <from> idx
# E.g. LEGAL_FROM_IDX[Up, to=19] = [20, 21, ..., -1] (filled by -1)
# Used for computing dlshogi action
#
# dir, to, from
# (10, 81, 81)
#
# 0 Up
# 1 Up left
# 2 Up right
# 3 Left
# 4 Right
# 5 Down
# 6 Down left
# 7 Down right
# 8 Up2 left
# 9 Up2 right

LEGAL_FROM_IDX = -np.ones((10, 81, 8), dtype=jnp.int32) # type: ignore

for dir_ in range(10):
for to in range(81):
x, y = to // 9, to % 9
if dir_ == 0: # Up
dx, dy = 0, +1
elif dir_ == 1: # Up left
dx, dy = -1, +1
elif dir_ == 2: # Up right
dx, dy = +1, +1
elif dir_ == 3: # Left
dx, dy = -1, 0
elif dir_ == 4: # Right
dx, dy = +1, 0
elif dir_ == 5: # Down
dx, dy = 0, -1
elif dir_ == 6: # Down left
dx, dy = -1, -1
elif dir_ == 7: # Down right
dx, dy = +1, -1
elif dir_ == 8: # Up2 left
dx, dy = -1, +2
elif dir_ == 9: # Up2 right
dx, dy = +1, +2
for i in range(8):
x += dx
y += dy
if x < 0 or 8 < x or y < 0 or 8 < y:
break
LEGAL_FROM_IDX[dir_, to, i] = x * 9 + y
if dir_ == 8 or dir_ == 9:
break

LEGAL_FROM_IDX = jnp.array(LEGAL_FROM_IDX) # type: ignore


@jax.jit
@jax.vmap
def can_move_any_ix(from_):
return jnp.nonzero(
(CAN_MOVE[:, from_, :] | CAN_MOVE[:, :, from_]).any(axis=0),
size=36,
fill_value=-1,
)[0]


@jax.jit
@jax.vmap
def neighbour_ix(from_):
return jnp.nonzero(
(CAN_MOVE[7, from_, :] | CAN_MOVE[2, :, from_]),
size=10,
fill_value=-1,
)[0]


NEIGHBOUR_IX = neighbour_ix(jnp.arange(81))


def between_ix(p, from_, to):
return jnp.nonzero(BETWEEN[p, from_, to], size=8, fill_value=-1)[0]


BETWEEN_IX = jax.jit(
jax.vmap(
jax.vmap(jax.vmap(between_ix, (None, None, 0)), (None, 0, None)),
(0, None, None),
)
)(jnp.arange(5), jnp.arange(81), jnp.arange(81))


CAN_MOVE_ANY = can_move_any_ix(jnp.arange(81)) # (81, 36)


def _around(c):
x, y = c // 9, c % 9
dx = jnp.int32([-1, -1, 0, +1, +1, +1, 0, -1])
dy = jnp.int32([0, -1, -1, -1, 0, +1, +1, +1])

def f(i):
new_x, new_y = x + dx[i], y + dy[i]
return jax.lax.select(
(new_x < 0) | (new_x >= 9) | (new_y < 0) | (new_y >= 9),
-1,
new_x * 9 + new_y,
)

return jax.vmap(f)(jnp.arange(8))


AROUND_IX = jax.vmap(_around)(jnp.arange(81))


EMPTY = jnp.int32(-1) # 空白
PAWN = jnp.int32(0) # 歩
LANCE = jnp.int32(1) # 香
Expand Down Expand Up @@ -94,8 +349,7 @@ def legal_action_mask(self, state: GameState) -> Array:
return _legal_action_mask(state)


@dataclass
class Action:
class Action(NamedTuple):
is_drop: Array
piece: Array
to: Array
Expand Down
Loading

0 comments on commit ffe939e

Please sign in to comment.