Skip to content

Commit

Permalink
Merge pull request #1068 from MarkZH/multiple-move-comments
Browse files Browse the repository at this point in the history
Handle multiple comments on a move/variation/game
  • Loading branch information
niklasf authored Oct 4, 2024
2 parents 08697b2 + 2d5755e commit d4b3190
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 92 deletions.
124 changes: 65 additions & 59 deletions chess/pgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def repl(match: typing.Match[str]) -> str:
return repl


def _standardize_comments(comment: Union[str, list[str]]) -> list[str]:
return [] if not comment else [comment] if isinstance(comment, str) else comment


TAG_ROSTER = ["Event", "Site", "Date", "Round", "White", "Black", "Result"]


Expand Down Expand Up @@ -200,24 +204,25 @@ class GameNode(abc.ABC):
variations: List[ChildNode]
"""A list of child nodes."""

comment: str
comments: list[str]
"""
A comment that goes behind the move leading to this node. Comments
that occur before any moves are assigned to the root node.
"""

starting_comment: str
starting_comments: list[str]

nags: Set[int]

def __init__(self, *, comment: str = "") -> None:
def __init__(self, *, comment: Union[str, list[str]] = "") -> None:
self.parent = None
self.move = None
self.variations = []
self.comment = comment
self.comments = _standardize_comments(comment)

# Deprecated: These should be properties of ChildNode, but need to
# remain here for backwards compatibility.
self.starting_comment = ""
self.starting_comments = []
self.nags = set()

@abc.abstractmethod
Expand Down Expand Up @@ -389,7 +394,7 @@ def remove_variation(self, move: Union[int, chess.Move, GameNode]) -> None:
"""Removes a variation."""
self.variations.remove(self.variation(move))

def add_variation(self, move: chess.Move, *, comment: str = "", starting_comment: str = "", nags: Iterable[int] = []) -> ChildNode:
def add_variation(self, move: chess.Move, *, comment: Union[str, list[str]] = "", starting_comment: Union[str, list[str]] = "", nags: Iterable[int] = []) -> ChildNode:
"""Creates a child node with the given attributes."""
# Instanciate ChildNode only in this method.
return ChildNode(self, move, comment=comment, starting_comment=starting_comment, nags=nags)
Expand Down Expand Up @@ -420,7 +425,7 @@ def mainline_moves(self) -> Mainline[chess.Move]:
"""Returns an iterable over the main moves after this node."""
return Mainline(self, lambda node: node.move)

def add_line(self, moves: Iterable[chess.Move], *, comment: str = "", starting_comment: str = "", nags: Iterable[int] = []) -> GameNode:
def add_line(self, moves: Iterable[chess.Move], *, comment: Union[str, list[str]] = "", starting_comment: Union[str, list[str]] = "", nags: Iterable[int] = []) -> GameNode:
"""
Creates a sequence of child nodes for the given list of moves.
Adds *comment* and *nags* to the last node of the line and returns it.
Expand All @@ -433,11 +438,8 @@ def add_line(self, moves: Iterable[chess.Move], *, comment: str = "", starting_c
starting_comment = ""

# Merge comment and NAGs.
if node.comment:
node.comment += " " + comment
else:
node.comment = comment

comments = _standardize_comments(comment)
node.comments.extend(comments)
node.nags.update(nags)

return node
Expand All @@ -449,7 +451,7 @@ def eval(self) -> Optional[chess.engine.PovScore]:
Complexity is `O(n)`.
"""
match = EVAL_REGEX.search(self.comment)
match = EVAL_REGEX.search(" ".join(self.comments))
if not match:
return None

