diff --git a/pgx/_src/games/shogi.py b/pgx/_src/games/shogi.py index 3583e84b2..1a1b39f61 100644 --- a/pgx/_src/games/shogi.py +++ b/pgx/_src/games/shogi.py @@ -16,26 +16,281 @@ from typing import NamedTuple from functools import partial +import numpy as np import jax +from jax import Array import jax.numpy as jnp -from pgx._src.shogi_utils import ( - AROUND_IX, - BETWEEN_IX, - CAN_MOVE, - CAN_MOVE_ANY, - INIT_PIECE_BOARD, - LEGAL_FROM_IDX, - NEIGHBOUR_IX, -) -from pgx._src.struct import dataclass -from pgx._src.types import Array MAX_TERMINATION_STEPS = 512 # From AZ paper TRUE = jnp.bool_(True) FALSE = jnp.bool_(False) + +EMPTY = -1 # 空白 +PAWN = 0 # 歩 +LANCE = 1 # 香 +KNIGHT = 2 # 桂 +SILVER = 3 # 銀 +BISHOP = 4 # 角 +ROOK = 5 # 飛 +GOLD = 6 # 金 +KING = 7 # 玉 +PRO_PAWN = 8 # と +PRO_LANCE = 9 # 成香 +PRO_KNIGHT = 10 # 成桂 +PRO_SILVER = 11 # 成銀 +HORSE = 12 # 馬 +DRAGON = 13 # 龍 + + +# fmt: off +INIT_PIECE_BOARD = jnp.int32([[15, -1, 14, -1, -1, -1, 0, -1, 1], # noqa: E241 + [16, 18, 14, -1, -1, -1, 0, 5, 2], # noqa: E241 + [17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241 + [20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241 + [21, -1, 14, -1, -1, -1, 0, -1, 7], # noqa: E241 + [20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241 + [17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241 + [16, 19, 14, -1, -1, -1, 0, 4, 2], # noqa: E241 + [15, -1, 14, -1, -1, -1, 0, -1, 1]]).flatten() # noqa: E241 +# fmt: on + + +# Can reach from to ignoring pieces on board? +def can_move_to(piece, from_, to): + """Can move from to ?""" + if from_ == to: + return False + x0, y0 = from_ // 9, from_ % 9 + x1, y1 = to // 9, to % 9 + dx = x1 - x0 + dy = y1 - y0 + if piece == PAWN: + if dx == 0 and dy == -1: + return True + else: + return False + elif piece == LANCE: + if dx == 0 and dy < 0: + return True + else: + return False + elif piece == KNIGHT: + if dx in (-1, 1) and dy == -2: + return True + else: + return False + elif piece == SILVER: + if dx in (-1, 0, 1) and dy == -1: + return True + elif dx in (-1, 1) and dy == 1: + return True + else: + return False + elif piece == BISHOP: + if dx == dy or dx == -dy: + return True + else: + return False + elif piece == ROOK: + if dx == 0 or dy == 0: + return True + else: + return False + if piece in (GOLD, PRO_PAWN, PRO_LANCE, PRO_KNIGHT, PRO_SILVER): + if dx in (-1, 0, 1) and dy in (0, -1): + return True + elif dx == 0 and dy == 1: + return True + else: + return False + elif piece == KING: + if abs(dx) <= 1 and abs(dy) <= 1: + return True + else: + return False + elif piece == HORSE: + if abs(dx) <= 1 and abs(dy) <= 1: + return True + elif dx == dy or dx == -dy: + return True + else: + return False + elif piece == DRAGON: + if abs(dx) <= 1 and abs(dy) <= 1: + return True + if dx == 0 or dy == 0: + return True + else: + return False + else: + assert False + + +def is_on_the_way(piece, from_, to, point): + if to == point: + return False + if piece not in (LANCE, BISHOP, ROOK, HORSE, DRAGON): + return False + if not can_move_to(piece, from_, to): + return False + if not can_move_to(piece, from_, point): + return False + + x0, y0 = from_ // 9, from_ % 9 + x1, y1 = to // 9, to % 9 + x2, y2 = point // 9, point % 9 + dx1, dy1 = x1 - x0, y1 - y0 + dx2, dy2 = x2 - x0, y2 - y0 + + def sign(d): + if d == 0: + return 0 + return d > 0 + + if (sign(dx1) != sign(dx2)) or (sign(dy1) != sign(dy2)): + return False + + return abs(dx2) <= abs(dx1) and abs(dy2) <= abs(dy1) + + +CAN_MOVE = np.zeros((14, 81, 81), dtype=jnp.bool_) +for piece in range(14): + for from_ in range(81): + for to in range(81): + CAN_MOVE[piece, from_, to] = can_move_to(piece, from_, to) + +assert CAN_MOVE.sum() == 8228 +CAN_MOVE = jnp.array(CAN_MOVE) + + +# When moves from to , +# is on the way between two points? +BETWEEN = np.zeros((5, 81, 81, 81), dtype=np.bool_) +for i, piece in enumerate((LANCE, BISHOP, ROOK, HORSE, DRAGON)): + for from_ in range(81): + for to in range(81): + for p in range(81): + BETWEEN[i, from_, to, p] = is_on_the_way(piece, from_, to, p) + +BETWEEN = jnp.array(BETWEEN) +assert BETWEEN.sum() == 10564 + + +# Give and , return the legal idx +# E.g. LEGAL_FROM_IDX[Up, to=19] = [20, 21, ..., -1] (filled by -1) +# Used for computing dlshogi action +# +# dir, to, from +# (10, 81, 81) +# +# 0 Up +# 1 Up left +# 2 Up right +# 3 Left +# 4 Right +# 5 Down +# 6 Down left +# 7 Down right +# 8 Up2 left +# 9 Up2 right + +LEGAL_FROM_IDX = -np.ones((10, 81, 8), dtype=jnp.int32) # type: ignore + +for dir_ in range(10): + for to in range(81): + x, y = to // 9, to % 9 + if dir_ == 0: # Up + dx, dy = 0, +1 + elif dir_ == 1: # Up left + dx, dy = -1, +1 + elif dir_ == 2: # Up right + dx, dy = +1, +1 + elif dir_ == 3: # Left + dx, dy = -1, 0 + elif dir_ == 4: # Right + dx, dy = +1, 0 + elif dir_ == 5: # Down + dx, dy = 0, -1 + elif dir_ == 6: # Down left + dx, dy = -1, -1 + elif dir_ == 7: # Down right + dx, dy = +1, -1 + elif dir_ == 8: # Up2 left + dx, dy = -1, +2 + elif dir_ == 9: # Up2 right + dx, dy = +1, +2 + for i in range(8): + x += dx + y += dy + if x < 0 or 8 < x or y < 0 or 8 < y: + break + LEGAL_FROM_IDX[dir_, to, i] = x * 9 + y + if dir_ == 8 or dir_ == 9: + break + +LEGAL_FROM_IDX = jnp.array(LEGAL_FROM_IDX) # type: ignore + + +@jax.jit +@jax.vmap +def can_move_any_ix(from_): + return jnp.nonzero( + (CAN_MOVE[:, from_, :] | CAN_MOVE[:, :, from_]).any(axis=0), + size=36, + fill_value=-1, + )[0] + + +@jax.jit +@jax.vmap +def neighbour_ix(from_): + return jnp.nonzero( + (CAN_MOVE[7, from_, :] | CAN_MOVE[2, :, from_]), + size=10, + fill_value=-1, + )[0] + + +NEIGHBOUR_IX = neighbour_ix(jnp.arange(81)) + + +def between_ix(p, from_, to): + return jnp.nonzero(BETWEEN[p, from_, to], size=8, fill_value=-1)[0] + + +BETWEEN_IX = jax.jit( + jax.vmap( + jax.vmap(jax.vmap(between_ix, (None, None, 0)), (None, 0, None)), + (0, None, None), + ) +)(jnp.arange(5), jnp.arange(81), jnp.arange(81)) + + +CAN_MOVE_ANY = can_move_any_ix(jnp.arange(81)) # (81, 36) + + +def _around(c): + x, y = c // 9, c % 9 + dx = jnp.int32([-1, -1, 0, +1, +1, +1, 0, -1]) + dy = jnp.int32([0, -1, -1, -1, 0, +1, +1, +1]) + + def f(i): + new_x, new_y = x + dx[i], y + dy[i] + return jax.lax.select( + (new_x < 0) | (new_x >= 9) | (new_y < 0) | (new_y >= 9), + -1, + new_x * 9 + new_y, + ) + + return jax.vmap(f)(jnp.arange(8)) + + +AROUND_IX = jax.vmap(_around)(jnp.arange(81)) + + EMPTY = jnp.int32(-1) # 空白 PAWN = jnp.int32(0) # 歩 LANCE = jnp.int32(1) # 香 @@ -94,8 +349,7 @@ def legal_action_mask(self, state: GameState) -> Array: return _legal_action_mask(state) -@dataclass -class Action: +class Action(NamedTuple): is_drop: Array piece: Array to: Array diff --git a/pgx/_src/shogi_utils.py b/pgx/_src/shogi_utils.py index 3247548e7..fac45bb7d 100644 --- a/pgx/_src/shogi_utils.py +++ b/pgx/_src/shogi_utils.py @@ -19,276 +19,6 @@ import jax.numpy as jnp import numpy as np -# fmt: off -INIT_PIECE_BOARD = jnp.int32([[15, -1, 14, -1, -1, -1, 0, -1, 1], # noqa: E241 - [16, 18, 14, -1, -1, -1, 0, 5, 2], # noqa: E241 - [17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241 - [20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241 - [21, -1, 14, -1, -1, -1, 0, -1, 7], # noqa: E241 - [20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241 - [17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241 - [16, 19, 14, -1, -1, -1, 0, 4, 2], # noqa: E241 - [15, -1, 14, -1, -1, -1, 0, -1, 1]]).flatten() # noqa: E241 -# fmt: on - -EMPTY = -1 # 空白 -PAWN = 0 # 歩 -LANCE = 1 # 香 -KNIGHT = 2 # 桂 -SILVER = 3 # 銀 -BISHOP = 4 # 角 -ROOK = 5 # 飛 -GOLD = 6 # 金 -KING = 7 # 玉 -PRO_PAWN = 8 # と -PRO_LANCE = 9 # 成香 -PRO_KNIGHT = 10 # 成桂 -PRO_SILVER = 11 # 成銀 -HORSE = 12 # 馬 -DRAGON = 13 # 龍 - - -# Can reach from to ignoring pieces on board? -def can_move_to(piece, from_, to): - """Can move from to ?""" - if from_ == to: - return False - x0, y0 = from_ // 9, from_ % 9 - x1, y1 = to // 9, to % 9 - dx = x1 - x0 - dy = y1 - y0 - if piece == PAWN: - if dx == 0 and dy == -1: - return True - else: - return False - elif piece == LANCE: - if dx == 0 and dy < 0: - return True - else: - return False - elif piece == KNIGHT: - if dx in (-1, 1) and dy == -2: - return True - else: - return False - elif piece == SILVER: - if dx in (-1, 0, 1) and dy == -1: - return True - elif dx in (-1, 1) and dy == 1: - return True - else: - return False - elif piece == BISHOP: - if dx == dy or dx == -dy: - return True - else: - return False - elif piece == ROOK: - if dx == 0 or dy == 0: - return True - else: - return False - if piece in (GOLD, PRO_PAWN, PRO_LANCE, PRO_KNIGHT, PRO_SILVER): - if dx in (-1, 0, 1) and dy in (0, -1): - return True - elif dx == 0 and dy == 1: - return True - else: - return False - elif piece == KING: - if abs(dx) <= 1 and abs(dy) <= 1: - return True - else: - return False - elif piece == HORSE: - if abs(dx) <= 1 and abs(dy) <= 1: - return True - elif dx == dy or dx == -dy: - return True - else: - return False - elif piece == DRAGON: - if abs(dx) <= 1 and abs(dy) <= 1: - return True - if dx == 0 or dy == 0: - return True - else: - return False - else: - assert False - - -def is_on_the_way(piece, from_, to, point): - if to == point: - return False - if piece not in (LANCE, BISHOP, ROOK, HORSE, DRAGON): - return False - if not can_move_to(piece, from_, to): - return False - if not can_move_to(piece, from_, point): - return False - - x0, y0 = from_ // 9, from_ % 9 - x1, y1 = to // 9, to % 9 - x2, y2 = point // 9, point % 9 - dx1, dy1 = x1 - x0, y1 - y0 - dx2, dy2 = x2 - x0, y2 - y0 - - def sign(d): - if d == 0: - return 0 - return d > 0 - - if (sign(dx1) != sign(dx2)) or (sign(dy1) != sign(dy2)): - return False - - return abs(dx2) <= abs(dx1) and abs(dy2) <= abs(dy1) - - -CAN_MOVE = np.zeros((14, 81, 81), dtype=jnp.bool_) -for piece in range(14): - for from_ in range(81): - for to in range(81): - CAN_MOVE[piece, from_, to] = can_move_to(piece, from_, to) - -assert CAN_MOVE.sum() == 8228 -CAN_MOVE = jnp.array(CAN_MOVE) - - -# When moves from to , -# is on the way between two points? -BETWEEN = np.zeros((5, 81, 81, 81), dtype=np.bool_) -for i, piece in enumerate((LANCE, BISHOP, ROOK, HORSE, DRAGON)): - for from_ in range(81): - for to in range(81): - for p in range(81): - BETWEEN[i, from_, to, p] = is_on_the_way(piece, from_, to, p) - -BETWEEN = jnp.array(BETWEEN) -assert BETWEEN.sum() == 10564 - - -# Give and , return the legal idx -# E.g. LEGAL_FROM_IDX[Up, to=19] = [20, 21, ..., -1] (filled by -1) -# Used for computing dlshogi action -# -# dir, to, from -# (10, 81, 81) -# -# 0 Up -# 1 Up left -# 2 Up right -# 3 Left -# 4 Right -# 5 Down -# 6 Down left -# 7 Down right -# 8 Up2 left -# 9 Up2 right - -LEGAL_FROM_IDX = -np.ones((10, 81, 8), dtype=jnp.int32) # type: ignore - -for dir_ in range(10): - for to in range(81): - x, y = to // 9, to % 9 - if dir_ == 0: # Up - dx, dy = 0, +1 - elif dir_ == 1: # Up left - dx, dy = -1, +1 - elif dir_ == 2: # Up right - dx, dy = +1, +1 - elif dir_ == 3: # Left - dx, dy = -1, 0 - elif dir_ == 4: # Right - dx, dy = +1, 0 - elif dir_ == 5: # Down - dx, dy = 0, -1 - elif dir_ == 6: # Down left - dx, dy = -1, -1 - elif dir_ == 7: # Down right - dx, dy = +1, -1 - elif dir_ == 8: # Up2 left - dx, dy = -1, +2 - elif dir_ == 9: # Up2 right - dx, dy = +1, +2 - for i in range(8): - x += dx - y += dy - if x < 0 or 8 < x or y < 0 or 8 < y: - break - LEGAL_FROM_IDX[dir_, to, i] = x * 9 + y - if dir_ == 8 or dir_ == 9: - break - -LEGAL_FROM_IDX = jnp.array(LEGAL_FROM_IDX) # type: ignore - - -@jax.jit -@jax.vmap -def can_move_any_ix(from_): - return jnp.nonzero( - (CAN_MOVE[:, from_, :] | CAN_MOVE[:, :, from_]).any(axis=0), - size=36, - fill_value=-1, - )[0] - - -@jax.jit -@jax.vmap -def neighbour_ix(from_): - return jnp.nonzero( - (CAN_MOVE[7, from_, :] | CAN_MOVE[2, :, from_]), - size=10, - fill_value=-1, - )[0] - - -NEIGHBOUR_IX = neighbour_ix(jnp.arange(81)) - - -def between_ix(p, from_, to): - return jnp.nonzero(BETWEEN[p, from_, to], size=8, fill_value=-1)[0] - - -BETWEEN_IX = jax.jit( - jax.vmap( - jax.vmap(jax.vmap(between_ix, (None, None, 0)), (None, 0, None)), - (0, None, None), - ) -)(jnp.arange(5), jnp.arange(81), jnp.arange(81)) - - -CAN_MOVE_ANY = can_move_any_ix(jnp.arange(81)) # (81, 36) - -INIT_LEGAL_ACTION_MASK = jnp.zeros(81 * 27, dtype=jnp.bool_) -# fmt: off -ixs = [5, 7, 14, 23, 25, 32, 34, 41, 43, 50, 52, 59, 61, 68, 77, 79, 115, 124, 133, 142, 187, 196, 205, 214, 268, 277, 286, 295, 304, 331] -# fmt: on -for ix in ixs: - INIT_LEGAL_ACTION_MASK = INIT_LEGAL_ACTION_MASK.at[ix].set(True) -assert INIT_LEGAL_ACTION_MASK.shape == (81 * 27,) -assert INIT_LEGAL_ACTION_MASK.sum() == 30 - - -def _around(c): - x, y = c // 9, c % 9 - dx = jnp.int32([-1, -1, 0, +1, +1, +1, 0, -1]) - dy = jnp.int32([0, -1, -1, -1, 0, +1, +1, +1]) - - def f(i): - new_x, new_y = x + dx[i], y + dy[i] - return jax.lax.select( - (new_x < 0) | (new_x >= 9) | (new_y < 0) | (new_y >= 9), - -1, - new_x * 9 + new_y, - ) - - return jax.vmap(f)(jnp.arange(8)) - - -AROUND_IX = jax.vmap(_around)(jnp.arange(81)) - def _to_sfen(state): """Convert state into sfen expression. diff --git a/pgx/shogi.py b/pgx/shogi.py index 57c29fdad..64cbec074 100644 --- a/pgx/shogi.py +++ b/pgx/shogi.py @@ -18,7 +18,6 @@ import pgx.core as core from pgx._src.shogi_utils import ( - INIT_LEGAL_ACTION_MASK, _from_sfen, _to_sfen, ) @@ -30,6 +29,16 @@ TRUE = jnp.bool_(True) FALSE = jnp.bool_(False) +INIT_LEGAL_ACTION_MASK = jnp.zeros(81 * 27, dtype=jnp.bool_) +# fmt: off +ixs = [5, 7, 14, 23, 25, 32, 34, 41, 43, 50, 52, 59, 61, 68, 77, 79, 115, 124, 133, 142, 187, 196, 205, 214, 268, 277, 286, 295, 304, 331] +# fmt: on +for ix in ixs: + INIT_LEGAL_ACTION_MASK = INIT_LEGAL_ACTION_MASK.at[ix].set(True) +assert INIT_LEGAL_ACTION_MASK.shape == (81 * 27,) +assert INIT_LEGAL_ACTION_MASK.sum() == 30 + + @dataclass class State(core.State):