Skip to content

add PrioritizedDQN #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/ReinforcementLearningCore/src/utils/device.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: watch https://github.com/JuliaGPU/Adapt.jl/pull/52

export device, send_to_device
export device, send_to_device, send_to_host

using Flux
using CUDA
Expand All @@ -9,6 +9,8 @@ using Random

import CUDA: device

send_to_host(x) = send_to_device(Val(:cpu), x)

send_to_device(d) = x -> send_to_device(device(d), x)

send_to_device(::Val{:cpu}, m) = fmap(x -> adapt(Array, x), m)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# title: JuliaRL\_PrioritizedDQN\_CartPole
# cover: assets/JuliaRL_PrioritizedDQN_CartPole.png
# description: PrioritizedDQN applied to CartPole
# date: 2021-05-22
# date: 2022-06-18
# author: "[Jun Tian](https://github.com/findmyway)"
# ---

Expand All @@ -16,59 +16,68 @@ function RL.Experiment(
::Val{:JuliaRL},
::Val{:PrioritizedDQN},
::Val{:CartPole},
::Nothing;
save_dir = nothing,
seed = 123,
; seed=123,
n=1,
γ=0.99f0,
is_enable_double_DQN=true
)
rng = StableRNG(seed)

env = CartPoleEnv(; T = Float32, rng = rng)
env = CartPoleEnv(; T=Float32, rng=rng)
ns, na = length(state(env)), length(action_space(env))

agent = Agent(
policy = QBasedPolicy(
learner = PrioritizedDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na; init = glorot_uniform(rng)),
) |> gpu,
optimizer = ADAM(),
policy=QBasedPolicy(
learner=PrioritizedDQNLearner(
approximator=Approximator(
model=TwinNetwork(
Chain(
Dense(ns, 128, relu; init=glorot_uniform(rng)),
Dense(128, 128, relu; init=glorot_uniform(rng)),
Dense(128, na; init=glorot_uniform(rng)),
);
sync_freq=100
),
optimiser=ADAM(),
),
target_approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na; init = glorot_uniform(rng)),
) |> gpu,
optimizer = ADAM(),
),
loss_func = (ŷ, y) -> huber_loss(ŷ, y; agg = identity),
stack_size = nothing,
batch_size = 32,
update_horizon = 1,
min_replay_history = 100,
update_freq = 1,
target_update_freq = 100,
rng = rng,
n=n,
γ=γ,
β_priority=0.5f0,
is_enable_double_DQN=is_enable_double_DQN,
loss_func=(ŷ, y) -> huber_loss(ŷ, y; agg=identity),
rng=rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
explorer=EpsilonGreedyExplorer(
kind=:exp,
ϵ_stable=0.01,
decay_steps=500,
rng=rng,
),
),
trajectory = CircularArrayPSARTTrajectory(
capacity = 1000,
state = Vector{Float32} => (ns,),
),
trajectory=Trajectory(
container=CircularPrioritizedTraces(
CircularArraySARTTraces(
capacity=1000,
state=Float32 => (ns,),
);
default_priority=100.0f0
),
sampler=NStepBatchSampler{SS′ART}(
n=n,
γ=γ,
batch_size=32,
rng=rng
),
controller=InsertSampleRatioController(
threshold=100,
n_inserted=-1
)
)
)

stop_condition = StopAfterStep(10_000)
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play CartPole with PrioritizedDQN")
Experiment(agent, env, stop_condition, hook)
end


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const EXPERIMENTS_DIR = joinpath(@__DIR__, "experiments")
# end
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_BasicDQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_DQN_CartPole.jl"))
include(joinpath(EXPERIMENTS_DIR, "JuliaRL_PrioritizedDQN_CartPole.jl"))

# dynamic loading environments
function __init__() end
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningExperiments/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ CUDA.allowscalar(false)

run(E`JuliaRL_BasicDQN_CartPole`)
run(E`JuliaRL_DQN_CartPole`)
run(E`JuliaRL_PrioritizedDQN_CartPole`)
# run(E`JuliaRL_BC_CartPole`)
# run(E`JuliaRL_PrioritizedDQN_CartPole`)
# run(E`JuliaRL_Rainbow_CartPole`)
# run(E`JuliaRL_QRDQN_CartPole`)
# run(E`JuliaRL_REMDQN_CartPole`)
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningZoo/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.6.0-dev"
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
Expand All @@ -13,7 +14,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ReinforcementLearningBase = "0.10"
ReinforcementLearningCore = "0.9"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include("basic_dqn.jl")
include("dqn.jl")
# include("prioritized_dqn.jl")
include("prioritized_dqn.jl")
# include("qr_dqn.jl")
# include("rem_dqn.jl")
# include("rainbow.jl")
Expand Down
180 changes: 54 additions & 126 deletions src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl
Original file line number Diff line number Diff line change
@@ -1,150 +1,78 @@
export PrioritizedDQNLearner

"""
PrioritizedDQNLearner(;kwargs...)

See paper: [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952)
And also https://danieltakeshi.github.io/2019/07/14/per/

# Keywords

- `approximator`::[`AbstractApproximator`](@ref): used to get Q-values of a state.
- `target_approximator`::[`AbstractApproximator`](@ref): similar to `approximator`, but used to estimate the target (the next state).
- `loss_func`: the loss function.
- `γ::Float32=0.99f0`: discount rate.
- `batch_size::Int=32`
- `update_horizon::Int=1`: length of update ('n' in n-step update).
- `min_replay_history::Int=32`: number of transitions that should be experienced before updating the `approximator`.
- `update_freq::Int=4`: the frequency of updating the `approximator`.
- `target_update_freq::Int=100`: the frequency of syncing `target_approximator`.
- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
- `default_priority::Float64=100.`: the default priority for newly added transitions.
- `rng = Random.GLOBAL_RNG`

!!! note
Our implementation is slightly different from the original paper. But it should be aligned with the version in [dopamine](https://github.com/google/dopamine/blob/90527f4eaad4c574b92df556c02dea45853ffd2e/dopamine/jax/agents/rainbow/rainbow_agent.py#L26-L30).
"""
mutable struct PrioritizedDQNLearner{Tq,Tt,Tf,R<:AbstractRNG} <: Any
approximator::Tq
target_approximator::Tt
loss_func::Tf
sampler::NStepBatchSampler
min_replay_history::Int
update_freq::Int
target_update_freq::Int
update_step::Int
default_priority::Float32
β_priority::Float32
rng::R
using Setfield: @set
using Random: AbstractRNG, GLOBAL_RNG
import Functors
using LinearAlgebra: dot

