Skip to content

Commit

Permalink
Fix and unit test for issue "Wrong style for state report for TicTacT…
Browse files Browse the repository at this point in the history
  • Loading branch information
hespanha committed Sep 3, 2024
1 parent b49bf69 commit 43a13a4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought))

RLBase.state(env::TicTacToeEnv, ::Observation, ::DefaultPlayer) = state(env, Observation{Int}(), Player(:Any))
RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board
RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}) = env.board
RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(1))
RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) =
get_tic_tac_toe_state_info()[env].index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
using ReinforcementLearningEnvironments, ReinforcementLearningBase, ReinforcementLearningCore

trajectory_1 = Trajectory(
CircularArraySARTSTraces(; capacity = 1),
CircularArraySARTSTraces(; capacity=1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
InsertSampleRatioController(n_inserted=-1),
)

trajectory_2 = Trajectory(
CircularArraySARTSTraces(; capacity = 1),
CircularArraySARTSTraces(; capacity=1),
BatchSampler(1),
InsertSampleRatioController(n_inserted = -1),
InsertSampleRatioController(n_inserted=-1),
)

multiagent_policy = MultiAgentPolicy(PlayerTuple(
Expand All @@ -30,6 +30,7 @@
@test length(state_space(env, Observation{Int}())) == 5478

@test RLBase.state(env, Observation{BitArray{3}}(), Player(:Cross)) == env.board
@test RLBase.state(env, Observation{BitArray{3}}()) == env.board
@test RLBase.state_space(env, Observation{BitArray{3}}(), Player(:Cross)) isa ArrayProductDomain
@test RLBase.state_space(env, Observation{String}(), Player(:Cross)) isa DomainSets.FullSpace{String}
@test RLBase.state(env, Observation{String}(), Player(:Cross)) isa String
Expand Down

0 comments on commit 43a13a4

Please sign in to comment.