Skip to content

add rainbow #724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,70 +10,67 @@
using ReinforcementLearning
using StableRNGs
using Flux
using Flux.Losses

function RL.Experiment(
::Val{:JuliaRL},
::Val{:Rainbow},
::Val{:CartPole},
::Nothing;
seed = 123,
; seed=123
)
rng = StableRNG(seed)

env = CartPoleEnv(; T = Float32, rng = rng)
env = CartPoleEnv(; T=Float32, rng=rng)
ns, na = length(state(env)), length(action_space(env))

n_atoms = 51
agent = Agent(
policy = QBasedPolicy(
learner = RainbowLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na * n_atoms; init = glorot_uniform(rng)),
) |> gpu,
optimizer = ADAM(0.0005),
policy=QBasedPolicy(
learner=RainbowLearner(
approximator=Approximator(
model=TwinNetwork(
Chain(
Dense(ns, 128, relu; init=glorot_uniform(rng)),
Dense(128, 128, relu; init=glorot_uniform(rng)),
Dense(128, na * n_atoms; init=glorot_uniform(rng)),
);
sync_freq=100
),
optimiser=ADAM(0.0005),
),
target_approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na * n_atoms; init = glorot_uniform(rng)),
) |> gpu,
optimizer = ADAM(0.0005),
),
n_actions = na,
n_atoms = n_atoms,
Vₘₐₓ = 200.0f0,
Vₘᵢₙ = 0.0f0,
update_freq = 1,
γ = 0.99f0,
update_horizon = 1,
batch_size = 32,
stack_size = nothing,
min_replay_history = 100,
loss_func = (ŷ, y) -> logitcrossentropy(ŷ, y; agg = identity),
target_update_freq = 100,
rng = rng,
n_actions=na,
n_atoms=n_atoms,
Vₘₐₓ=200.0f0,
Vₘᵢₙ=0.0f0,
γ=0.99f0,
update_horizon=1,
rng=rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
explorer=EpsilonGreedyExplorer(
kind=:exp,
ϵ_stable=0.01,
decay_steps=500,
rng=rng,
),
),
trajectory = CircularArrayPSARTTrajectory(
capacity = 1000,
state = Vector{Float32} => (ns,),
),
trajectory=Trajectory(
container=CircularArraySARTTraces(
capacity=1000,
state=Float32 => (ns,),
),
sampler=BatchSampler{SS′ART}(
batch_size=32,
rng=rng
),
controller=InsertSampleRatioController(
threshold=100,
n_inserted=-1
)
)
)

stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
Experiment(agent, env, stop_condition, hook)
end

#+ tangle=false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include(joinpath(EXPERIMENTS_DIR, "JuliaRL_PrioritizedDQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_QRDQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_REMDQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_IQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_Rainbow_CartPole.jl"))

# dynamic loading environments
function __init__() end
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningExperiments/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ run(E`JuliaRL_PrioritizedDQN_CartPole`)
run(E`JuliaRL_QRDQN_CartPole`)
run(E`JuliaRL_REMDQN_CartPole`)
run(E`JuliaRL_IQN_CartPole`)
run(E`JuliaRL_Rainbow_CartPole`)
# run(E`JuliaRL_BC_CartPole`)
# run(E`JuliaRL_Rainbow_CartPole`)
# run(E`JuliaRL_VMPO_CartPole`)
# run(E`JuliaRL_VPG_CartPole`)
# run(E`JuliaRL_BasicDQN_MountainCar`)
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ include("prioritized_dqn.jl")
include("qr_dqn.jl")
include("rem_dqn.jl")
include("iqn.jl")
# include("rainbow.jl")
include("rainbow.jl")
# include("common.jl")
171 changes: 53 additions & 118 deletions src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl
Original file line number Diff line number Diff line change
@@ -1,140 +1,62 @@
export RainbowLearner

