Skip to content

Commit

Permalink
Creation of Policy Guided submodule (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
romainljsimon authored Feb 13, 2025
1 parent 1ef0500 commit c1513b6
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 7 deletions.
9 changes: 2 additions & 7 deletions src/MonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@ export Metropolis, callback_acceptance, StoreParameters
export build_schedule, StoreCallbacks, StoreTrajectories, StoreLastFrames, PrintTimeSteps
export Simulation, run!

using Enzyme: autodiff, ReverseWithPrimal, Const, Duplicated
using Zygote: withgradient

include("pgmc/gradients.jl")
include("pgmc/learning.jl")
include("pgmc/pgmc.jl")

include("PolicyGuided/PolicyGuided.jl")
using .PolicyGuided: Static, VPG, BLPG, BLAPG, NPG, ANPG, BLANPG, reward, PolicyGradientEstimator, PolicyGradientUpdate
export Static, VPG, BLPG, BLAPG, NPG, ANPG, BLANPG, reward
export PolicyGradientEstimator, PolicyGradientUpdate

Expand Down
13 changes: 13 additions & 0 deletions src/PolicyGuided/PolicyGuided.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module PolicyGuided

using ..MonteCarlo: Action, Policy, Algorithm, Simulation
using Random
using Enzyme: autodiff, ReverseWithPrimal, Const, Duplicated
using Zygote: withgradient

include("gradients.jl")
include("learning.jl")
include("estimator.jl")
include("update.jl")

end
File renamed without changes.
File renamed without changes.
File renamed without changes.
56 changes: 56 additions & 0 deletions src/PolicyGuided/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
struct PolicyGradientUpdate{P,O,VPR<:AbstractArray,VG<:AbstractArray} <: Algorithm
pools::Vector{P} # Vector of independent pools (one for each system)
optimisers::O # List of optimisers (one for each move)
learn_ids::Vector{Int} # List of learnable moves
parameters_list::VPR # List of current parameters values (one array for each move)
gradients_data::VG # Gradient information (one for each move)

function PolicyGradientUpdate(chains::Vector{S}, pools::Vector{P}, optimisers::O, gradients_data::VG) where {S,P,O,VG}
# Safety checks
@assert length(chains) == length(pools)
@assert all(k -> all(move -> move.parameters == getindex.(pools, k)[1].parameters, getindex.(pools, k)), eachindex(pools[1]))
@assert all(k -> all(move -> move.weight == getindex.(pools, k)[1].weight, getindex.(pools, k)), eachindex(pools[1]))
@assert length(optimisers) == length(pools[1])
# Find learnable actions
learn_ids = [k for k in eachindex(optimisers) if !isa(optimisers[k], Static)]
#Make sure that all policies and parameters across chains refer to the same objects
policy_list = [move.policy for move in pools[1]]
parameters_list = [move.parameters for move in pools[1]]
for pool in pools
for k in eachindex(policy_list)
pool[k].policy = policy_list[k]
pool[k].parameters = parameters_list[k]
end
end
return new{P,O,typeof(parameters_list),VG}(pools, optimisers, learn_ids, parameters_list, gradients_data)
end

end

function PolicyGradientUpdate(chains, path, steps; dependencies=missing)
@assert length(dependencies) == 1
@assert isa(dependencies[1], PolicyGradientEstimator)
pge = dependencies[1]
return PolicyGradientUpdate(chains, pge.pools, pge.optimisers, pge.gradients_data)
end

function make_step!(::Simulation, algorithm::PolicyGradientUpdate)
for (k, lid) in enumerate(algorithm.learn_ids)
gd = average(algorithm.gradients_data[k])
learning_step!(algorithm.parameters_list[lid], gd, algorithm.optimisers[lid])
algorithm.gradients_data[k] = initialise_gradient_data(algorithm.parameters_list[lid])
end
return nothing
end

function write_algorithm(io, algorithm::PolicyGradientUpdate, scheduler)
println(io, "\tPolicyGradientUpdate")
println(io, "\t\tCalls: $(length(filter(x -> 0 < x scheduler[end], scheduler)))")
println(io, "\t\tLearnable moves: $(algorithm.learn_ids)")
println(io, "\t\tOptimisers:")
for (k, opt) in enumerate(algorithm.optimisers)
println(io, "\t\t\tMove $k: " * replace(string(opt), r"\{[^\{\}]*\}" => ""))
end
end

nothing

0 comments on commit c1513b6

Please sign in to comment.