Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the definition of AbstractTraces #14

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
coverage/
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml

.DS_Store
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"

[compat]
Expand Down
74 changes: 65 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,78 @@

## Design

A typical example of `Trajectory`:
The relationship of several concepts provided in this package:

![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png)
```
┌───────────────────────────────────┐
│ Trajectory │
│ ┌───────────────────────────────┐ │
│ │ AbstractTraces │ │
│ │ ┌───────────────┐ │ │
│ │ :trace_A => │ AbstractTrace │ │ │
│ │ └───────────────┘ │ │
│ │ │ │
│ │ ┌───────────────┐ │ │
│ │ :trace_B => │ AbstractTrace │ │ │
│ │ └───────────────┘ │ │
│ │ ... ... │ │
│ └───────────────────────────────┘ │
│ ┌───────────┐ │
│ │ Sampler │ │
│ └───────────┘ │
│ ┌────────────┐ │
│ │ Controller │ │
│ └────────────┘ │
└───────────────────────────────────┘
```

## `Trajectory`

A `Trajectory` contains 3 parts:

Exported APIs are:
- A `container` to store data. (Usually an `AbstractTraces`)
- A `sampler` to determine how to sample a batch from `container`
- A `controller` to decide when to sample a new batch from the `container`

Typical usage:

```julia
push!(trajectory; [trace_name=value]...)
append!(trajectory; [trace_name=value]...)
julia> t = Trajectory(Traces(a=Int[], b=Bool[]), BatchSampler(3), InsertSampleRatioControler(1.0, 3));

julia> for i in 1:5
push!(t, (a=i, b=iseven(i)))
end

for sample in trajectory
# consume samples from the trajectory
end
julia> for batch in t
println(batch)
end
(a = [4, 5, 1], b = Bool[1, 0, 0])
(a = [3, 2, 4], b = Bool[0, 1, 1])
(a = [4, 1, 2], b = Bool[1, 0, 1])
```

A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc.
**Traces**

- `Traces`
- `MultiplexTraces`
- `CircularSARTTraces`
- `Episode`
- `Episodes`

**Samplers**

- `BatchSampler`
findmyway marked this conversation as resolved.
Show resolved Hide resolved
- `MetaSampler`
- `MultiBatchSampler`

**Controllers**

- `InsertSampleRatioController`
- `InsertSampleController`
- `AsyncInsertSampleRatioController`


Please refer tests for common usage. (TODO: generate docs and add links to above data structures)

## Acknowledgement

Expand Down
6 changes: 3 additions & 3 deletions src/Trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module Trajectories

include("patch.jl")

include("traces.jl")
include("samplers.jl")
include("controllers.jl")
include("traces.jl")
include("episodes.jl")
include("trajectory.jl")
include("rendering.jl")
include("common/common.jl")

end
33 changes: 5 additions & 28 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
export CircularArraySARTTraces

const CircularArraySARTTraces = Traces{
SART,
SSAART,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -23,32 +22,10 @@ function CircularArraySARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_action=t[:state][inds′]
) |> s.transformer
end

function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:action], x[:action])
end
43 changes: 8 additions & 35 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
export CircularArraySLARTTraces

const CircularArraySLARTTraces = Traces{
SLART,
SSLLAART,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -26,37 +25,11 @@ function CircularArraySLARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function sample(s::BatchSampler, t::CircularArraySLARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
legal_actions_mask=t[:legal_actions_mask][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_legal_actions_mask=t[:legal_actions_mask][inds′],
next_action=t[:state][inds′]
) |> s.transformer
end

function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:legal_actions_mask])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:legal_actions_mask], x[:legal_actions_mask])
push!(t[:action], x[:action])
end
end
11 changes: 5 additions & 6 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
using CircularArrayBuffers

const SA = (:state, :action)
const SLA = (:state, :legal_actions_mask, :action)
const SS = (:state, :next_state)
const LL = (:legal_actions_mask, :next_legal_actions_mask)
const AA = (:action, :next_action)
const RT = (:reward, :terminal)
const SART = (:state, :action, :reward, :terminal)
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action)
const SSAART = (SS..., AA..., RT...)
const SSLLAART = (SS..., LL..., AA..., RT...)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
Expand Down
107 changes: 0 additions & 107 deletions src/episodes.jl

This file was deleted.

3 changes: 3 additions & 0 deletions src/patch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import MLUtils

MLUtils.batch(x::AbstractArray{<:Number}) = x
Loading