This repository was archived by the owner on May 6, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 18
PG policy #87
Merged
Merged
PG policy #87
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
c48f0a0
Implemented Reinforce policy gradient.
norci 79aa893
refactor
norci 590ac95
bug fix
norci 3947c39
many changes
norci 3e4e1e9
bug fix in pg.jl
norci 2267097
updated logging.
norci e4e9a9b
in PGPolicy, added step size.
norci e82c6fc
minor change
norci c396bd7
rename PG to VPG
norci a324266
minor change
norci d9b0285
added baseline in vpg.
norci 3daf6ee
fixed a bug in vpg's baseline update
norci e497134
bug fix in baseline update
norci 1b68bcf
bug fix. send data to dev.
norci 6258ee0
refactor. for baseline
norci 597a239
.
norci fed721f
added mini batch in vpg
norci 1f824d0
updated vpg for continuous action space
norci d1d31cc
updated test
norci 1cc3097
added debug log in vpg Pendulum experiment
norci d69e0b0
added discrete action Pendulum experiment
norci 9f7d9a4
refactor vpg. added docs.
bb161d6
added action_space into vpg
ffc1760
removed @views. fixed bugs for gpu.
b121b90
updated experiment Pendulum
3ad1c5f
update doc
cd1d5bd
updated comments
norci File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,3 +3,4 @@ include("ppo.jl") | |
| include("A2CGAE.jl") | ||
| include("ddpg.jl") | ||
| include("sac.jl") | ||
| include("vpg.jl") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| using Flux: normalise | ||
| using Random: shuffle | ||
|
|
||
| using ReinforcementLearningBase | ||
| using ReinforcementLearningCore | ||
|
|
||
| export VPGPolicy, GaussianNetwork | ||
|
|
||
| struct GaussianNetwork | ||
| pre::Chain | ||
| μ::Chain | ||
| σ::Chain | ||
| end | ||
| Flux.@functor GaussianNetwork | ||
| function (m::GaussianNetwork)(S) | ||
| x = m.pre(S) | ||
| m.μ(x), m.σ(x) .|> exp | ||
| end | ||
|
|
||
| """ | ||
| Vanilla Policy Gradient | ||
|
|
||
| VPGPolicy(;kwargs) | ||
|
|
||
| # Keyword arguments | ||
| - `approximator`, | ||
| - `baseline`, | ||
| - `dist`, distribution function of the action | ||
| - `γ`, discount factor | ||
| - `α_θ`, step size of policy parameter | ||
| - `α_w`, step size of baseline parameter | ||
| - `batch_size`, | ||
| - `rng`, | ||
| - `loss`, | ||
| - `baseline_loss`, | ||
|
|
||
|
|
||
| if the action space is continuous, | ||
| then the env should transform the action value, (such as using tanh), | ||
| in order to make sure low ≤ value ≤ high | ||
| """ | ||
| Base.@kwdef mutable struct VPGPolicy{ | ||
| A<:NeuralNetworkApproximator, | ||
| B<:Union{NeuralNetworkApproximator,Nothing}, | ||
| S<:AbstractSpace, | ||
| R<:AbstractRNG, | ||
| } <: AbstractPolicy | ||
| approximator::A | ||
| baseline::B = nothing | ||
| action_space::S | ||
| dist::Any | ||
| γ::Float32 = 0.99f0 # discount factor | ||
| α_θ = 1.0f0 # step size of policy | ||
| α_w = 1.0f0 # step size of baseline | ||
| batch_size::Int = 1024 | ||
| rng::R = Random.GLOBAL_RNG | ||
| loss::Float32 = 0.0f0 | ||
| baseline_loss::Float32 = 0.0f0 | ||
| end | ||
|
|
||
| """ | ||
| About continuous action space, see | ||
| * [Diagonal Gaussian Policies](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html#stochastic-policies | ||
| * [Clipped Action Policy Gradient](https://arxiv.org/pdf/1802.07564.pdf) | ||
| """ | ||
|
|
||
| function (π::VPGPolicy)(env::AbstractEnv) | ||
| to_dev(x) = send_to_device(device(π.approximator), x) | ||
|
|
||
| logits = env |> get_state |> to_dev |> π.approximator | ||
|
|
||
| if π.action_space isa DiscreteSpace | ||
| dist = logits |> softmax |> π.dist | ||
| action = π.action_space[rand(π.rng, dist)] | ||
| elseif π.action_space isa ContinuousSpace | ||
| dist = π.dist.(logits...) | ||
| action = rand.(π.rng, dist)[1] | ||
| else | ||
| error("not implemented") | ||
| end | ||
| action | ||
| end | ||
|
|
||
| function (π::VPGPolicy)(env::MultiThreadEnv) | ||
| error("not implemented") | ||
| # TODO: can PG support multi env? PG only get updated at the end of an episode. | ||
| end | ||
|
|
||
| function RLBase.update!(π::VPGPolicy, traj::ElasticCompactSARTSATrajectory) | ||
| (length(traj[:terminal]) > 0 && traj[:terminal][end]) || return | ||
|
|
||
| model = π.approximator | ||
| to_dev(x) = send_to_device(device(model), x) | ||
|
|
||
| states = traj[:state] | ||
| actions = traj[:action] |> Array # need to convert ElasticArray to Array, or code will fail on gpu. `log_prob[CartesianIndex.(A, 1:length(A))` | ||
| gains = traj[:reward] |> x -> discount_rewards(x, π.γ) | ||
|
|
||
| for idx in Iterators.partition(shuffle(1:length(traj[:terminal])), π.batch_size) | ||
| S = select_last_dim(states, idx) |> to_dev | ||
| A = actions[idx] | ||
| G = gains[idx] |> x -> Flux.unsqueeze(x, 1) |> to_dev | ||
| # gains is a 1 colomn array, but the ouput of flux model is 1 row, n_batch columns array. so unsqueeze it. | ||
|
|
||
| if π.baseline isa NeuralNetworkApproximator | ||
| gs = gradient(Flux.params(π.baseline)) do | ||
| δ = G - π.baseline(S) | ||
| loss = mean(δ .^ 2) * π.α_w # mse | ||
| ignore() do | ||
| π.baseline_loss = loss | ||
| end | ||
| loss | ||
| end | ||
| update!(π.baseline, gs) | ||
| elseif π.baseline isa Nothing | ||
| # Normalization. See | ||
| # (http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/hw2_final.pdf) | ||
| # (https://web.stanford.edu/class/cs234/assignment3/solution.pdf) | ||
| # normalise should not be used with baseline. or the loss of the policy will be too small. | ||
| δ = G |> x -> normalise(x; dims = 2) | ||
| end | ||
|
|
||
| gs = gradient(Flux.params(model)) do | ||
| if π.action_space isa DiscreteSpace | ||
| log_prob = S |> model |> logsoftmax | ||
| log_probₐ = log_prob[CartesianIndex.(A, 1:length(A))] | ||
| elseif π.action_space isa ContinuousSpace | ||
| dist = π.dist.(model(S)...) # TODO: this part does not work on GPU. See: https://github.com/JuliaStats/Distributions.jl/issues/1183 . | ||
| log_probₐ = logpdf.(dist, A) | ||
| end | ||
| loss = -mean(log_probₐ .* δ) * π.α_θ | ||
| ignore() do | ||
| π.loss = loss | ||
| end | ||
| loss | ||
| end | ||
| update!(model, gs) | ||
| end | ||
| empty!(traj) | ||
| end | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.