Skip to content

Commit 6c66796

Browse files
author
Vincent Moens
committed
[BugFix] Fix tictactoeenv.py
ghstack-source-id: 99a368c Pull Request resolved: #2417
1 parent 60cd104 commit 6c66796

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

torchrl/envs/custom/tictactoeenv.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _step(self, state: TensorDict) -> TensorDict:
218218
turn = state["turn"].clone()
219219
action = state["action"]
220220
board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1)
221-
wins = self.win(state["board"], action)
221+
wins = self.win(board, action)
222222

223223
mask = board.flatten(-2, -1) == -1
224224
done = wins | ~mask.any(-1, keepdim=True)
@@ -234,7 +234,7 @@ def _step(self, state: TensorDict) -> TensorDict:
234234
("player0", "reward"): reward_0.float(),
235235
("player1", "reward"): reward_1.float(),
236236
"board": torch.where(board == -1, board, 1 - board),
237-
"turn": 1 - state["turn"],
237+
"turn": 1 - turn,
238238
"mask": mask,
239239
},
240240
batch_size=state.batch_size,
@@ -260,13 +260,15 @@ def _set_seed(self, seed: int | None):
260260
def win(board: torch.Tensor, action: torch.Tensor):
261261
row = action // 3 # type: ignore
262262
col = action % 3 # type: ignore
263-
return (
264-
board[..., row, :].sum()
265-
== 3 | board[..., col].sum()
266-
== 3 | board.diagonal(0, -2, -1).sum()
267-
== 3 | board.flip(-1).diagonal(0, -2, -1).sum()
268-
== 3
269-
)
263+
if board[..., row, :].sum() == 3:
264+
return True
265+
if board[..., col].sum() == 3:
266+
return True
267+
if board.diagonal(0, -2, -1).sum() == 3:
268+
return True
269+
if board.flip(-1).diagonal(0, -2, -1).sum() == 3:
270+
return True
271+
return False
270272

271273
@staticmethod
272274
def full(board: torch.Tensor) -> bool:

0 commit comments

Comments
 (0)