Skip to content

WIP: PETS algorithm from facebook/mbrl #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions docs/experiments/experiments/Model Based/JuliaRL_PETS_CartPole.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
using ReinforcementLearning
using StableRNGs
using Flux
using IntervalSets

# Make wrapper to redefine reward and also allow for a query
# about reward and termination based on some state/action/steps.
# PETS needs both `is_terminated` and `reward` to support specific keywords in this fashion.
struct CartPoleWrapper{E<:AbstractEnv} <: AbstractEnvWrapper
env::E
end

function RLBase.is_terminated(env::CartPoleWrapper; current_state=state(env[!]), future_steps=0)
x, xdot, theta, thetadot = current_state
done = abs(x) > env[!].params.xthreshold ||
abs(theta) > env[!].params.thetathreshold ||
env[!].t + future_steps > env[!].params.max_steps
return done
end

function RLBase.reward(env::CartPoleWrapper; last_action=env[!].last_action, current_state=state(env[!]))
arm_length = 2 * env[!].params.halflength
x, xdot, theta, thetadot = current_state
target_dist = (x - arm_length * sin(theta))^2 + (arm_length + arm_length * cos(theta))^2
cost_pos = exp(-target_dist)
cost_act = 0.01 * sum(abs2, last_action)
return -(cost_pos + cost_act)
end

function RL.Experiment(
::Val{:JuliaRL},
::Val{:PETS},
::Val{:CartPole},
::Nothing;
save_dir = nothing,
seed = 123,
)
rng = StableRNG(seed)

inner_env = CartPoleEnv(T = Float32, continuous = true, rng = rng)
A = action_space(inner_env)
low = A.left
high = A.right

env = StateTransformedEnv(
ActionTransformedEnv(
CartPoleWrapper(inner_env);
action_mapping = x -> low + (x[1] + 1) * 0.5 * (high - low),
);
state_mapping = x -> [x[1:2]; sin(x[3]); cos(x[3]); x[4]], # TODO: this does not seem very efficient?
)

ns = length(state(env))
na = 1

init = glorot_uniform(rng)

T = 5000
hidden = 200

agent = Agent(
policy = PETSPolicy(
optimizer = CEMTrajectoryOptimizer(
lower_bound = [-1.0],
upper_bound = [1.0],
population = 500,
elite_ratio = 0.1,
iterations = 5,
horizon = 15,
α = 0.1,
rng = rng,
),
ensamble = [
NeuralNetworkApproximator(
model = GaussianNetwork(
pre = Chain(
Dense(ns + na, hidden, leakyrelu, init = init),
Dense(hidden, hidden, leakyrelu, init = init),
# Dense(200, 200, leakyrelu, init = init),
),
μ = Chain(Dense(hidden, ns, init = init)),
logσ = Chain(Dense(hidden, ns, init = init)),
clampfun = softclamp,
min_σ = 1f-5,
max_σ = 1f2,
),
optimizer = ADAM(7.5e-4),
) for _ in 1:5
],
batch_size = 256, # Can this be larger than update_after?
start_steps = 200,
start_policy = RandomPolicy(Space([-1.0..1.0 for _ in 1:na]); rng = rng),
update_after = 200,
update_freq = 50,
predict_reward = false,
rng = rng,
),
trajectory = CircularArraySARTTrajectory(
capacity = T,
state = Vector{Float32} => (ns,),
action = Vector{Float32} => (na,),
),
)

stop_condition = StopAfterStep(T, is_show_progress=!haskey(ENV, "CI"))
hook = ComposedHook(
TotalRewardPerEpisode(),
)
Experiment(agent, env, stop_condition, hook, "# Play CartPole with PETS")
end

#################################################################
########### TEMPORARY DEBUGGIN HELP #############################
#################################################################
using Plots, Statistics

ex = E`JuliaRL_PETS_CartPole`

mutable struct DataCollectionHook <: AbstractHook
states::Vector{Vector{Float64}}
actions::Vector{Float64}
losses::Vector{Float64}
end
DataCollectionHook() = DataCollectionHook(Vector{Vector{Float64}}(undef, 0), Vector{Float64}(undef, 0), Vector{Float64}(undef, 0))

function (hook::DataCollectionHook)(::PreActStage, agent, env, action)
push!(hook.states, state(env[!]))
push!(hook.actions, Float64(action[1]))
push!(hook.losses, Float64(mean(agent.policy.model_loss)))
end

ex.hook = ComposedHook(ex.hook.hooks..., DataCollectionHook(), StepsPerEpisode())

run(ex)

plot(ex.hook.hooks[end].steps)

actions = ex.hook.hooks[end-1].actions
plot(actions)

