-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Creation of Policy Guided submodule (#34)
- Loading branch information
1 parent
1ef0500
commit c1513b6
Showing
6 changed files
with
71 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
module PolicyGuided | ||
|
||
using ..MonteCarlo: Action, Policy, Algorithm, Simulation | ||
using Random | ||
using Enzyme: autodiff, ReverseWithPrimal, Const, Duplicated | ||
using Zygote: withgradient | ||
|
||
include("gradients.jl") | ||
include("learning.jl") | ||
include("estimator.jl") | ||
include("update.jl") | ||
|
||
end |
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
struct PolicyGradientUpdate{P,O,VPR<:AbstractArray,VG<:AbstractArray} <: Algorithm | ||
pools::Vector{P} # Vector of independent pools (one for each system) | ||
optimisers::O # List of optimisers (one for each move) | ||
learn_ids::Vector{Int} # List of learnable moves | ||
parameters_list::VPR # List of current parameters values (one array for each move) | ||
gradients_data::VG # Gradient information (one for each move) | ||
|
||
function PolicyGradientUpdate(chains::Vector{S}, pools::Vector{P}, optimisers::O, gradients_data::VG) where {S,P,O,VG} | ||
# Safety checks | ||
@assert length(chains) == length(pools) | ||
@assert all(k -> all(move -> move.parameters == getindex.(pools, k)[1].parameters, getindex.(pools, k)), eachindex(pools[1])) | ||
@assert all(k -> all(move -> move.weight == getindex.(pools, k)[1].weight, getindex.(pools, k)), eachindex(pools[1])) | ||
@assert length(optimisers) == length(pools[1]) | ||
# Find learnable actions | ||
learn_ids = [k for k in eachindex(optimisers) if !isa(optimisers[k], Static)] | ||
#Make sure that all policies and parameters across chains refer to the same objects | ||
policy_list = [move.policy for move in pools[1]] | ||
parameters_list = [move.parameters for move in pools[1]] | ||
for pool in pools | ||
for k in eachindex(policy_list) | ||
pool[k].policy = policy_list[k] | ||
pool[k].parameters = parameters_list[k] | ||
end | ||
end | ||
return new{P,O,typeof(parameters_list),VG}(pools, optimisers, learn_ids, parameters_list, gradients_data) | ||
end | ||
|
||
end | ||
|
||
function PolicyGradientUpdate(chains, path, steps; dependencies=missing) | ||
@assert length(dependencies) == 1 | ||
@assert isa(dependencies[1], PolicyGradientEstimator) | ||
pge = dependencies[1] | ||
return PolicyGradientUpdate(chains, pge.pools, pge.optimisers, pge.gradients_data) | ||
end | ||
|
||
function make_step!(::Simulation, algorithm::PolicyGradientUpdate) | ||
for (k, lid) in enumerate(algorithm.learn_ids) | ||
gd = average(algorithm.gradients_data[k]) | ||
learning_step!(algorithm.parameters_list[lid], gd, algorithm.optimisers[lid]) | ||
algorithm.gradients_data[k] = initialise_gradient_data(algorithm.parameters_list[lid]) | ||
end | ||
return nothing | ||
end | ||
|
||
function write_algorithm(io, algorithm::PolicyGradientUpdate, scheduler) | ||
println(io, "\tPolicyGradientUpdate") | ||
println(io, "\t\tCalls: $(length(filter(x -> 0 < x ≤ scheduler[end], scheduler)))") | ||
println(io, "\t\tLearnable moves: $(algorithm.learn_ids)") | ||
println(io, "\t\tOptimisers:") | ||
for (k, opt) in enumerate(algorithm.optimisers) | ||
println(io, "\t\t\tMove $k: " * replace(string(opt), r"\{[^\{\}]*\}" => "")) | ||
end | ||
end | ||
|
||
nothing |