From 43a13a4062a3cb4ea273d6491caaabda588f4f87 Mon Sep 17 00:00:00 2001 From: Joao Hespanha Date: Mon, 2 Sep 2024 17:19:58 -0700 Subject: [PATCH] Fix and unit test for issue "Wrong style for state report for TicTacToeEnv() #1079" --- .../src/environments/examples/TicTacToeEnv.jl | 1 + .../test/environments/examples/tic_tac_toe.jl | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index ff4c89b4d..027502793 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -72,6 +72,7 @@ RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought)) RLBase.state(env::TicTacToeEnv, ::Observation, ::DefaultPlayer) = state(env, Observation{Int}(), Player(:Any)) RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board +RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}) = env.board RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(1)) RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) = get_tic_tac_toe_state_info()[env].index diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl index 0eca516ff..f5b15f289 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl @@ -3,15 +3,15 @@ using ReinforcementLearningEnvironments, ReinforcementLearningBase, ReinforcementLearningCore trajectory_1 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) trajectory_2 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) multiagent_policy = MultiAgentPolicy(PlayerTuple( @@ -30,6 +30,7 @@ @test length(state_space(env, Observation{Int}())) == 5478 @test RLBase.state(env, Observation{BitArray{3}}(), Player(:Cross)) == env.board + @test RLBase.state(env, Observation{BitArray{3}}()) == env.board @test RLBase.state_space(env, Observation{BitArray{3}}(), Player(:Cross)) isa ArrayProductDomain @test RLBase.state_space(env, Observation{String}(), Player(:Cross)) isa DomainSets.FullSpace{String} @test RLBase.state(env, Observation{String}(), Player(:Cross)) isa String