diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index ff4c89b4d..027502793 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -72,6 +72,7 @@ RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought)) RLBase.state(env::TicTacToeEnv, ::Observation, ::DefaultPlayer) = state(env, Observation{Int}(), Player(:Any)) RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board +RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}) = env.board RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(1)) RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) = get_tic_tac_toe_state_info()[env].index diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl index 0eca516ff..f5b15f289 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl @@ -3,15 +3,15 @@ using ReinforcementLearningEnvironments, ReinforcementLearningBase, ReinforcementLearningCore trajectory_1 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) trajectory_2 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) multiagent_policy = MultiAgentPolicy(PlayerTuple( @@ -30,6 +30,7 @@ @test length(state_space(env, Observation{Int}())) == 5478 @test RLBase.state(env, Observation{BitArray{3}}(), Player(:Cross)) == env.board + @test RLBase.state(env, Observation{BitArray{3}}()) == env.board @test RLBase.state_space(env, Observation{BitArray{3}}(), Player(:Cross)) isa ArrayProductDomain @test RLBase.state_space(env, Observation{String}(), Player(:Cross)) isa DomainSets.FullSpace{String} @test RLBase.state(env, Observation{String}(), Player(:Cross)) isa String