Expand All @@ -475,7 +477,7 @@ def eval_depth(self) -> Optional[int]:
Complexity is `O(1)`.
"""
match = EVAL_REGEX.search(self.comment)
match = EVAL_REGEX.search(" ".join(self.comments))
return int(match.group("depth")) if match and match.group("depth") else None

def set_eval(self, score: Optional[chess.engine.PovScore], depth: Optional[int] = None) -> None:
Expand All @@ -492,12 +494,7 @@ def set_eval(self, score: Optional[chess.engine.PovScore], depth: Optional[int]
elif score.white().mate():
eval = f"[%eval #{score.white().mate()}{depth_suffix}]"

self.comment, found = EVAL_REGEX.subn(_condense_affix(eval), self.comment, count=1)

if not found and eval:
if self.comment and not self.comment.endswith(" "):
self.comment += " "
self.comment += eval
self._replace_or_add_annotation(eval, EVAL_REGEX)

def arrows(self) -> List[chess.svg.Arrow]:
"""
Expand All @@ -507,7 +504,7 @@ def arrows(self) -> List[chess.svg.Arrow]:
Returns a list of :class:`arrows <chess.svg.Arrow>`.
"""
arrows = []
for match in ARROWS_REGEX.finditer(self.comment):
for match in ARROWS_REGEX.finditer(" ".join(self.comments)):
for group in match.group("arrows").split(","):
arrows.append(chess.svg.Arrow.from_pgn(group))

Expand All @@ -529,18 +526,19 @@ def set_arrows(self, arrows: Iterable[Union[chess.svg.Arrow, Tuple[Square, Squar
pass
(csl if arrow.tail == arrow.head else cal).append(arrow.pgn()) # type: ignore

self.comment = ARROWS_REGEX.sub(_condense_affix(""), self.comment)
for index in range(len(self.comments)):
self.comments[index] = ARROWS_REGEX.sub(_condense_affix(""), self.comments[index])

self.comments = list(filter(None, self.comments))

prefix = ""
if csl:
prefix += f"[%csl {','.join(csl)}]"
if cal:
prefix += f"[%cal {','.join(cal)}]"

if prefix and self.comment and not self.comment.startswith(" ") and not self.comment.startswith("\n"):
self.comment = prefix + " " + self.comment
else:
self.comment = prefix + self.comment
if prefix:
self.comments.insert(0, prefix)

def clock(self) -> Optional[float]:
"""
Expand All @@ -550,7 +548,7 @@ def clock(self) -> Optional[float]:
Returns the player's remaining time to the next time control after this
move, in seconds.
"""
match = CLOCK_REGEX.search(self.comment)
match = CLOCK_REGEX.search(" ".join(self.comments))
if match is None:
return None
return int(match.group("hours")) * 3600 + int(match.group("minutes")) * 60 + float(match.group("seconds"))
Expand All @@ -569,12 +567,7 @@ def set_clock(self, seconds: Optional[float]) -> None:
seconds_part = f"{seconds:06.3f}".rstrip("0").rstrip(".")
clk = f"[%clk {hours:d}:{minutes:02d}:{seconds_part}]"

self.comment, found = CLOCK_REGEX.subn(_condense_affix(clk), self.comment, count=1)

if not found and clk:
if self.comment and not self.comment.endswith(" ") and not self.comment.endswith("\n"):
self.comment += " "
self.comment += clk
self._replace_or_add_annotation(clk, CLOCK_REGEX)

def emt(self) -> Optional[float]:
"""
Expand All @@ -584,7 +577,7 @@ def emt(self) -> Optional[float]:
Returns the player's elapsed move time use for the comment of this
move, in seconds.
"""
match = EMT_REGEX.search(self.comment)
match = EMT_REGEX.search(" ".join(self.comments))
if match is None:
return None
return int(match.group("hours")) * 3600 + int(match.group("minutes")) * 60 + float(match.group("seconds"))
Expand All @@ -603,12 +596,19 @@ def set_emt(self, seconds: Optional[float]) -> None:
seconds_part = f"{seconds:06.3f}".rstrip("0").rstrip(".")
emt = f"[%emt {hours:d}:{minutes:02d}:{seconds_part}]"

self.comment, found = EMT_REGEX.subn(_condense_affix(emt), self.comment, count=1)
self._replace_or_add_annotation(emt, EMT_REGEX)

def _replace_or_add_annotation(self, text: str, regex: re.Pattern[str]) -> None:
found = 0
for index in range(len(self.comments)):
self.comments[index], found = regex.subn(_condense_affix(text), self.comments[index], count=1)
if found:
break

self.comments = list(filter(None, self.comments))

if not found and emt:
if self.comment and not self.comment.endswith(" ") and not self.comment.endswith("\n"):
self.comment += " "
self.comment += emt
if not found and text:
self.comments.append(text)

@abc.abstractmethod
def accept(self, visitor: BaseVisitor[ResultT]) -> ResultT:
Expand Down Expand Up @@ -664,7 +664,7 @@ class ChildNode(GameNode):
move: chess.Move
"""The move leading to this node."""

starting_comment: str
starting_comments: list[str]
"""
A comment for the start of a variation. Only nodes that
actually start a variation (:func:`~chess.pgn.GameNode.starts_variation()`
Expand All @@ -678,14 +678,14 @@ class ChildNode(GameNode):
node of the game will never have NAGs.
"""

def __init__(self, parent: GameNode, move: chess.Move, *, comment: str = "", starting_comment: str = "", nags: Iterable[int] = []) -> None:
def __init__(self, parent: GameNode, move: chess.Move, *, comment: Union[str, list[str]] = "", starting_comment: Union[str, list[str]] = "", nags: Iterable[int] = []) -> None:
super().__init__(comment=comment)
self.parent = parent
self.move = move
self.parent.variations.append(self)

self.nags.update(nags)
self.starting_comment = starting_comment
self.starting_comments = _standardize_comments(starting_comment)

def board(self) -> chess.Board:
stack: List[chess.Move] = []
Expand Down Expand Up @@ -741,8 +741,8 @@ def end(self) -> ChildNode:
return typing.cast(ChildNode, super().end())

def _accept_node(self, parent_board: chess.Board, visitor: BaseVisitor[ResultT]) -> None:
if self.starting_comment:
visitor.visit_comment(self.starting_comment)
if self.starting_comments:
visitor.visit_comment(self.starting_comments)

visitor.visit_move(parent_board, self.move)

Expand All @@ -753,8 +753,8 @@ def _accept_node(self, parent_board: chess.Board, visitor: BaseVisitor[ResultT])
for nag in sorted(self.nags):
visitor.visit_nag(nag)

if self.comment:
visitor.visit_comment(self.comment)
if self.comments:
visitor.visit_comment(self.comments)

def _accept(self, parent_board: chess.Board, visitor: BaseVisitor[ResultT], *, sidelines: bool = True) -> None:
stack = [_AcceptFrame(self, sidelines=sidelines)]
Expand Down Expand Up @@ -885,8 +885,8 @@ def accept(self, visitor: BaseVisitor[ResultT]) -> ResultT:
board = self.board()
visitor.visit_board(board)

if self.comment:
visitor.visit_comment(self.comment)
if self.comments:
visitor.visit_comment(self.comments)

if self.variations:
self.variations[0]._accept(board, visitor)
Expand Down Expand Up @@ -1137,7 +1137,7 @@ def visit_board(self, board: chess.Board) -> None:
"""
pass

def visit_comment(self, comment: str) -> None:
def visit_comment(self, comment: list[str]) -> None:
"""Called for each comment."""
pass

Expand Down Expand Up @@ -1191,7 +1191,7 @@ def begin_game(self) -> None:
self.game: GameT = self.Game()

self.variation_stack: List[GameNode] = [self.game]
self.starting_comment = ""
self.starting_comments: list[str] = []
self.in_variation = False

def begin_headers(self) -> Headers:
Expand All @@ -1216,22 +1216,23 @@ def visit_result(self, result: str) -> None:
if self.game.headers.get("Result", "*") == "*":
self.game.headers["Result"] = result

def visit_comment(self, comment: str) -> None:
def visit_comment(self, comment: Union[str, list[str]]) -> None:
comments = _standardize_comments(comment)
if self.in_variation or (self.variation_stack[-1].parent is None and self.variation_stack[-1].is_end()):
# Add as a comment for the current node if in the middle of
# a variation. Add as a comment for the game if the comment
# starts before any move.
new_comment = [self.variation_stack[-1].comment, comment]
self.variation_stack[-1].comment = " ".join(filter(None, new_comment))
self.variation_stack[-1].comments.extend(comments)
self.variation_stack[-1].comments = list(filter(None, self.variation_stack[-1].comments))
else:
# Otherwise, it is a starting comment.
new_comment = [self.starting_comment, comment]
self.starting_comment = " ".join(filter(None, new_comment))
self.starting_comments.extend(comments)
self.starting_comments = list(filter(None, self.starting_comments))

def visit_move(self, board: chess.Board, move: chess.Move) -> None:
self.variation_stack[-1] = self.variation_stack[-1].add_variation(move)
self.variation_stack[-1].starting_comment = self.starting_comment
self.starting_comment = ""
self.variation_stack[-1].starting_comments = self.starting_comments
self.starting_comments = []
self.in_variation = True

def handle_error(self, error: Exception) -> None:
Expand Down Expand Up @@ -1399,9 +1400,14 @@ def end_variation(self) -> None:
self.write_token(") ")
self.force_movenumber = True

def visit_comment(self, comment: str) -> None:
def visit_comment(self, comment: Union[str, list[str]]) -> None:
if self.comments and (self.variations or not self.variation_depth):
self.write_token("{ " + comment.replace("}", "").strip() + " } ")
def pgn_format(comments: list[str]) -> str:
edit = map(lambda s: s.replace("{", "").replace("}", ""), comments)
return " ".join(f"{{ {comment} }}" for comment in edit if comment)

comments = _standardize_comments(comment)
self.write_token(pgn_format(comments) + " ")
self.force_movenumber = True

def visit_nag(self, nag: int) -> None:
Expand Down
Loading

0 comments on commit d4b3190

Please sign in to comment.