Base.@kwdef mutable struct PrioritizedDQNLearner{A<:Approximator{<:TwinNetwork}} <: AbstractLearner
approximator::A
loss_func::Any # !!! here the loss func must return the loss before reducing over the batch dimension
n::Int = 1
γ::Float32 = 0.99f0
β_priority::Float32 = 0.5f0
is_enable_double_DQN::Bool = true
rng::AbstractRNG = GLOBAL_RNG
# for logging
loss::Float32
end

function PrioritizedDQNLearner(;
approximator::Tq,
target_approximator::Tt,
loss_func::Tf,
stack_size::Union{Int,Nothing} = nothing,
γ::Float32 = 0.99f0,
batch_size::Int = 32,
update_horizon::Int = 1,
min_replay_history::Int = 32,
update_freq::Int = 1,
target_update_freq::Int = 100,
update_step::Int = 0,
default_priority::Float32 = 100.0f0,
β_priority::Float32 = 0.5f0,
traces = SARTS,
rng = Random.GLOBAL_RNG,
) where {Tq,Tt,Tf}
copyto!(approximator, target_approximator)
sampler = NStepBatchSampler{traces}(;
γ = γ,
n = update_horizon,
stack_size = stack_size,
batch_size = batch_size,
)
PrioritizedDQNLearner(
approximator,
target_approximator,
loss_func,
sampler,
min_replay_history,
update_freq,
target_update_freq,
update_step,
default_priority,
β_priority,
rng,
0.0f0,
)
end


Functors.functor(x::PrioritizedDQNLearner) =
(Q = x.approximator, Qₜ = x.target_approximator),
y -> begin
x = @set x.approximator = y.Q
x = @set x.target_approximator = y.Qₜ
x
end

"""

!!! note
The state of the observation is assumed to have been stacked,
if `!isnothing(stack_size)`.
"""
function (learner::PrioritizedDQNLearner)(env)
env |>
state |>
x ->
Flux.unsqueeze(x, ndims(x) + 1) |>
x ->
send_to_device(device(learner), x) |>
learner.approximator |>
vec |>
send_to_host
loss::Float32 = 0.0f0
end

function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
Q = learner.approximator
Qₜ = learner.target_approximator
γ = learner.sampler.γ
(L::PrioritizedDQNLearner)(s::AbstractArray) = L.approximator(s)

Functors.functor(x::PrioritizedDQNLearner) = (; approximator=x.approximator), y -> @set x.approximator = y.approximator

function RLBase.optimise!(
learner::PrioritizedDQNLearner,
batch::Union{
NamedTuple{(:key, :priority, SS′ART...)},
NamedTuple{(:key, :priority, SS′L′ART...)}
}
)
A = learner.approximator
Q = A.model.source
Qₜ = A.model.target
γ = learner.γ
β = learner.β_priority
loss_func = learner.loss_func
n = learner.sampler.n
batch_size = learner.sampler.batch_size
n = learner.n

D = device(Q)
s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
s, s′, a, r, t = map(x -> batch[x], SS′ART)
batch_size = length(a)
a = CartesianIndex.(a, 1:batch_size)
k, p = batch.key, batch.priority
p′ = similar(p)

updated_priorities = Vector{Float32}(undef, batch_size)
w = 1.0f0 ./ ((batch.priority .+ 1.0f-10) .^ β)
w = 1.0f0 ./ ((p .+ 1.0f-10) .^ β)
w ./= maximum(w)
w = send_to_device(D, w)

target_q = Qₜ(s′)
q′ = learner.is_enable_double_DQN ? Q(s′) : Qₜ(s′)

if haskey(batch, :next_legal_actions_mask)
l′ = send_to_device(D, batch[:next_legal_actions_mask])
target_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
q′ .+= ifelse.(batch[:next_legal_actions_mask], 0.0f0, typemin(Float32))
end

q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
G = r .+ γ^n .* (1 .- t) .* q′
q′ₐ = learner.is_enable_double_DQN ? Qₜ(s′)[dropdims(argmax(q′, dims=1), dims=1)] : dropdims(maximum(q′; dims=1), dims=1)

G = r .+ γ^n .* (1 .- t) .* q′ₐ

gs = gradient(params(Q)) do
q = Q(s)[a]
batch_losses = loss_func(G, q)
gs = gradient(params(A)) do
qₐ = Q(s)[a]
batch_losses = loss_func(G, qₐ)
loss = dot(vec(w), vec(batch_losses)) * 1 // batch_size
ignore() do
updated_priorities .= send_to_host(vec((batch_losses .+ 1.0f-10) .^ β))
p′ .= vec((batch_losses .+ 1.0f-10) .^ β)
learner.loss = loss
end
loss
end

update!(Q, gs)
updated_priorities
optimise!(A, gs)
k => p′
end

function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, trajectory::Trajectory)
for batch in trajectory
k, p = optimise!(policy, batch) |> send_to_host
trajectory[:priority, k] = p
end
end