Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Expand All @@ -29,13 +30,14 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AbstractTrees = "0.3"
BSON = "0.2"
CUDA = "1"
Distributions = "0.23, 0.24"
Distributions = "0.24"
Flux = "0.11"
MacroTools = "0.5"
ReinforcementLearningBase = "0.8.4"
ReinforcementLearningCore = "0.4.5"
ReinforcementLearningCore = "0.5"
Requires = "1"
Setfield = "0.6, 0.7"
StableRNGs = "1.0"
StatsBase = "0.32, 0.33"
StructArrays = "0.4"
TensorBoardLogger = "0.1"
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ This project aims to provide some implementations of the most typical reinforcem
- DDPG
- TD3
- SAC
- CFR/OS-MCCFR/ES-MCCFR
- CFR/OS-MCCFR/ES-MCCFR/DeepCFR
- Minimax

If you are looking for tabular reinforcement learning algorithms, you may refer [ReinforcementLearningAnIntroduction.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningAnIntroduction.jl).
Expand Down Expand Up @@ -55,6 +55,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
- ``E`JuliaRL_DQN_MountainCar` `` (Thanks to [@felixchalumeau](https://github.com/felixchalumeau))
- ``E`JuliaRL_Minimax_OpenSpiel(tic_tac_toe)` ``
- ``E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)` ``
- ``E`JuliaRL_DeepCFR_OpenSpiel(leduc_poker)` ``
- ``E`JuliaRL_DQN_SnakeGame` ``
- ``E`Dopamine_DQN_Atari(pong)` ``
- ``E`Dopamine_Rainbow_Atari(pong)` ``
Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningZoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export RLZoo
using ReinforcementLearningBase
using ReinforcementLearningCore
using Setfield: @set
using StableRNGs

include("patch.jl")
include("algorithms/algorithms.jl")
Expand Down
18 changes: 18 additions & 0 deletions src/algorithms/cfr/abstract_cfr_policy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
abstract type AbstractCFRPolicy <: AbstractPolicy end

function Base.run(p::AbstractCFRPolicy, env::AbstractEnv, stop_condition=StopAfterStep(1), hook=EmptyHook())
@assert NumAgentStyle(env) isa MultiAgent
@assert DynamicStyle(env) === SEQUENTIAL
@assert RewardStyle(env) === TERMINAL_REWARD
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC
@assert DefaultStateStyle(env) isa Information

RLBase.reset!(env)

while true
update!(p, env)
hook(POST_ACT_STAGE, p, env)
stop_condition(p, env) && break
end
update!(p)
end
119 changes: 119 additions & 0 deletions src/algorithms/cfr/best_response_policy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
export BestResponsePolicy

using Flux:onehot

struct BestResponsePolicy{E, S, A, X, P<:AbstractPolicy} <: AbstractCFRPolicy
cfr_reach_prob::Dict{S, Vector{Pair{E, Float64}}}
best_response_action_cache::Dict{S,A}
best_response_value_cache::Dict{E,Float64}
best_responder::X
policy::P
end

"""
BestResponsePolicy(policy, env, best_responder)

- `policy`, the original policy to be wrapped in the best response policy.
- `env`, the environment to handle.
- `best_responder`, the player to choose best response action.
"""
function BestResponsePolicy(policy, env, best_responder; state_type=String, action_type=Int)
# S = typeof(get_state(env)) # TODO: currently it will break the OpenSpielEnv. Can not get information set for chance player
# A = eltype(get_actions(env)) # TODO: for chance players it will return ActionProbPair
S = state_type
A = action_type
E = typeof(env)

p = BestResponsePolicy(
Dict{S, Vector{Pair{E, Float64}}}(),
Dict{S, A}(),
Dict{E, Float64}(),
best_responder,
policy
)

e = copy(env)
@assert e == env "The copy method doesn't seem to be implemented for environment: $env"
@assert hash(e) == hash(env) "The hash method doesn't seem to be implemented for environment: $env"
RLBase.reset!(e) # start from the root!
init_cfr_reach_prob!(p, e)
p
end

function (p::BestResponsePolicy)(env::AbstractEnv)
if get_current_player(env) == p.best_responder
best_response_action(p, env)
else
p.policy(env)
end
end

function init_cfr_reach_prob!(p, env, reach_prob=1.0)
if !get_terminal(env)
if get_current_player(env) == p.best_responder
push!(get!(p.cfr_reach_prob, get_state(env), []), env => reach_prob)

for a in get_legal_actions(env)
init_cfr_reach_prob!(p, child(env, a), reach_prob)
end
elseif get_current_player(env) == get_chance_player(env)
for a::ActionProbPair in get_actions(env)
init_cfr_reach_prob!(p, child(env, a), reach_prob * a.prob)
end
else # opponents
for a in get_legal_actions(env)
init_cfr_reach_prob!(p, child(env, a), reach_prob * get_prob(p.policy, env, a))
end
end
end
end

function best_response_value(p, env)
get!(p.best_response_value_cache, env) do
if get_terminal(env)
get_reward(env, p.best_responder)
elseif get_current_player(env) == p.best_responder
a = best_response_action(p, env)
best_response_value(p, child(env, a))
elseif get_current_player(env) == get_chance_player(env)
v = 0.
for a::ActionProbPair in get_actions(env)
v += a.prob * best_response_value(p, child(env, a))
end
v
else
v = 0.
for a in get_legal_actions(env)
v += get_prob(p.policy, env, a) * best_response_value(p, child(env, a))
end
v
end
end
end

function best_response_action(p, env)
get!(p.best_response_action_cache, get_state(env)) do
best_action, best_action_value = nothing, typemin(Float64)
for a in get_legal_actions(env)
# for each information set (`get_state(env)` here), we may have several paths to reach it
# here we sum the cfr reach prob weighted value to find out the best action
v = sum(p.cfr_reach_prob[get_state(env)]) do (e, reach_prob)
reach_prob * best_response_value(p, child(e, a))
end
if v > best_action_value
best_action, best_action_value = a, v
end
end
best_action
end
end

RLBase.update!(p::BestResponsePolicy, args...) = nothing

function RLBase.get_prob(p::BestResponsePolicy, env::AbstractEnv)
if get_current_player(env) == p.best_responder
onehot(p(env), get_actions(env))
else
get_prob(p.policy, env)
end
end
4 changes: 4 additions & 0 deletions src/algorithms/cfr/cfr.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
include("abstract_cfr_policy.jl")
include("tabular_cfr.jl")
include("outcome_sampling_mccfr.jl")
include("external_sampling_mccfr.jl")
include("best_response_policy.jl")
include("nash_conv.jl")
include("deep_cfr.jl")
181 changes: 181 additions & 0 deletions src/algorithms/cfr/deep_cfr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
export DeepCFR

using Statistics: mean
using StatsBase

"""
DeepCFR(;kwargs...)

Symbols used here follow the paper: [Deep Counterfactual Regret Minimization](https://arxiv.org/abs/1811.00164)

# Keyword arguments

- `K`, number of traverrsal.
- `t`, number of iteration.
- `Π`, the policy network.
- `V`, a dictionary of each player's advantage network.
- `MΠ`, a strategy memory.
- `MV`, a dictionary of each player's advantage memory.
- `reinitialize_freq=1`, the frequency of reinitializing the value networks.
"""
Base.@kwdef mutable struct DeepCFR{TP, TV, TMP, TMV, I, R, P} <: AbstractCFRPolicy
Π::TP
V::TV
MΠ::TMP
MV::TMV
K::Int = 20
t::Int = 1
reinitialize_freq::Int = 1
batch_size_V::Int = 32
batch_size_Π::Int = 32
n_training_steps_V::Int = 1
n_training_steps_Π::Int = 1
rng::R = Random.GLOBAL_RNG
initializer::I = glorot_normal(rng)
max_grad_norm::Float32 = 10.0f0
# for logging
Π_losses::Vector{Float32} = zeros(Float32, n_training_steps_Π)
V_losses::Dict{P, Vector{Float32}} = Dict(k => zeros(Float32, n_training_steps_V) for (k,_) in MV)
Π_norms::Vector{Float32} = zeros(Float32, n_training_steps_Π)
V_norms::Dict{P, Vector{Float32}} = Dict(k => zeros(Float32, n_training_steps_V) for (k,_) in MV)
end

function RLBase.get_prob(π::DeepCFR, env::AbstractEnv)
I = send_to_device(device(π.Π), get_state(env))
m = send_to_device(device(π.Π), ifelse.(get_legal_actions_mask(env), 0.f0, -Inf32))
logits = π.Π(Flux.unsqueeze(I, ndims(I)+1)) |> vec
σ = softmax(logits .+ m)
send_to_host(σ)
end

(π::DeepCFR)(env::AbstractEnv) = sample(π.rng, get_actions(env), Weights(get_prob(π, env), 1.0))

"Run one interation"
function RLBase.update!(π::DeepCFR, env::AbstractEnv)
for p in get_players(env)
if p != get_chance_player(env)
for k in 1:π.K
external_sampling!(π, copy(env), p)
end
update_advantage_networks(π, p)
end
end
π.t += 1
end

"Update Π (policy network)"
function RLBase.update!(π::DeepCFR)
Π = π.Π
Π_losses = π.Π_losses
Π_norms = π.Π_norms
D = device(Π)
MΠ = π.MΠ
ps = Flux.params(Π)

for x in ps
x .= π.initializer(size(x)...)
end

for i in 1:π.n_training_steps_Π
batch_inds = rand(π.rng, 1:length(MΠ), π.batch_size_Π)
I = send_to_device(D, Flux.batch([MΠ[:I][i] for i in batch_inds]))
σ = send_to_device(D, Flux.batch([MΠ[:σ][i] for i in batch_inds]))
t = send_to_device(D, Flux.batch([MΠ[:t][i] / π.t for i in batch_inds]))
m = send_to_device(D, Flux.batch([ifelse.(MΠ[:m][i], 0.f0, -Inf32) for i in batch_inds]))
gs = gradient(ps) do
logits = Π(I) .+ m
loss = mean(reshape(t, 1, :) .* ((σ .- softmax(logits)) .^ 2))
ignore() do
# println(σ, "!!!",m, "===", Π(I))
Π_losses[i] = loss
end
loss
end
Π_norms[i] = clip_by_global_norm!(gs, ps, π.max_grad_norm)
update!(Π, gs)
end
end

"Update advantage network"
function update_advantage_networks(π, p)
V = π.V[p]
V_losses = π.V_losses[p]
V_norms = π.V_norms[p]
MV = π.MV[p]
if π.t % π.reinitialize_freq == 0
for x in Flux.params(V)
# TODO: inplace
x .= π.initializer(size(x)...)
end
end
if length(MV) >= π.batch_size_V
for i in 1:π.n_training_steps_V
batch_inds = rand(π.rng, 1:length(MV), π.batch_size_V)
I = send_to_device(device(V), Flux.batch([MV[:I][i] for i in batch_inds]))
r̃ = send_to_device(device(V), Flux.batch([MV[:r̃][i] for i in batch_inds]))
t = send_to_device(device(V), Flux.batch([MV[:t][i] / π.t for i in batch_inds]))
m = send_to_device(device(V), Flux.batch([MV[:m][i] for i in batch_inds]))
ps = Flux.params(V)
gs = gradient(ps) do
loss = mean(reshape(t, 1, :) .* ((r̃ .- V(I) .* m) .^ 2))
ignore() do
V_losses[i] = loss
end
loss
end
V_norms[i] = clip_by_global_norm!(gs, ps, π.max_grad_norm)
update!(V, gs)
end
end
end

"CFR Traversal with External Sampling"
function external_sampling!(π::DeepCFR, env::AbstractEnv, p)
if get_terminal(env)
get_reward(env, p)
elseif get_current_player(env) == get_chance_player(env)
env(rand(π.rng, get_actions(env)))
external_sampling!(π, env, p)
elseif get_current_player(env) == p
V = π.V[p]
s = get_state(env)
I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s)+1))
A = get_actions(env)
m = get_legal_actions_mask(env)
σ = masked_regret_matching(V(I) |> send_to_host |> vec, m)
v = zeros(length(σ))
v̄ = 0.
for i in 1:length(m)
if m[i]
v[i] = external_sampling!(π, child(env, A[i]), p)
v̄ += σ[i] * v[i]
end
end
push!(π.MV[p],I=s, t = π.t, r̃= (v .- v̄) .* m, m = m)
else
V = π.V[get_current_player(env)]
s = get_state(env)
I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s)+1))
A = get_actions(env)
m = get_legal_actions_mask(env)
σ = masked_regret_matching(V(I) |> send_to_host |> vec, m)
push!(π.MΠ, I=s, t = π.t, σ=σ, m = m)
a = sample(π.rng, A, Weights(σ, 1.0))
env(a)
external_sampling!(π, env, p)
end
end

"This is the specific regret matching method used in DeepCFR"
function masked_regret_matching(v, m)
v⁺ = max.(v .* m, 0.f0)
s = sum(v⁺)
if s > 0
v⁺ ./= s
else
fill!(v⁺, 0.f0)
v⁺[findmax(v, m)[2]] = 1.
end
v⁺
end
Loading