Skip to content

Commit

Permalink
Added:
Browse files Browse the repository at this point in the history
- lots of fixes, and compat updates to the datapipe lines
- pickleable cache holder so that envs get reset when pickled
Removed:
- usages using blocks. They just seem to overly obsuficate the code
  • Loading branch information
josiahls committed Jul 29, 2023
1 parent 1c9b742 commit ec1cd4a
Show file tree
Hide file tree
Showing 20 changed files with 1,756 additions and 1,149 deletions.
44 changes: 38 additions & 6 deletions fastrl/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,30 @@
'fastrl/dataloader2_ext.py'),
'fastrl.dataloader2_ext.item_input_pipe_type': ( 'dataloader2_ext.html#item_input_pipe_type',
'fastrl/dataloader2_ext.py')},
'fastrl.datapipes.cacheholder': { 'fastrl.datapipes.cacheholder.PickleableInMemoryCacheHolderIterDataPipe': ( '01_DataPipes/cachholder.html#pickleableinmemorycacheholderiterdatapipe',
'fastrl/datapipes/cacheholder.py'),
'fastrl.datapipes.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__getstate__': ( '01_DataPipes/cachholder.html#pickleableinmemorycacheholderiterdatapipe.__getstate__',
'fastrl/datapipes/cacheholder.py'),
'fastrl.datapipes.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__init__': ( '01_DataPipes/cachholder.html#pickleableinmemorycacheholderiterdatapipe.__init__',
'fastrl/datapipes/cacheholder.py'),
'fastrl.datapipes.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__iter__': ( '01_DataPipes/cachholder.html#pickleableinmemorycacheholderiterdatapipe.__iter__',
'fastrl/datapipes/cacheholder.py'),
'fastrl.datapipes.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__len__': ( '01_DataPipes/cachholder.html#pickleableinmemorycacheholderiterdatapipe.__len__',
'fastrl/datapipes/cacheholder.py'),
'fastrl.datapipes.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__setstate__': ( '01_DataPipes/cachholder.html#pickleableinmemorycacheholderiterdatapipe.__setstate__',
'fastrl/datapipes/cacheholder.py')},
'fastrl.datapipes.pipes.iter.cacheholder': { 'fastrl.datapipes.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe',
'fastrl/datapipes/pipes/iter/cacheholder.py'),
'fastrl.datapipes.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__getstate__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__getstate__',
'fastrl/datapipes/pipes/iter/cacheholder.py'),
'fastrl.datapipes.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__init__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__init__',
'fastrl/datapipes/pipes/iter/cacheholder.py'),
'fastrl.datapipes.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__iter__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__iter__',
'fastrl/datapipes/pipes/iter/cacheholder.py'),
'fastrl.datapipes.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__len__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__len__',
'fastrl/datapipes/pipes/iter/cacheholder.py'),
'fastrl.datapipes.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__setstate__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__setstate__',
'fastrl/datapipes/pipes/iter/cacheholder.py')},
'fastrl.envs.gym': { 'fastrl.envs.gym.GymDataPipe': ('03_Environment/envs.gym.html#gymdatapipe', 'fastrl/envs/gym.py'),
'fastrl.envs.gym.GymStepper': ('03_Environment/envs.gym.html#gymstepper', 'fastrl/envs/gym.py'),
'fastrl.envs.gym.GymStepper.__init__': ( '03_Environment/envs.gym.html#gymstepper.__init__',
Expand Down Expand Up @@ -684,12 +708,8 @@
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.SimpleVSCodeVideoPlayer.show': ( '05_Logging/loggers.vscode_visualizers.html#simplevscodevideoplayer.show',
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.VSCodeTransformBlock': ( '05_Logging/loggers.vscode_visualizers.html#vscodetransformblock',
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.VSCodeTransformBlock.__call__': ( '05_Logging/loggers.vscode_visualizers.html#vscodetransformblock.__call__',
'fastrl/loggers/vscode_visualizers.py'),
'fastrl.loggers.vscode_visualizers.VSCodeTransformBlock.__init__': ( '05_Logging/loggers.vscode_visualizers.html#vscodetransformblock.__init__',
'fastrl/loggers/vscode_visualizers.py')},
'fastrl.loggers.vscode_visualizers.VSCodePipeline': ( '05_Logging/loggers.vscode_visualizers.html#vscodepipeline',
'fastrl/loggers/vscode_visualizers.py')},
'fastrl.memory.experience_replay': { 'fastrl.memory.experience_replay.ExperienceReplay': ( '04_Memory/memory.experience_replay.html#experiencereplay',
'fastrl/memory/experience_replay.py'),
'fastrl.memory.experience_replay.ExperienceReplay.__init__': ( '04_Memory/memory.experience_replay.html#experiencereplay.__init__',
Expand Down Expand Up @@ -721,6 +741,18 @@
'fastrl/pipes/core.py'),
'fastrl.pipes.core.find_dp': ('01_DataPipes/pipes.core.html#find_dp', 'fastrl/pipes/core.py'),
'fastrl.pipes.core.find_dps': ('01_DataPipes/pipes.core.html#find_dps', 'fastrl/pipes/core.py')},
'fastrl.pipes.iter.cacheholder': { 'fastrl.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe',
'fastrl/pipes/iter/cacheholder.py'),
'fastrl.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__getstate__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__getstate__',
'fastrl/pipes/iter/cacheholder.py'),
'fastrl.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__init__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__init__',
'fastrl/pipes/iter/cacheholder.py'),
'fastrl.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__iter__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__iter__',
'fastrl/pipes/iter/cacheholder.py'),
'fastrl.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__len__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__len__',
'fastrl/pipes/iter/cacheholder.py'),
'fastrl.pipes.iter.cacheholder.PickleableInMemoryCacheHolderIterDataPipe.__setstate__': ( '01_DataPipes/pipes.iter.cacheholder.html#pickleableinmemorycacheholderiterdatapipe.__setstate__',
'fastrl/pipes/iter/cacheholder.py')},
'fastrl.pipes.iter.firstlast': { 'fastrl.pipes.iter.firstlast.FirstLastMerger': ( '01_DataPipes/pipes.iter.firstlast.html#firstlastmerger',
'fastrl/pipes/iter/firstlast.py'),
'fastrl.pipes.iter.firstlast.FirstLastMerger.__init__': ( '01_DataPipes/pipes.iter.firstlast.html#firstlastmerger.__init__',
Expand Down
9 changes: 5 additions & 4 deletions fastrl/agents/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 2
# Python native modules
import os
from typing import Union
# Third party libs
# from fastcore.all import *
import torchdata.datapipes as dp
Expand All @@ -16,9 +17,9 @@
import numpy as np
# Local modules
# from fastrl.core import *
# from fastrl.pipes.core import *
# from fastrl.agents.core import *
# from fastrl.loggers.core import *
from ..pipes.core import find_dp
from .core import AgentBase
from ..loggers.core import LogCollector,Record
# from fastrl.torch_core import *

# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 4
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self,
self.decrement_on_val = decrement_on_val
self.select_on_val = select_on_val
self.ret_mask = ret_mask
self.agent_base = find_dp(traverse(self.source_datapipe,only_datapipe=True),AgentBase)
self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase)
self.step = 0
self.device = torch.device(device)

Expand Down
73 changes: 37 additions & 36 deletions fastrl/agents/dqn/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,56 @@
__all__ = ['DataPipeAugmentationFn', 'DQN', 'DQNAgent', 'QCalc', 'TargetCalc', 'LossCalc', 'ModelLearnCalc', 'LossCollector',
'DQNLearner']

# %% ../../../nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb 3
# %% ../../../nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb 2
# Python native modules
import os
from collections import deque
from typing import Callable
from typing import Callable,Optional,List
# Third party libs
from fastcore.all import *
from fastcore.all import ifnone
import torchdata.datapipes as dp
from torchdata.dataloader2 import DataLoader2
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
from torchdata.dataloader2.graph import find_dps,traverse,DataPipe
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *
from torch import optim
from torch import nn
import numpy as np
# Local modules
from ...core import *
from ..core import *
from ...pipes.core import *
from ...data.block import *
from ...dataloader2_ext import *
from ...memory.experience_replay import *
from ..core import *
from ..discrete import *
from ...loggers.core import *
from ...loggers.vscode_visualizers import *
from ...learner.core import *
from ...torch_core import *
from ...data.dataloader2 import *
# from fastrl.core import *
from ..core import AgentHead,AgentBase
from ...pipes.core import find_dp
# from fastrl.data.block import *
# from fastrl.dataloader2_ext import *
from ...memory.experience_replay import ExperienceReplay
from ..core import StepFieldSelector,SimpleModelRunner,NumpyConverter
from ..discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector
from fastrl.loggers.core import (
CacheLoggerBase,LogCollector,Record,LoggerBasePassThrough,BatchCollector,EpocherCollector,RollingTerminatedRewardCollector,EpisodeCollector
)
# from fastrl.loggers.vscode_visualizers import *
from ...learner.core import LearnerBase,LearnerHead,StepBatcher
from ...torch_core import Module
# from fastrl.data.dataloader2 import *

# %% ../../../nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb 6
# %% ../../../nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb 5
class DQN(Module):
def __init__(self,
state_sz:int, # The input dim of the state
action_sz:int, # The output dim of the actions
hidden=512, # Number of neurons connected between the 2 input/output layers
head_layer:Module=Linear, # DQN extensions such as Dueling DQNs have custom heads
activition_fn:Module=ReLU # The activiation fn used by `DQN`
head_layer:Module=nn.Linear, # DQN extensions such as Dueling DQNs have custom heads
activition_fn:Module=nn.ReLU # The activiation fn used by `DQN`
):
self.layers=Sequential(
Linear(state_sz,hidden),
self.layers=nn.Sequential(
nn.Linear(state_sz,hidden),
activition_fn(),
head_layer(hidden,action_sz),
)
def forward(self,x): return self.layers(x)


# %% ../../../nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb 8
# %% ../../../nbs/07_Agents/01_Discrete/12g_agents.dqn.basic.ipynb 7
DataPipeAugmentationFn = Callable[[DataPipe],Optional[DataPipe]]

def DQNAgent(
Expand All @@ -66,7 +67,7 @@ def DQNAgent(
)->AgentHead:
agent_base = AgentBase(model,logger_bases=ifnone(logger_bases,[CacheLoggerBase()]))
agent = StepFieldSelector(agent_base,field='state')
agent = InputInjester(agent)
# agent = InputInjester(agent)
agent = SimpleModelRunner(agent)
agent = ArgMaxer(agent)
agent = EpsilonSelector(agent,min_epsilon=min_epsilon,max_epsilon=max_epsilon,max_steps=max_steps,device=device)
Expand All @@ -88,7 +89,7 @@ def __init__(self,source_datapipe):
self.source_datapipe = source_datapipe

def __iter__(self):
self.learner = find_dp(traverse(self),LearnerBase)
self.learner = find_dp(traverse_dps(self),LearnerBase)
for batch in self.source_datapipe:
self.learner.done_mask = batch.terminated.reshape(-1,)
self.learner.next_q = self.learner.model(batch.next_state)
Expand All @@ -105,7 +106,7 @@ def __init__(self,source_datapipe,discount=0.99,nsteps=1):
self.learner = None

def __iter__(self):
self.learner = find_dp(traverse(self),LearnerBase)
self.learner = find_dp(traverse_dps(self),LearnerBase)
for batch in self.source_datapipe:
self.learner.targets = batch.reward+self.learner.next_q*(self.discount**self.nsteps)
self.learner.pred = self.learner.model(batch.state)
Expand All @@ -119,7 +120,7 @@ def __init__(self,source_datapipe,discount=0.99,nsteps=1):
self.source_datapipe = source_datapipe
self.discount = discount
self.nsteps = nsteps
self.learner = find_dp(traverse(self),LearnerBase)
self.learner = find_dp(traverse_dps(self),LearnerBase)

def __iter__(self):
for batch in self.source_datapipe:
Expand All @@ -132,7 +133,7 @@ def __init__(self,source_datapipe):
self.source_datapipe = source_datapipe

def __iter__(self):
self.learner = find_dp(traverse(self),LearnerBase)
self.learner = find_dp(traverse_dps(self),LearnerBase)
for batch in self.source_datapipe:
self.learner.loss_grad.backward()
self.learner.opt.step()
Expand All @@ -151,7 +152,7 @@ def __init__(self,
self.main_buffers = None

def __iter__(self):
self.learner = find_dp(traverse(self),LearnerBase)
self.learner = find_dp(traverse_dps(self),LearnerBase)
for i,steps in enumerate(self.source_datapipe):
# if i==0: self.push_header('loss')
for q in self.main_buffers: q.append(Record('loss',self.learner.loss.cpu().detach().numpy()))
Expand All @@ -161,9 +162,9 @@ def __iter__(self):
def DQNLearner(
model,
dls,
logger_bases=None,
loss_func=MSELoss(),
opt=AdamW,
logger_bases=(),
loss_func=nn.MSELoss(),
opt=optim.AdamW,
lr=0.005,
bs=128,
max_sz=10000,
Expand All @@ -176,7 +177,7 @@ def DQNLearner(
learner = LoggerBasePassThrough(learner,logger_bases)
learner = BatchCollector(learner,batch_on_pipe=LearnerBase)
learner = EpocherCollector(learner)
for logger_base in L(logger_bases): learner = logger_base.connect_source_datapipe(learner)
for logger_base in logger_bases: learner = logger_base.connect_source_datapipe(learner)
if logger_bases:
learner = RollingTerminatedRewardCollector(learner)
learner = EpisodeCollector(learner)
Expand Down
Empty file added fastrl/datapipes/__init__.py
Empty file.
Loading

0 comments on commit ec1cd4a

Please sign in to comment.