"""
RainbowLearner(;kwargs...)

See paper: [Rainbow: Combining Improvements in Deep Reinforcement Learning](https://arxiv.org/abs/1710.02298)

# Keywords

- `approximator`::[`AbstractApproximator`](@ref): used to get Q-values of a state.
- `target_approximator`::[`AbstractApproximator`](@ref): similar to `approximator`, but used to estimate the target (the next state).
- `loss_func`: the loss function. It is recommended to use Flux.Losses.logitcrossentropy. Flux.Losses.crossentropy will encounter the problem of negative numbers.
- `Vₘₐₓ::Float32`: the maximum value of distribution.
- `Vₘᵢₙ::Float32`: the minimum value of distribution.
- `n_actions::Int`: number of possible actions.
- `γ::Float32=0.99f0`: discount rate.
- `batch_size::Int=32`
- `update_horizon::Int=1`: length of update ('n' in n-step update).
- `min_replay_history::Int=32`: number of transitions that should be experienced before updating the `approximator`.
- `update_freq::Int=4`: the frequency of updating the `approximator`.
- `target_update_freq::Int=500`: the frequency of syncing `target_approximator`.
- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
- `default_priority::Float32=1.0f2.`: the default priority for newly added transitions. It must be `>= 1`.
- `n_atoms::Int=51`: the number of buckets of the value function distribution.
- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
- `rng = Random.GLOBAL_RNG`
"""
mutable struct RainbowLearner{Tq,Tt,Tf,Ts,R<:AbstractRNG} <: Any
approximator::Tq
target_approximator::Tt
loss_func::Tf
sampler::NStepBatchSampler
using Random: AbstractRNG, GLOBAL_RNG
using Flux: params, unsqueeze, softmax
using Flux.Losses: logitcrossentropy
using Functors: @functor

Base.@kwdef mutable struct RainbowLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner
approximator::A
Vₘₐₓ::Float32
Vₘᵢₙ::Float32
n_actions::Int
n_atoms::Int
support::Ts
delta_z::Float32
min_replay_history::Int
update_freq::Int
target_update_freq::Int
update_step::Int
default_priority::Float32
β_priority::Float32
rng::R
loss::Float32
γ::Float32
update_horizon::Int = 1
n_atoms::Int = 51
support::AbstractVector = range(Float32(-Vₘₐₓ), Float32(Vₘₐₓ), length=n_atoms)
delta_z::Float32 = convert(Float32, support.step)
default_priority::Float32 = 1.0f2
β_priority::Float32 = 0.5f0
loss_func::Any = (ŷ, y) -> logitcrossentropy(ŷ, y; agg=identity)
rng::AbstractRNG = GLOBAL_RNG
# for logging
loss::Float32 = 0.0f0
end

Functors.functor(x::RainbowLearner) =
(Q=x.approximator, Qₜ=x.target_approximator, S=x.support),
y -> begin
x = @set x.approximator = y.Q
x = @set x.target_approximator = y.Qₜ
x = @set x.support = y.S
x
end
@functor RainbowLearner (support, approximator)

function RainbowLearner(;
approximator,
target_approximator,
loss_func,
Vₘₐₓ,
Vₘᵢₙ,
n_actions,
n_atoms=51,
support=collect(range(Float32(-Vₘₐₓ), Float32(Vₘₐₓ), length=n_atoms)),
stack_size=4,
delta_z=Float32(support[2] - support[1]),
γ=0.99,
batch_size=32,
update_horizon=1,
min_replay_history=32,
update_freq=1,
target_update_freq=500,
update_step=0,
default_priority=1.0f2,
β_priority=0.5f0,
traces=SARTS,
rng=Random.GLOBAL_RNG
)
default_priority >= 1.0f0 || error("default value must be >= 1.0f0")
copyto!(approximator, target_approximator) # force sync
support = send_to_device(device(approximator), support)
sampler = NStepBatchSampler{traces}(;
γ=γ,
n=update_horizon,
stack_size=stack_size,
batch_size=batch_size
)
RainbowLearner(
approximator,
target_approximator,
loss_func,
sampler,
Vₘₐₓ,
Vₘᵢₙ,
n_actions,
n_atoms,
support,
delta_z,
min_replay_history,
update_freq,
target_update_freq,
update_step,
default_priority,
β_priority,
rng,
0.0f0,
)
function (L::RainbowLearner)(s::AbstractArray)
logits = L.approximator(s)
q = L.support .* softmax(reshape(logits, :, L.n_actions))
sum(q, dims=1)
end

