diff --git a/pgx/_src/assets/between.npy b/pgx/_src/assets/between.npy deleted file mode 100644 index a019a9860..000000000 Binary files a/pgx/_src/assets/between.npy and /dev/null differ diff --git a/pgx/_src/assets/can_move.npy b/pgx/_src/assets/can_move.npy deleted file mode 100644 index 8d86d62ef..000000000 Binary files a/pgx/_src/assets/can_move.npy and /dev/null differ diff --git a/pgx/_src/shogi_utils.py b/pgx/_src/shogi_utils.py index f4d95603f..3247548e7 100644 --- a/pgx/_src/shogi_utils.py +++ b/pgx/_src/shogi_utils.py @@ -14,6 +14,7 @@ import os +import numpy as np import jax import jax.numpy as jnp import numpy as np @@ -30,20 +31,144 @@ [15, -1, 14, -1, -1, -1, 0, -1, 1]]).flatten() # noqa: E241 # fmt: on +EMPTY = -1 # 空白 +PAWN = 0 # 歩 +LANCE = 1 # 香 +KNIGHT = 2 # 桂 +SILVER = 3 # 銀 +BISHOP = 4 # 角 +ROOK = 5 # 飛 +GOLD = 6 # 金 +KING = 7 # 玉 +PRO_PAWN = 8 # と +PRO_LANCE = 9 # 成香 +PRO_KNIGHT = 10 # 成桂 +PRO_SILVER = 11 # 成銀 +HORSE = 12 # 馬 +DRAGON = 13 # 龍 + + # Can reach from to ignoring pieces on board? -file_path = "assets/can_move.npy" -with open(os.path.join(os.path.dirname(__file__), file_path), "rb") as f: - CAN_MOVE = jnp.load(f) +def can_move_to(piece, from_, to): + """Can move from to ?""" + if from_ == to: + return False + x0, y0 = from_ // 9, from_ % 9 + x1, y1 = to // 9, to % 9 + dx = x1 - x0 + dy = y1 - y0 + if piece == PAWN: + if dx == 0 and dy == -1: + return True + else: + return False + elif piece == LANCE: + if dx == 0 and dy < 0: + return True + else: + return False + elif piece == KNIGHT: + if dx in (-1, 1) and dy == -2: + return True + else: + return False + elif piece == SILVER: + if dx in (-1, 0, 1) and dy == -1: + return True + elif dx in (-1, 1) and dy == 1: + return True + else: + return False + elif piece == BISHOP: + if dx == dy or dx == -dy: + return True + else: + return False + elif piece == ROOK: + if dx == 0 or dy == 0: + return True + else: + return False + if piece in (GOLD, PRO_PAWN, PRO_LANCE, PRO_KNIGHT, PRO_SILVER): + if dx in (-1, 0, 1) and dy in (0, -1): + return True + elif dx == 0 and dy == 1: + return True + else: + return False + elif piece == KING: + if abs(dx) <= 1 and abs(dy) <= 1: + return True + else: + return False + elif piece == HORSE: + if abs(dx) <= 1 and abs(dy) <= 1: + return True + elif dx == dy or dx == -dy: + return True + else: + return False + elif piece == DRAGON: + if abs(dx) <= 1 and abs(dy) <= 1: + return True + if dx == 0 or dy == 0: + return True + else: + return False + else: + assert False + + +def is_on_the_way(piece, from_, to, point): + if to == point: + return False + if piece not in (LANCE, BISHOP, ROOK, HORSE, DRAGON): + return False + if not can_move_to(piece, from_, to): + return False + if not can_move_to(piece, from_, point): + return False + + x0, y0 = from_ // 9, from_ % 9 + x1, y1 = to // 9, to % 9 + x2, y2 = point // 9, point % 9 + dx1, dy1 = x1 - x0, y1 - y0 + dx2, dy2 = x2 - x0, y2 - y0 + + def sign(d): + if d == 0: + return 0 + return d > 0 + + if (sign(dx1) != sign(dx2)) or (sign(dy1) != sign(dy2)): + return False + + return abs(dx2) <= abs(dx1) and abs(dy2) <= abs(dy1) + + +CAN_MOVE = np.zeros((14, 81, 81), dtype=jnp.bool_) +for piece in range(14): + for from_ in range(81): + for to in range(81): + CAN_MOVE[piece, from_, to] = can_move_to(piece, from_, to) + assert CAN_MOVE.sum() == 8228 +CAN_MOVE = jnp.array(CAN_MOVE) # When moves from to , # is on the way between two points? -file_path = "assets/between.npy" -with open(os.path.join(os.path.dirname(__file__), file_path), "rb") as f: - BETWEEN = jnp.load(f) +BETWEEN = np.zeros((5, 81, 81, 81), dtype=np.bool_) +for i, piece in enumerate((LANCE, BISHOP, ROOK, HORSE, DRAGON)): + for from_ in range(81): + for to in range(81): + for p in range(81): + BETWEEN[i, from_, to, p] = is_on_the_way(piece, from_, to, p) + +BETWEEN = jnp.array(BETWEEN) assert BETWEEN.sum() == 10564 + # Give and , return the legal idx # E.g. LEGAL_FROM_IDX[Up, to=19] = [20, 21, ..., -1] (filled by -1) # Used for computing dlshogi action diff --git a/setup.py b/setup.py index 60bc1947f..341d0a1bf 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ def _read_requirements(fname): keywords="", packages=find_packages(), package_data={ - "": ["LICENSE", "*.svg", "_src/assets/*.npy"] + "": ["LICENSE", "*.svg"] }, include_package_data=True, install_requires=_read_requirements("requirements.txt"),