Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Flux = "0.11"
IntervalSets = "0.5"
MacroTools = "0.5"
ReinforcementLearningBase = "0.9"
ReinforcementLearningCore = "0.7"
ReinforcementLearningCore = "0.7.2"
Requires = "1"
Setfield = "0.6, 0.7"
StableRNGs = "1.0"
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ This project aims to provide some implementations of the most typical reinforcem
- SAC
- CFR/OS-MCCFR/ES-MCCFR/DeepCFR
- Minimax
- Behavior Cloning

If you are looking for tabular reinforcement learning algorithms, you may refer [ReinforcementLearningAnIntroduction.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningAnIntroduction.jl).

Expand Down Expand Up @@ -58,6 +59,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
- ``E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)` ``
- ``E`JuliaRL_DeepCFR_OpenSpiel(leduc_poker)` ``
- ``E`JuliaRL_DQN_SnakeGame` ``
- ``E`JuliaRL_BC_CartPole` ``
- ``E`Dopamine_DQN_Atari(pong)` ``
- ``E`Dopamine_Rainbow_Atari(pong)` ``
- ``E`Dopamine_IQN_Atari(pong)` ``
Expand Down
1 change: 1 addition & 0 deletions src/algorithms/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ include("dqns/dqns.jl")
include("policy_gradient/policy_gradient.jl")
include("searching/searching.jl")
include("cfr/cfr.jl")
include("offline_rl/offline_rl.jl")
33 changes: 33 additions & 0 deletions src/algorithms/offline_rl/behavior_cloning.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
export BehaviorCloningPolicy

"""
BehaviorCloningPolicy(;kw...)

# Keyword Arguments

- `approximator`: calculate the logits of possible actions directly
- `explorer=GreedyExplorer()`

"""
Base.@kwdef struct BehaviorCloningPolicy{A} <: AbstractPolicy
approximator::A
explorer::Any = GreedyExplorer()
end

function (p::BehaviorCloningPolicy)(env::AbstractEnv)
s = state(env)
s_batch = Flux.unsqueeze(s, ndims(s) + 1)
logits = p.approximator(s_batch) |> vec # drop dimension
p.explorer(logits)
end

function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :action)})
s, a = batch.state, batch.action
m = p.approximator
gs = gradient(params(m)) do
ŷ = m(s)
y = Flux.onehotbatch(a, axes(ŷ, 1))
logitcrossentropy(ŷ, y)
end
update!(m, gs)
end
1 change: 1 addition & 0 deletions src/algorithms/offline_rl/offline_rl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include("behavior_cloning.jl")
85 changes: 85 additions & 0 deletions src/experiments/rl_envs/JuliaRL_BC_CartPole.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
Base.@kwdef struct RecordStateAction <: AbstractHook
records::Any = VectorSATrajectory(;state=Vector{Float32})
end

function (h::RecordStateAction)(::PreActStage, policy, env, action)
push!(h.records;state=copy(state(env)), action=action)
end

function RLCore.Experiment(
::Val{:JuliaRL},
::Val{:BC},
::Val{:CartPole},
::Nothing;
seed = 123,
save_dir = nothing,
)
rng = StableRNG(seed)

env = CartPoleEnv(; T = Float32, rng = rng)
ns, na = length(state(env)), length(action_space(env))
agent = Agent(
policy = QBasedPolicy(
learner = BasicDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
Dense(128, 128, relu; initW = glorot_uniform(rng)),
Dense(128, na; initW = glorot_uniform(rng)),
) |> cpu,
optimizer = ADAM(),
),
batch_size = 32,
min_replay_history = 100,
loss_func = huber_loss,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
),
),
trajectory = CircularArraySARTTrajectory(
capacity = 1000,
state = Vector{Float32} => (ns,),
),
)

stop_condition = StopAfterStep(10_000)
hook = RecordStateAction()
run(agent, env, stop_condition, hook)

bc = BehaviorCloningPolicy(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
Dense(128, 128, relu; initW = glorot_uniform(rng)),
Dense(128, na; initW = glorot_uniform(rng)),
) |> cpu,
optimizer = ADAM(),
)
)

s = BatchSampler{(:state, :action)}(32;)

for i in 1:300
_, batch = s(hook.records)
RLBase.update!(bc, batch)
end

description = """
# Behavior Cloning with CartPole

This experiment uses transitions during the experiment
`JuliaRL_BasicDQN_CartPole` to train a behavior policy.
"""

hook = ComposedHook(
TotalRewardPerEpisode(),
TimePerStep(),
)

Experiment(bc, env, StopAfterEpisode(100), hook, description)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

@testset "training" begin
mktempdir() do dir
for method in (:BasicDQN, :DQN, :PrioritizedDQN, :Rainbow, :IQN, :VPG)
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :IQN, :VPG)
res = run(Experiment(
Val(:JuliaRL),
Val(method),
Expand Down