-
-
Notifications
You must be signed in to change notification settings - Fork 109
WIP: add nfsp
algorithm and relative experiment
#375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0a14cd4
f6c0216
d03e015
4eac2af
5efb4b2
d2e7d0d
5503d31
68ac91e
f1dfb37
db903a2
b625a7a
a9b95a9
b254dfd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# --- | ||
# title: JuliaRL\_NFSP\_KuhnPoker | ||
# cover: assets/logo.svg | ||
# description: NFSP applied to KuhnPokerEnv | ||
# date: 2021-07-18 | ||
# author: "[Peter Chen](https://github.com/peterchen96)" | ||
# --- | ||
|
||
#+ tangle=false | ||
using ReinforcementLearning | ||
using StableRNGs | ||
using Flux | ||
using Flux.Losses | ||
|
||
mutable struct ResultNEpisode <: AbstractHook | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it set to a subtype of |
||
episode::Vector{Int} | ||
results | ||
end | ||
recorder = ResultNEpisode([], []) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please note that it will be set to the global variable in the |
||
|
||
function RL.Experiment( | ||
::Val{:JuliaRL}, | ||
::Val{:NFSP}, | ||
::Val{:KuhnPoker}, | ||
::Nothing; | ||
seed = 123, | ||
) | ||
|
||
# Encode the KuhnPokerEnv's states for training. | ||
env = KuhnPokerEnv() | ||
states = [ | ||
(), (:J,), (:Q,), (:K,), | ||
(:J, :Q), (:J, :K), (:Q, :J), (:Q, :K), (:K, :J), (:K, :Q), | ||
(:J, :bet), (:J, :pass), (:Q, :bet), (:Q, :pass), (:K, :bet), (:K, :pass), | ||
(:J, :pass, :bet), (:J, :bet, :bet), (:J, :bet, :pass), (:J, :pass, :pass), | ||
(:Q, :pass, :bet), (:Q, :bet, :bet), (:Q, :bet, :pass), (:Q, :pass, :pass), | ||
(:K, :pass, :bet), (:K, :bet, :bet), (:K, :bet, :pass), (:K, :pass, :pass), | ||
(:J, :pass, :bet, :pass), (:J, :pass, :bet, :bet), (:Q, :pass, :bet, :pass), | ||
(:Q, :pass, :bet, :bet), (:K, :pass, :bet, :pass), (:K, :pass, :bet, :bet), | ||
] # collect all states | ||
Comment on lines
+30
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about adding another state representation to this environment? |
||
states_indexes_Dict = Dict((i, j) for (j, i) in enumerate(states)) | ||
wrapped_env = StateTransformedEnv( | ||
env; | ||
state_mapping = s -> [states_indexes_Dict[s]], | ||
state_space_mapping = ss -> [[i] for i in 1:length(states)] | ||
) | ||
|
||
# set parameters for NFSPAgentManager | ||
nfsp = NFSPAgentManager(wrapped_env; | ||
η = 0.1, | ||
_device = Flux.cpu, | ||
Optimizer = Flux.Descent, | ||
rng = StableRNG(seed), | ||
batch_size = 128, | ||
learn_freq = 128, | ||
min_buffer_size_to_learn = 1000, | ||
hidden_layers = (128, 128), | ||
|
||
# Reinforcement Learning(RL) agent parameters | ||
rl_loss_func = mse, | ||
rl_learning_rate = 0.01, | ||
replay_buffer_capacity = 200_000, | ||
ϵ_start = 0.06, | ||
ϵ_end = 0.001, | ||
ϵ_decay = 20_000_000, | ||
discount_factor = 1.0f0, | ||
update_target_network_freq = 19200, | ||
|
||
# Supervisor Learning(SL) agent parameters | ||
sl_learning_rate = 0.01, | ||
reservoir_buffer_capacity = 2_000_000, | ||
) | ||
|
||
stop_condition = StopAfterEpisode(10_000_000, is_show_progress=!haskey(ENV, "CI")) | ||
hook = DoEveryNEpisode(; n = 10_000) do t, nfsp, wrapped_env | ||
push!(recorder.episode, t) | ||
push!(recorder.results, RLZoo.nash_conv(nfsp, wrapped_env)) | ||
end | ||
Experiment(nfsp, wrapped_env, stop_condition, hook, "# run NFSP on KuhnPokerEnv") | ||
end | ||
|
||
#+ tangle=false | ||
using Plots | ||
ex = E`JuliaRL_NFSP_KuhnPoker` | ||
run(ex) | ||
plot(recorder.episode, recorder.results, xaxis=:log, yaxis=:log, xlabel="episode", ylabel="nash_conv") | ||
|
||
savefig("assets/JuliaRL_NFSP_KuhnPoker.png")#hide | ||
|
||
#  | ||
Comment on lines
+88
to
+90
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are used to generate the progress plot. Whether it is needed depends on how you write the |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"description": "Neural Fictitious Self-play(NFSP) related experiments.", | ||
"order": [ | ||
"JuliaRL_NFSP_KuhnPoker.jl" | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,8 @@ | |
"Policy Gradient", | ||
"Offline", | ||
"Search", | ||
"CFR" | ||
"CFR", | ||
"NFSP" | ||
] | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
include("average_learner.jl") | ||
include("nfsp.jl") | ||
include("nfsp_manager.jl") | ||
|
||
|
||
function Base.run( | ||
nfsp::NFSPAgentManager, | ||
env::AbstractEnv, | ||
stop_condition = StopAfterEpisode(1), | ||
hook = EmptyHook(), | ||
) | ||
@assert NumAgentStyle(env) isa MultiAgent | ||
@assert DynamicStyle(env) === SEQUENTIAL | ||
@assert RewardStyle(env) === TERMINAL_REWARD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this required? |
||
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC | ||
@assert DefaultStateStyle(env) isa InformationSet | ||
|
||
is_stop = false | ||
|
||
while !is_stop | ||
RLBase.reset!(env) | ||
hook(PRE_EPISODE_STAGE, nfsp, env) | ||
|
||
while !is_terminated(env) # one episode | ||
RLBase.update!(nfsp, env) | ||
hook(POST_ACT_STAGE, nfsp, env) | ||
|
||
if stop_condition(nfsp, env) | ||
is_stop = true | ||
break | ||
end | ||
end # end of an episode | ||
|
||
if is_terminated(env) | ||
hook(POST_EPISODE_STAGE, nfsp, env) | ||
end | ||
end | ||
hook(POST_EXPERIMENT_STAGE, nfsp, env) | ||
hook | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
export AverageLearner | ||
|
||
mutable struct AverageLearner{ | ||
Tq<:AbstractApproximator, | ||
R<:AbstractRNG, | ||
} <: AbstractLearner | ||
approximator::Tq | ||
min_reservoir_history::Int | ||
update_freq::Int | ||
update_step::Int | ||
sampler::NStepBatchSampler | ||
rng::R | ||
end | ||
|
||
""" | ||
AverageLearner(;kwargs...) | ||
|
||
In the `Neural Fictitious Self-play` algorithm, AverageLearner, also known as Supervisor Learner, works to learn the best response for the state from RL_agent's policy. | ||
|
||
See paper: [Deep Reinforcement Learning from Self-Play in Imperfect-Information Games](https://arxiv.org/pdf/1603.01121.pdf) | ||
|
||
# Keywords | ||
|
||
- `approximator`::[`AbstractApproximator`](@ref). | ||
- `batch_size::Int=32` | ||
- `update_horizon::Int=1`: length of update ('n' in n-step update). | ||
- `min_reservoir_history::Int=32`: number of transitions that should be experienced before updating the `approximator`. | ||
- `update_freq::Int=1`: the frequency of updating the `approximator`. | ||
- `stack_size::Union{Int, Nothing}=nothing`: use the recent `stack_size` frames to form a stacked state. | ||
- `traces = SARTS`. | ||
- `rng = Random.GLOBAL_RNG` | ||
""" | ||
|
||
function AverageLearner(; | ||
approximator::Tq, | ||
batch_size::Int = 32, | ||
update_horizon::Int = 1, | ||
min_reservoir_history::Int = 32, | ||
update_freq::Int = 1, | ||
update_step::Int = 0, | ||
stack_size::Union{Int,Nothing} = nothing, | ||
traces = SARTS, | ||
rng = Random.GLOBAL_RNG, | ||
) where {Tq} | ||
sampler = NStepBatchSampler{traces}(; | ||
γ = 0f0, # no need to set discount factor | ||
n = update_horizon, | ||
stack_size = stack_size, | ||
batch_size = batch_size, | ||
) | ||
AverageLearner( | ||
approximator, | ||
min_reservoir_history, | ||
update_freq, | ||
update_step, | ||
sampler, | ||
rng, | ||
) | ||
end | ||
|
||
Flux.functor(x::AverageLearner) = (Q = x.approximator, ), y -> begin | ||
x = @set x.approximator = y.Q | ||
x | ||
end | ||
|
||
function (learner::AverageLearner)(env) | ||
env |> | ||
state |> | ||
x -> Flux.unsqueeze(x, ndims(x) + 1) |> | ||
x -> send_to_device(device(learner), x) |> | ||
learner.approximator |> | ||
send_to_host |> vec | ||
end | ||
|
||
function RLBase.update!(learner::AverageLearner, t::AbstractTrajectory) | ||
length(t[:terminal]) - learner.sampler.n <= learner.min_reservoir_history && return | ||
|
||
learner.update_step += 1 | ||
learner.update_step % learner.update_freq == 0 || return | ||
|
||
inds, batch = sample(learner.rng, t, learner.sampler) | ||
if t isa PrioritizedTrajectory | ||
priorities = update!(learner, batch) | ||
t[:priority][inds] .= priorities | ||
else | ||
update!(learner, batch) | ||
end | ||
end | ||
|
||
function RLBase.update!(learner::AverageLearner, batch::NamedTuple) | ||
Q = learner.approximator | ||
_device(x) = send_to_device(device(Q), x) | ||
|
||
local s, a | ||
@sync begin | ||
@async s = _device(batch[:state]) | ||
@async a = _device(batch[:action]) | ||
end | ||
|
||
gs = gradient(params(Q)) do | ||
ŷ = Q(s) | ||
y = Flux.onehotbatch(a, axes(ŷ, 1)) |> _device | ||
Flux.Losses.crossentropy(ŷ, y) | ||
end | ||
|
||
update!(Q, gs) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it set to
false
?