function (learner::RainbowLearner)(env)
function (learner::RainbowLearner)(env::AbstractEnv)
s = send_to_device(device(learner.approximator), state(env))
s = Flux.unsqueeze(s, dims=ndims(s) + 1)
logits = learner.approximator(s)
q = learner.support .* softmax(reshape(logits, :, learner.n_actions))
vec(sum(q, dims=1)) |> send_to_host
s = unsqueeze(s, dims=ndims(s) + 1)
s |> learner |> vec |> send_to_host
end

function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
Q = learner.approximator
Qₜ = learner.target_approximator
γ = learner.sampler.γ
function RLBase.optimise!(learner::RainbowLearner, batch::NamedTuple)
A = learner.approximator
Q = A.model.source
Qₜ = A.model.target
γ = learner.γ
β = learner.β_priority
loss_func = learner.loss_func
n_atoms = learner.n_atoms
n_actions = learner.n_actions
support = learner.support
delta_z = learner.delta_z
update_horizon = learner.sampler.n
batch_size = learner.sampler.batch_size
update_horizon = learner.update_horizon

D = device(Q)
states = send_to_device(D, batch.state)
rewards = send_to_device(D, batch.reward)
terminals = send_to_device(D, batch.terminal)
next_states = send_to_device(D, batch.next_state)

batch_size = length(terminals)
actions = CartesianIndex.(batch.action, 1:batch_size)

target_support =
Expand All @@ -143,11 +65,14 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)

next_logits = Qₜ(next_states)
next_probs = reshape(softmax(reshape(next_logits, n_atoms, :)), n_atoms, n_actions, :)

next_q = reshape(sum(support .* next_probs, dims=1), n_actions, :)

if haskey(batch, :next_legal_actions_mask)
l′ = send_to_device(D, batch[:next_legal_actions_mask])
next_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
end

next_prob_select = select_best_probs(next_probs, next_q)

target_distribution = project_distribution(
Expand All @@ -160,21 +85,21 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
)

is_use_PER = haskey(batch, :priority) # is use Prioritized Experience Replay

if is_use_PER
updated_priorities = Vector{Float32}(undef, batch_size)
weights = 1.0f0 ./ ((batch.priority .+ 1.0f-10) .^ β)
weights ./= maximum(weights)
weights = send_to_device(D, weights)
# TODO: init on device directly
end

gs = gradient(Flux.params(Q)) do
gs = gradient(params(A)) do
logits = reshape(Q(states), n_atoms, n_actions, :)
select_logits = logits[:, actions]
# The original paper normalized logits, but using normalization and Flux.Losses.crossentropy is not as stable as using Flux.Losses.logitcrossentropy.
batch_losses = loss_func(select_logits, target_distribution)
loss =
is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size :
mean(batch_losses)
loss = is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size : mean(batch_losses)
ignore_derivatives() do
if is_use_PER
updated_priorities .= send_to_host(vec((batch_losses .+ 1.0f-10) .^ β))
Expand All @@ -184,9 +109,9 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
loss
end

update!(Q, gs)
optimise!(A, gs)

is_use_PER ? updated_priorities : nothing
is_use_PER ? batch.key => updated_priorities : nothing
end

@inline function select_best_probs(probs, q)
Expand All @@ -213,3 +138,13 @@ function project_distribution(supports, weights, target_support, delta_z, vmin,
) .* reshape(weights, n_atoms, 1, batch_size)
reshape(sum(projection, dims=1), n_atoms, batch_size)
end

function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, trajectory::Trajectory)
for batch in trajectory
res = optimise!(policy, batch) |> send_to_host
if !isnothing(res)
k, p = res
trajectory[:priority, k] = p
end
end
end