Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the definition of AbstractTraces #14

Merged
Merged
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"

[compat]
Expand Down
71 changes: 62 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,75 @@

## Design

A typical example of `Trajectory`:
The relationship of several concepts provided in this package:

![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png)
```
┌───────────────────────────────────┐
│ Trajectory │
│ ┌───────────────────────────────┐ │
│ │ AbstractTraces │ │
│ │ ┌───────────────┐ │ │
│ │ :trace_A => │ AbstractTrace │ │ │
│ │ └───────────────┘ │ │
│ │ │ │
│ │ ┌───────────────┐ │ │
│ │ :trace_B => │ AbstractTrace │ │ │
│ │ └───────────────┘ │ │
│ │ ... ... │ │
│ └───────────────────────────────┘ │
│ ┌───────────┐ │
│ │ Sampler │ │
│ └───────────┘ │
│ ┌────────────┐ │
│ │ Controller │ │
│ └────────────┘ │
└───────────────────────────────────┘
```

## `Trajectory`

A `Trajectory` contains 3 parts:

Exported APIs are:
- A `container` to store data. (Usually an `AbstractTraces`)
- A `sampler` to determine how to sample a batch from `container`
- A `controller` to decide when to sample a new batch from the `container`

Typical usage:

```julia
push!(trajectory; [trace_name=value]...)
append!(trajectory; [trace_name=value]...)
julia> t = Trajectory(Traces(a=Int[], b=Bool[]), BatchSampler(3), InsertSampleRatioControler(1.0, 3));

julia> for i in 1:5
push!(t, (a=i, b=iseven(i)))
end

for sample in trajectory
# consume samples from the trajectory
end
julia> for batch in t
println(batch)
end
(a = [4, 5, 1], b = Bool[1, 0, 0])
(a = [3, 2, 4], b = Bool[0, 1, 1])
(a = [4, 1, 2], b = Bool[1, 0, 1])
```

A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc.
**Traces**

- `Traces`
- `MultiplexTraces`
- `CircularSARTTraces`
- `Episode`
- `Episodes`

**Samplers**

- `BatchSampler`
findmyway marked this conversation as resolved.
Show resolved Hide resolved

**Controllers**

- `InsertSampleRatioController`
- `AsyncInsertSampleRatioController`
findmyway marked this conversation as resolved.
Show resolved Hide resolved


Please refer tests for common usage. (TODO: generate docs and add links to above data structures)

## Acknowledgement

Expand Down
5 changes: 0 additions & 5 deletions src/LastDimSlices.jl

This file was deleted.

3 changes: 2 additions & 1 deletion src/Trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module Trajectories

include("patch.jl")

include("traces.jl")
include("episodes.jl")
include("samplers.jl")
include("controlers.jl")
include("trajectory.jl")
Expand Down
14 changes: 12 additions & 2 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
export CircularArraySARTTraces

const CircularArraySARTTraces = Traces{
SSAART,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
}

function CircularArraySARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -12,8 +22,8 @@ function CircularArraySARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
Expand Down
17 changes: 14 additions & 3 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
export CircularArraySLARTTraces

const CircularArraySLARTTraces = Traces{
SSLLAART,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
}

function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
Expand All @@ -14,9 +25,9 @@ function CircularArraySLARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{(:state, :next_state)}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{(:legal_actions_mask, :next_legal_actions_mask)}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{(:action, :next_action)}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
Expand Down
7 changes: 7 additions & 0 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using CircularArrayBuffers

const SS = (:state, :next_state)
const LL = (:legal_actions_mask, :next_legal_actions_mask)
const AA = (:action, :next_action)
const RT = (:reward, :terminal)
const SSAART = (SS..., AA..., RT...)
const SSLLAART = (SS..., LL..., AA..., RT...)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
include("CircularArraySLARTTraces.jl")
92 changes: 0 additions & 92 deletions src/episodes.jl

This file was deleted.

3 changes: 3 additions & 0 deletions src/patch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import MLUtils

MLUtils.batch(x::AbstractArray{<:Number}) = x
9 changes: 2 additions & 7 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,9 @@ Uniformly sample a batch of examples for each trace.

See also [`sample`](@ref).
"""
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, transformer)
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=batch) = BatchSampler(batch_size, rng, transformer)

function sample(s::BatchSampler, t::AbstractTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
@view t[inds]
map(s.transformer, t[inds])
end

function sample(s::BatchSampler, e::Episodes)
inds = rand(s.rng, 1:length(t), s.batch_size)
batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer
end
Loading