@@ -218,7 +218,7 @@ def _step(self, state: TensorDict) -> TensorDict:
218
218
turn = state ["turn" ].clone ()
219
219
action = state ["action" ]
220
220
board .flatten (- 2 , - 1 ).scatter_ (index = action .unsqueeze (- 1 ), dim = - 1 , value = 1 )
221
- wins = self .win (state [ " board" ] , action )
221
+ wins = self .win (board , action )
222
222
223
223
mask = board .flatten (- 2 , - 1 ) == - 1
224
224
done = wins | ~ mask .any (- 1 , keepdim = True )
@@ -234,7 +234,7 @@ def _step(self, state: TensorDict) -> TensorDict:
234
234
("player0" , "reward" ): reward_0 .float (),
235
235
("player1" , "reward" ): reward_1 .float (),
236
236
"board" : torch .where (board == - 1 , board , 1 - board ),
237
- "turn" : 1 - state [ " turn" ] ,
237
+ "turn" : 1 - turn ,
238
238
"mask" : mask ,
239
239
},
240
240
batch_size = state .batch_size ,
@@ -260,13 +260,15 @@ def _set_seed(self, seed: int | None):
260
260
def win (board : torch .Tensor , action : torch .Tensor ):
261
261
row = action // 3 # type: ignore
262
262
col = action % 3 # type: ignore
263
- return (
264
- board [..., row , :].sum ()
265
- == 3 | board [..., col ].sum ()
266
- == 3 | board .diagonal (0 , - 2 , - 1 ).sum ()
267
- == 3 | board .flip (- 1 ).diagonal (0 , - 2 , - 1 ).sum ()
268
- == 3
269
- )
263
+ if board [..., row , :].sum () == 3 :
264
+ return True
265
+ if board [..., col ].sum () == 3 :
266
+ return True
267
+ if board .diagonal (0 , - 2 , - 1 ).sum () == 3 :
268
+ return True
269
+ if board .flip (- 1 ).diagonal (0 , - 2 , - 1 ).sum () == 3 :
270
+ return True
271
+ return False
270
272
271
273
@staticmethod
272
274
def full (board : torch .Tensor ) -> bool :
0 commit comments