Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Flux = "0.11"
IntervalSets = "0.5"
MacroTools = "0.5"
ReinforcementLearningBase = "0.9"
ReinforcementLearningCore = "0.6.1"
ReinforcementLearningCore = "0.6.3"
Requires = "1"
Setfield = "0.6, 0.7"
StableRNGs = "1.0"
Expand All @@ -47,3 +47,11 @@ StructArrays = "0.4"
TensorBoardLogger = "0.1"
Zygote = "0.5"
julia = "1.4"

[extras]
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel"]
2 changes: 1 addition & 1 deletion src/ReinforcementLearningZoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function __init__()
@require ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" begin
include("experiments/rl_envs/rl_envs.jl")
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("experiments/atari/atari.jl")
# @require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel/open_spiel.jl")
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel/open_spiel.jl")
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/cfr/abstract_cfr_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function Base.run(
@assert DynamicStyle(env) === SEQUENTIAL
@assert RewardStyle(env) === TERMINAL_REWARD
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC
@assert DefaultStateStyle(env) isa Information
@assert DefaultStateStyle(env) isa InformationSet

RLBase.reset!(env)

Expand Down
22 changes: 12 additions & 10 deletions src/algorithms/cfr/best_response_policy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@ function BestResponsePolicy(
policy,
env,
best_responder;
state_type = String,
action_type = Int,
)
# S = typeof(state(env)) # TODO: currently it will break the OpenSpielEnv. Can not get information set for chance player
# A = eltype(action_space(env)) # TODO: for chance players it will return ActionProbPair
S = state_type
A = action_type
S = eltype(state_space(env))
A = eltype(action_space(env))
E = typeof(env)

p = BestResponsePolicy(
Expand Down Expand Up @@ -61,8 +57,10 @@ function init_cfr_reach_prob!(p, env, reach_prob = 1.0)
init_cfr_reach_prob!(p, child(env, a), reach_prob)
end
elseif current_player(env) == chance_player(env)
for a::ActionProbPair in action_space(env)
init_cfr_reach_prob!(p, child(env, a), reach_prob * a.prob)
for (a, pₐ) in zip(action_space(env), prob(env))
if pₐ > 0
init_cfr_reach_prob!(p, child(env, a), reach_prob * pₐ)
end
end
else # opponents
for a in legal_action_space(env)
Expand All @@ -81,8 +79,12 @@ function best_response_value(p, env)
best_response_value(p, child(env, a))
elseif current_player(env) == chance_player(env)
v = 0.0
for a::ActionProbPair in action_space(env)
v += a.prob * best_response_value(p, child(env, a))
A, P = action_space(env), prob(env)
@assert length(A) == length(P)
for (a, pₐ) in zip(A, P)
if pₐ > 0
v += pₐ * best_response_value(p, child(env, a))
end
end
v
else
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/cfr/deep_cfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ end

"Run one interation"
function RLBase.update!(π::DeepCFR, env::AbstractEnv)
for p in get_players(env)
for p in players(env)
if p != chance_player(env)
for k in 1:π.K
external_sampling!(π, copy(env), p)
Expand Down
21 changes: 14 additions & 7 deletions src/algorithms/cfr/external_sampling_mccfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ end

RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) = prob(p.behavior_policy, env)

RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv, action) =
prob(p.behavior_policy, env, action)

function ExternalSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG)
ExternalSamplingMCCFRPolicy(
Dict{state_type,InfoStateNode}(),
Expand All @@ -36,7 +39,10 @@ function RLBase.update!(p::ExternalSamplingMCCFRPolicy)
for (k, v) in p.nodes
s = sum(v.cumulative_strategy)
if s != 0
update!(p.behavior_policy, k => v.cumulative_strategy ./ s)
m = v.mask
strategy = zeros(length(m))
strategy[m] .= v.cumulative_strategy ./ s
update!(p.behavior_policy, k => strategy)
else
# The TabularLearner will return uniform distribution by default.
# So we do nothing here.
Expand All @@ -46,30 +52,31 @@ end

"Run one interation"
function RLBase.update!(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv)
for x in get_players(env)
for x in players(env)
if x != chance_player(env)
external_sampling(copy(env), x, p.nodes, p.rng)
end
end
end

function external_sampling(env, i, nodes, rng)
current_player = current_player(env)
player = current_player(env)

if is_terminated(env)
reward(env, i)
elseif current_player == chance_player(env)
env(rand(rng, action_space(env)))
elseif player == chance_player(env)
env(sample(rng, action_space(env), Weights(prob(env), 1.0)))
external_sampling(env, i, nodes, rng)
else
I = state(env)
legal_actions = legal_action_space(env)
M = legal_action_space_mask(env)
n = length(legal_actions)
node = get!(nodes, I, InfoStateNode(n))
node = get!(nodes, I, InfoStateNode(M))
regret_matching!(node; is_reset_neg_regrets = false)
σ, rI, sI = node.strategy, node.cumulative_regret, node.cumulative_strategy

if i == current_player
if i == player
u = zeros(n)
uσ = 0
for (aᵢ, a) in enumerate(legal_actions)
Expand Down
14 changes: 8 additions & 6 deletions src/algorithms/cfr/nash_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ export expected_policy_values, nash_conv

function expected_policy_values(π::AbstractPolicy, env::AbstractEnv)
if is_terminated(env)
[reward(env, p) for p in get_players(env) if p != chance_player(env)]
[reward(env, p) for p in players(env) if p != chance_player(env)]
elseif current_player(env) == chance_player(env)
vals = [0.0 for p in get_players(env) if p != chance_player(env)]
for a::ActionProbPair in legal_action_space(env)
vals .+= a.prob .* expected_policy_values(π, child(env, a))
vals = [0.0 for p in players(env) if p != chance_player(env)]
for (a, pₐ) in zip(action_space(env), prob(env))
if pₐ > 0
vals .+= pₐ .* expected_policy_values(π, child(env, a))
end
end
vals
else
vals = [0.0 for p in get_players(env) if p != chance_player(env)]
vals = [0.0 for p in players(env) if p != chance_player(env)]
actions = action_space(env)
probs = prob(π, env)
@assert length(actions) == length(probs)
Expand All @@ -30,7 +32,7 @@ function nash_conv(π, env; is_reduce = true, kw...)

σ′ = [
best_response_value(BestResponsePolicy(π, e, i; kw...), e)
for i in get_players(e) if i != chance_player(e)
for i in players(e) if i != chance_player(e)
]

σ = expected_policy_values(π, e)
Expand Down
24 changes: 16 additions & 8 deletions src/algorithms/cfr/outcome_sampling_mccfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ end

RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) = prob(p.behavior_policy, env)

RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv, action) =
prob(p.behavior_policy, env, action)

function OutcomeSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG, ϵ = 0.6)
OutcomeSamplingMCCFRPolicy(
Dict{state_type,InfoStateNode}(),
Expand All @@ -36,7 +39,7 @@ end

"Run one interation"
function RLBase.update!(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv)
for x in get_players(env)
for x in players(env)
if x != chance_player(env)
outcome_sampling(copy(env), x, p.nodes, p.ϵ, 1.0, 1.0, 1.0, p.rng)
end
Expand All @@ -47,7 +50,10 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
for (k, v) in p.nodes
s = sum(v.cumulative_strategy)
if s != 0
update!(p.behavior_policy, k => v.cumulative_strategy ./ s)
m = v.mask
strategy = zeros(length(m))
strategy[m] .= v.cumulative_strategy ./ s
update!(p.behavior_policy, k => strategy)
else
# The TabularLearner will return uniform distribution by default.
# So we do nothing here.
Expand All @@ -56,22 +62,24 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
end

function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
current_player = current_player(env)
player = current_player(env)

if is_terminated(env)
reward(env, i) / s, 1.0
elseif current_player == chance_player(env)
env(rand(rng, action_space(env)))
elseif player == chance_player(env)
x = sample(rng, action_space(env), Weights(prob(env), 1.0))
env(sample(rng, action_space(env), Weights(prob(env), 1.0)))
outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
else
I = state(env)
legal_actions = legal_action_space(env)
M = legal_action_space_mask(env)
n = length(legal_actions)
node = get!(nodes, I, InfoStateNode(n))
node = get!(nodes, I, InfoStateNode(M))
regret_matching!(node; is_reset_neg_regrets = false)
σ, rI, sI = node.strategy, node.cumulative_regret, node.cumulative_strategy

if i == current_player
if i == player
aᵢ = rand(rng) >= ϵ ? sample(rng, Weights(σ, 1.0)) : rand(rng, 1:n)
pᵢ = σ[aᵢ] * (1 - ϵ) + ϵ / n
πᵢ′, π₋ᵢ′, s′ = πᵢ * pᵢ, π₋ᵢ, s * pᵢ
Expand All @@ -84,7 +92,7 @@ function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
env(legal_action_space(env)[aᵢ])
u, πₜₐᵢₗ = outcome_sampling(env, i, nodes, ϵ, πᵢ′, π₋ᵢ′, s′, rng)

if i == current_player
if i == player
w = u * π₋ᵢ
rI .+= w * πₜₐᵢₗ .* ((1:n .== aᵢ) .- σ[aᵢ])
else
Expand Down
35 changes: 25 additions & 10 deletions src/algorithms/cfr/tabular_cfr.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
export TabularCFRPolicy

struct InfoStateNode
struct InfoStateNode{M<:AbstractVector{Bool}}
strategy::Vector{Float64}
cumulative_regret::Vector{Float64}
cumulative_strategy::Vector{Float64}
mask::M
end

InfoStateNode(n) = InfoStateNode(fill(1 / n, n), zeros(n), zeros(n))
function InfoStateNode(mask)
n = sum(mask)
InfoStateNode(
fill(1 / n, n),
zeros(n),
zeros(n),
mask
)
end

#####
# TabularCFRPolicy
Expand Down Expand Up @@ -77,7 +86,10 @@ function RLBase.update!(p::TabularCFRPolicy)
for (k, v) in p.nodes
s = sum(v.cumulative_strategy)
if s != 0
update!(p.behavior_policy, k => v.cumulative_strategy ./ s)
m = v.mask
strategy = zeros(length(m))
strategy[m] .= v.cumulative_strategy ./ s
update!(p.behavior_policy, k => strategy)
else
# The TabularLearner will return uniform distribution by default.
# So we do nothing here.
Expand All @@ -89,7 +101,7 @@ end
function RLBase.update!(p::TabularCFRPolicy, env::AbstractEnv)
w = p.is_linear_averaging ? max(p.n_iteration - p.weighted_averaging_delay, 0) : 1
if p.is_alternating_update
for x in get_players(env)
for x in players(env)
if x != chance_player(env)
cfr!(p.nodes, env, x, w)
regret_matching!(p)
Expand All @@ -113,22 +125,25 @@ w: weight
v: counterfactual value **before weighted by opponent's reaching probability**
V: a vector containing the `v` after taking each action with current information set. Used to calculate the **regret value**
"""
function cfr!(nodes, env, p, w, π = Dict(x => 1.0 for x in get_players(env)))
function cfr!(nodes, env, p, w, π = Dict(x => 1.0 for x in players(env)))
if is_terminated(env)
reward(env, p)
else
if current_player(env) == chance_player(env)
v = 0.0
for a::ActionProbPair in legal_action_space(env)
π′ = copy(π)
π′[current_player(env)] *= a.prob
v += a.prob * cfr!(nodes, child(env, a), p, w, π′)
for (a, pₐ) in zip(action_space(env), prob(env))
if pₐ > 0
π′ = copy(π)
π′[current_player(env)] *= pₐ
v += pₐ * cfr!(nodes, child(env, a), p, w, π′)
end
end
v
else
v = 0.0
legal_actions = legal_action_space(env)
node = get!(nodes, state(env), InfoStateNode(length(legal_actions)))
M = legal_action_space_mask(env)
node = get!(nodes, state(env), InfoStateNode(M))

is_update = isnothing(p) || p == current_player(env)
V = is_update ? Vector{Float64}(undef, length(legal_actions)) : nothing
Expand Down
27 changes: 23 additions & 4 deletions src/algorithms/policy_gradient/multi_thread_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ struct MultiThreadEnv{E,S,R,AS,SS,L} <: AbstractEnv
end

function Base.show(io::IO, t::MIME"text/markdown", env::MultiThreadEnv)
s = """
# MultiThreadEnv($(length(env)) x $(nameof(env[1])))
"""
show(io, t, Markdown.parse(s))
print(io, "MultiThreadEnv($(length(env)) x $(nameof(env[1])))")
end

"""
Expand Down Expand Up @@ -146,3 +143,25 @@ for f in RLBase.ENV_API
@eval RLBase.$f(x::MultiThreadEnv) = $f(x[1])
end
end

#####
# Patches
#####

(env::MultiThreadEnv)(action::EnrichedAction) = env(action.action)

function (π::QBasedPolicy)(env::MultiThreadEnv, ::MinimalActionSet, A)
[A[i][a] for (i, a) in enumerate(π.explorer(π.learner(env)))]
end

function (π::QBasedPolicy)(env::MultiThreadEnv, ::FullActionSet, A)
[A[i][a] for (i,a) in enumerate(π.explorer(π.learner(env), legal_action_space_mask(env)))]
end

function (π::QBasedPolicy)(env::MultiThreadEnv, ::MinimalActionSet, ::Space{<:Vector{<:Base.OneTo{<:Integer}}})
π.explorer(π.learner(env))
end

function (π::QBasedPolicy)(env::MultiThreadEnv, ::FullActionSet, ::Space{<:Vector{<:Base.OneTo{<:Integer}}})
π.explorer(π.learner(env), legal_action_space_mask(env))
end
Loading