states = hcat(ex.hook.hooks[end-1].states...)
plot(states')

losses = ex.hook.hooks[end-1].losses;
plot(losses)
#################################################################
#################################################################
#################################################################

#+ tangle=false
using Plots
pyplot() #hide
ex = E`JuliaRL_PETS_CartPole`
run(ex)
plot(ex.hook.hooks[1].rewards)
savefig("assets/JuliaRL_PETS_CartPole.png") #hide

# ![](assets/JuliaRL_PETS_CartPole.png)
8 changes: 7 additions & 1 deletion src/ReinforcementLearningCore/src/extensions/Flux.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export glorot_uniform, glorot_normal, orthogonal
export glorot_uniform, glorot_normal, orthogonal, softclamp

import Flux: glorot_uniform, glorot_normal

Expand Down Expand Up @@ -29,3 +29,9 @@ function batch!(data, xs)
end
data
end

function softclamp(x, xmin, xmax)
x = xmax - softplus(xmax - x)
x = xmin + softplus(x - xmin)
return x
end
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,19 @@ end
#####

"""
GaussianNetwork(;pre=identity, μ, logσ, min_σ=0f0, max_σ=Inf32)
GaussianNetwork(;pre=identity, μ, logσ, min_σ=0f0, max_σ=Inf32, clampfun=clamp)

Returns `μ` and `logσ` when called. Create a distribution to sample from using
`Normal.(μ, exp.(logσ))`. `min_σ` and `max_σ` are used to clip the output from
`logσ`.
`logσ` using `clampfun` (typically `clamp` or `softclamp`).
"""
Base.@kwdef struct GaussianNetwork{P,U,S}
Base.@kwdef struct GaussianNetwork{P,U,S,F}
pre::P = identity
μ::U
logσ::S
min_σ::Float32 = 0f0
max_σ::Float32 = Inf32
clampfun::F = clamp
end

Flux.@functor GaussianNetwork
Expand All @@ -94,8 +95,8 @@ This function is compatible with a multidimensional action space. When outputtin
"""
function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
x = model.pre(state)
μ, raw_logσ = model.μ(x), model.logσ(x)
logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ))
μ, raw_logσ = model.μ(x), model.logσ(x)
logσ = model.clampfun.(raw_logσ, log(model.min_σ), log(model.max_σ))
if is_sampling
σ = exp.(logσ)
z = μ .+ σ .* send_to_device(device(model), randn(rng, Float32, size(μ)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ Base.show(io::IO, params::CartPoleEnvParams) = print(
join(["$p=$(getfield(params, p))" for p in fieldnames(CartPoleEnvParams)], ","),
)

mutable struct CartPoleEnv{T,R<:AbstractRNG} <: AbstractEnv
mutable struct CartPoleEnv{A,T,R<:AbstractRNG} <: AbstractEnv
params::CartPoleEnvParams{T}
action_space::A
state_space::Space{Vector{ClosedInterval{T}}}
state::Array{T,1}
action::Int
last_action::T
done::Bool
t::Int
rng::R
Expand Down Expand Up @@ -50,6 +52,7 @@ function CartPoleEnv(;
halflength = 0.5,
forcemag = 10.0,
max_steps = 200,
continuous::Bool = false,
dt = 0.02,
rng = Random.GLOBAL_RNG,
)
Expand All @@ -66,41 +69,53 @@ function CartPoleEnv(;
2.4,
max_steps,
)
high = cp = CartPoleEnv(params, zeros(T, 4), 2, false, 0, rng)
reset!(cp)
cp
high = T.([2 * params.xthreshold, 1e38, 2 * params.thetathreshold, 1e38])
action_space = continuous ? -forcemag..forcemag : Base.OneTo(2)
env = CartPoleEnv(
params,
action_space,
Space(ClosedInterval{T}.(-high, high)),
zeros(T, 4),
zero(T),
false,
0,
rng,
)
reset!(env)
env
end

CartPoleEnv{T}(; kwargs...) where {T} = CartPoleEnv(; T = T, kwargs...)

function RLBase.reset!(env::CartPoleEnv{T}) where {T<:Number}
env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05)
function RLBase.reset!(env::CartPoleEnv{A,T}) where {A,T<:Number}
env.state .= T(0.1) .* rand(env.rng, T, 4) .- T(0.05)
env.t = 0
env.action = 2
env.done = false
nothing
end

RLBase.action_space(env::CartPoleEnv) = Base.OneTo(2)
RLBase.action_space(env::CartPoleEnv) = env.action_space

RLBase.state_space(env::CartPoleEnv{T}) where {T} = Space(
ClosedInterval{T}[
(-2*env.params.xthreshold)..(2*env.params.xthreshold),
-1e38..1e38,
(-2*env.params.thetathreshold)..(2*env.params.thetathreshold),
-1e38..1e38,
],
)
RLBase.state_space(env::CartPoleEnv) = env.state_space

RLBase.reward(env::CartPoleEnv{T}) where {T} = env.done ? zero(T) : one(T)
RLBase.reward(env::CartPoleEnv{A,T}) where {A,T<:Number} = env.done ? zero(T) : one(T)
RLBase.is_terminated(env::CartPoleEnv) = env.done
RLBase.state(env::CartPoleEnv) = env.state

function (env::CartPoleEnv)(a)
@assert a in (1, 2)
env.action = a
env.t += 1
function (env::CartPoleEnv{<:ClosedInterval})(a)
@assert a in env.action_space
_step!(env, a)
end

function (env::CartPoleEnv{<:Base.OneTo})(a)
@assert a in env.action_space
force = a == 2 ? env.params.forcemag : -env.params.forcemag
_step!(env, force)
end

function _step!(env::CartPoleEnv, force)
env.last_action = force
env.t += 1
x, xdot, theta, thetadot = env.state
costheta = cos(theta)
sintheta = sin(theta)
Expand Down
4 changes: 2 additions & 2 deletions src/ReinforcementLearningZoo/src/ReinforcementLearningZoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ using Zygote: ignore, @ignore
using Flux
using Flux: onehot, normalise
using StatsBase
using StatsBase: sample, Weights, mean
using StatsBase: sample, Weights, mean, std
using LinearAlgebra: dot
using MacroTools
using Distributions: Categorical, Normal, logpdf
using Distributions: Categorical, Normal, TruncatedNormal, logpdf
using StructArrays


Expand Down
1 change: 1 addition & 0 deletions src/ReinforcementLearningZoo/src/algorithms/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ include("cfr/cfr.jl")
include("offline_rl/offline_rl.jl")
include("nfsp/abstract_nfsp.jl")
include("exploitability_descent/exploitability_descent.jl")
include("model_based/model_based.jl")
Loading