Skip to content

Commit

Permalink
Lots of fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
josiahls committed Jul 29, 2023
1 parent a2a8004 commit 1c9b742
Show file tree
Hide file tree
Showing 19 changed files with 1,150 additions and 1,447 deletions.
18 changes: 4 additions & 14 deletions fastrl/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@
'fastrl/dataloader2_ext.py'),
'fastrl.dataloader2_ext.item_input_pipe_type': ( 'dataloader2_ext.html#item_input_pipe_type',
'fastrl/dataloader2_ext.py')},
'fastrl.envs.gym': { 'fastrl.envs.gym.GymStepper': ('03_Environment/envs.gym.html#gymstepper', 'fastrl/envs/gym.py'),
'fastrl.envs.gym': { 'fastrl.envs.gym.GymDataPipe': ('03_Environment/envs.gym.html#gymdatapipe', 'fastrl/envs/gym.py'),
'fastrl.envs.gym.GymStepper': ('03_Environment/envs.gym.html#gymstepper', 'fastrl/envs/gym.py'),
'fastrl.envs.gym.GymStepper.__init__': ( '03_Environment/envs.gym.html#gymstepper.__init__',
'fastrl/envs/gym.py'),
'fastrl.envs.gym.GymStepper.__iter__': ( '03_Environment/envs.gym.html#gymstepper.__iter__',
Expand Down Expand Up @@ -583,13 +584,7 @@
'fastrl/learner/core.py'),
'fastrl.learner.core.StepBatcher.vstack_by_fld': ( '06_Learning/learner.core.html#stepbatcher.vstack_by_fld',
'fastrl/learner/core.py')},
'fastrl.loggers.core': { 'fastrl.loggers.core.ActionPublish': ( '05_Logging/loggers.core.html#actionpublish',
'fastrl/loggers/core.py'),
'fastrl.loggers.core.ActionPublish.__init__': ( '05_Logging/loggers.core.html#actionpublish.__init__',
'fastrl/loggers/core.py'),
'fastrl.loggers.core.ActionPublish.__iter__': ( '05_Logging/loggers.core.html#actionpublish.__iter__',
'fastrl/loggers/core.py'),
'fastrl.loggers.core.BatchCollector': ( '05_Logging/loggers.core.html#batchcollector',
'fastrl.loggers.core': { 'fastrl.loggers.core.BatchCollector': ( '05_Logging/loggers.core.html#batchcollector',
'fastrl/loggers/core.py'),
'fastrl.loggers.core.BatchCollector.__init__': ( '05_Logging/loggers.core.html#batchcollector.__init__',
'fastrl/loggers/core.py'),
Expand Down Expand Up @@ -664,12 +659,7 @@
'fastrl.loggers.core.RollingTerminatedRewardCollector.reward_detach': ( '05_Logging/loggers.core.html#rollingterminatedrewardcollector.reward_detach',
'fastrl/loggers/core.py'),
'fastrl.loggers.core.RollingTerminatedRewardCollector.step2terminated': ( '05_Logging/loggers.core.html#rollingterminatedrewardcollector.step2terminated',
'fastrl/loggers/core.py'),
'fastrl.loggers.core.TestSync': ('05_Logging/loggers.core.html#testsync', 'fastrl/loggers/core.py'),
'fastrl.loggers.core.TestSync.__init__': ( '05_Logging/loggers.core.html#testsync.__init__',
'fastrl/loggers/core.py'),
'fastrl.loggers.core.TestSync.__iter__': ( '05_Logging/loggers.core.html#testsync.__iter__',
'fastrl/loggers/core.py')},
'fastrl/loggers/core.py')},
'fastrl.loggers.jupyter_visualizers': { 'fastrl.loggers.jupyter_visualizers.ImageCollector': ( '05_Logging/loggers.jupyter_visualizers.html#imagecollector',
'fastrl/loggers/jupyter_visualizers.py'),
'fastrl.loggers.jupyter_visualizers.ImageCollector.__iter__': ( '05_Logging/loggers.jupyter_visualizers.html#imagecollector.__iter__',
Expand Down
22 changes: 11 additions & 11 deletions fastrl/agents/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# %% auto 0
__all__ = ['AgentBase', 'AgentHead', 'SimpleModelRunner', 'StepFieldSelector', 'StepModelFeeder', 'NumpyConverter']

# %% ../../nbs/07_Agents/12a_agents.core.ipynb 3
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 2
# Python native modules
import os
from typing import List
Expand All @@ -12,13 +12,13 @@
import torchdata.datapipes as dp
import torch
from torch import nn
from torchdata.dataloader2.graph import find_dps,traverse_dps
from torchdata.dataloader2.graph import traverse_dps
# Local modules
from ..core import *
from ..torch_core import *
from ..pipes.core import *
from ..core import StepType,SimpleStep
from ..torch_core import evaluating,Module
from ..pipes.core import find_dps,find_dp

# %% ../../nbs/07_Agents/12a_agents.core.ipynb 5
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 4
class AgentBase(dp.iter.IterDataPipe):
def __init__(self,
model:nn.Module, # The base NN that we getting raw action values out of.
Expand Down Expand Up @@ -52,7 +52,7 @@ def __iter__(self):
to=torch.Tensor.to.__doc__
)

# %% ../../nbs/07_Agents/12a_agents.core.ipynb 6
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 5
class AgentHead(dp.iter.IterDataPipe):
def __init__(self,source_datapipe):
self.source_datapipe = source_datapipe
Expand Down Expand Up @@ -90,7 +90,7 @@ def create_step(self,**kwargs): return SimpleStep(**kwargs)
create_step="Creates the step used by the env for running, and used by the model for training."
)

# %% ../../nbs/07_Agents/12a_agents.core.ipynb 7
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 6
class SimpleModelRunner(dp.iter.IterDataPipe):
"Takes input from `source_datapipe` and pushes through the agent bases model assuming there is only one model field."
def __init__(self,
Expand All @@ -111,7 +111,7 @@ def __iter__(self):
res = self.agent_base.model(x)
yield res

# %% ../../nbs/07_Agents/12a_agents.core.ipynb 13
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 12
class StepFieldSelector(dp.iter.IterDataPipe):
"Grabs `field` from `source_datapipe` to push to the rest of the pipeline."
def __init__(self,
Expand All @@ -128,7 +128,7 @@ def __iter__(self):
raise Exception(f'Expected typing.NamedTuple object got {type(step)}\n{step}')
yield getattr(step,self.field)

# %% ../../nbs/07_Agents/12a_agents.core.ipynb 23
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 22
class StepModelFeeder(dp.iter.IterDataPipe):
def __init__(self,
source_datapipe, # next() must produce a `StepType`,
Expand Down Expand Up @@ -158,7 +158,7 @@ def __iter__(self):
)


# %% ../../nbs/07_Agents/12a_agents.core.ipynb 24
# %% ../../nbs/07_Agents/12a_agents.core.ipynb 23
class NumpyConverter(dp.iter.IterDataPipe):
debug=False

Expand Down
26 changes: 13 additions & 13 deletions fastrl/agents/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
# %% auto 0
__all__ = ['ArgMaxer', 'EpsilonSelector', 'EpsilonCollector', 'PyPrimativeConverter']

# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 3
# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 2
# Python native modules
import os
# Third party libs
from fastcore.all import *
# from fastcore.all import *
import torchdata.datapipes as dp
import torch
from torch.nn import *
# from torch.nn import *
import torch.nn.functional as F
from torchdata.dataloader2.graph import find_dps,traverse
from torchdata.dataloader2.graph import traverse_dps
import numpy as np
# Local modules
from ..core import *
from ..pipes.core import *
from .core import *
from ..loggers.core import *
from ..torch_core import *
# from fastrl.core import *
# from fastrl.pipes.core import *
# from fastrl.agents.core import *
# from fastrl.loggers.core import *
# from fastrl.torch_core import *

# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 5
# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 4
class ArgMaxer(dp.iter.IterDataPipe):
debug=False

Expand Down Expand Up @@ -49,7 +49,7 @@ def __iter__(self) -> torch.LongTensor:
yield step.long()


# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 9
# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 8
class EpsilonSelector(dp.iter.IterDataPipe):
debug=False
"Given input `Tensor` from `source_datapipe`."
Expand Down Expand Up @@ -116,7 +116,7 @@ def __iter__(self):

yield ((action,mask) if self.ret_mask else action)

# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 23
# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 22
class EpsilonCollector(LogCollector):
header:str='epsilon'
# def __init__(self,
Expand All @@ -133,7 +133,7 @@ def __iter__(self):
q.append(Record('epsilon',self.source_datapipe.epsilon))
yield action

# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 24
# %% ../../nbs/07_Agents/01_Discrete/12b_agents.discrete.ipynb 23
class PyPrimativeConverter(dp.iter.IterDataPipe):
debug=False

Expand Down
69 changes: 67 additions & 2 deletions fastrl/envs/gym.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03_Environment/05b_envs.gym.ipynb.

# %% auto 0
__all__ = ['GymStepper']
__all__ = ['GymStepper', 'GymDataPipe']

# %% ../../nbs/03_Environment/05b_envs.gym.ipynb 2
# Python native modules
import os
import warnings
from functools import partial
from typing import Callable, Any, Union, Iterable, Optional
# Third party libs
import gymnasium as gym
Expand All @@ -15,7 +16,7 @@
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,DataPipe,traverse_dps
from torchdata.dataloader2 import MultiProcessingReadingService
from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
# Local modules
Expand Down Expand Up @@ -146,3 +147,67 @@ def __iter__(self) -> SimpleStep:
env_reset="Resets a env given the env_id.",
no_agent_create_step="If there is no agent for creating the step output, then `GymStepper` will create its own"
)

# %% ../../nbs/03_Environment/05b_envs.gym.ipynb 52
def GymDataPipe(
agent:DataPipe, # An AgentHead
seed:Optional[int]=None, # The seed for the gym to use
# Used by `NStepper`, outputs tuples / chunks of assiciated steps
nsteps:int=1,
# Used by `NSkipper` to skip a certain number of steps (agent still gets called for each)
nskips:int=1,
# Whether when nsteps>1 to merge it into a single `StepType`
firstlast:bool=False,
# The batch size, which is different from `nsteps` in that firstlast will be
# run prior to batching, and a batch of steps might come from multiple envs,
# where nstep is associated with a single env
bs:int=1,
# The prefered default is for the pipeline to be infinate, and the learner
# decides how much to iter. If this is not None, then the pipeline will run for
# that number of `n`
n:Optional[int]=None,
# Whether to reset all the envs at the same time as opposed to reseting them
# the moment an episode ends.
synchronized_reset:bool=False,
# Should be used only for validation / logging, will grab a render of the gym
# and assign to the `StepType` image field. This data should not be used for training.
# If it images are needed for training, then you should wrap the env instead.
include_images:bool=False,
# If an environment truncates, terminate it.
terminate_on_truncation:bool=True
) -> Callable:
"Basic `gymnasium` `DataPipeGraph` with first-last, nstep, and nskip capability"

def pipe_init(source,as_dataloader=False,num_workers=0):
"This is the function that is actually run by `DataBlock`"
pipe = dp.map.Mapper(source)
if include_images:
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
else:
pipe = pipe.map(gym.make)
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle() # Cycle through the envs inf
pipe = GymStepper(pipe,agent=agent,seed=seed,
include_images=include_images,
terminate_on_truncation=terminate_on_truncation,
synchronized_reset=synchronized_reset)
if nskips!=1: pipe = NSkipper(pipe,n=nskips)
if nsteps!=1:
pipe = NStepper(pipe,n=nsteps)
if firstlast:
pipe = FirstLastMerger(pipe)
else:
pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger
if n is not None: pipe = pipe.header(limit=n)
pipe = pipe.batch(batch_size=bs)

if as_dataloader:
pipe = DataLoader2(
datapipe=pipe,
reading_service=MultiProcessingReadingService(
num_workers = num_workers
) if num_workers > 0 else None
)
return pipe
return pipe_init
42 changes: 23 additions & 19 deletions fastrl/learner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,30 @@
# %% auto 0
__all__ = ['LearnerBase', 'LearnerHead', 'StepBatcher']

# %% ../../nbs/06_Learning/10a_learner.core.ipynb 3
# %% ../../nbs/06_Learning/10a_learner.core.ipynb 2
# Python native modules
import os
from contextlib import contextmanager
from typing import *
from typing import List,Union
# Third party libs
from fastcore.all import *
from fastcore.all import add_docs
import torchdata.datapipes as dp
import torch
from ..torch_core import *
from torch import nn
from ..torch_core import evaluating
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.graph import find_dps,traverse,DataPipeGraph,Type,DataPipe
from torchdata.dataloader2.graph import traverse_dps,DataPipeGraph,DataPipe
# Local modules
from ..core import *
from ..torch_core import *
from ..pipes.core import *
from ..loggers.core import *
from ..data.dataloader2 import *
# from fastrl.core import *
# from fastrl.torch_core import *
from ..pipes.core import find_dps
from ..loggers.core import Record,EpocherCollector
# from fastrl.data.dataloader2 import *

# %% ../../nbs/06_Learning/10a_learner.core.ipynb 5
# %% ../../nbs/06_Learning/10a_learner.core.ipynb 4
class LearnerBase(dp.iter.IterDataPipe):
def __init__(self,
model:Module, # The base NN that we getting raw action values out of.
model:nn.Module, # The base NN that we getting raw action values out of.
dls:List[DataLoader2], # The dataloaders to read data from for training
device=None,
loss_func=None, # The loss function to use
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self,
self.batches = batches
self.infinite_dls = True
else:
self.batches = find_dp(traverse(dls[0].datapipe,only_datapipe=True),dp.iter.Header).limit
self.batches = find_dps(traverse_dps(dls[0].datapipe,only_datapipe=True),dp.iter.Header).limit

def __getstate__(self):
state = super().__getstate__()
Expand All @@ -77,8 +78,11 @@ def reset(self):

def increment_batch(self,value):
return not isinstance(value,
(Record,GetInputItemResponse)
(Record,)
)
# return not isinstance(value,
# (Record,GetInputItemResponse)
# )

def __iter__(self):
self.reset()
Expand Down Expand Up @@ -122,16 +126,16 @@ def __iter__(self):
increment_batch="Decides when a single batch is actually 'complete'."
)

# %% ../../nbs/06_Learning/10a_learner.core.ipynb 6
# %% ../../nbs/06_Learning/10a_learner.core.ipynb 5
class LearnerHead(dp.iter.IterDataPipe):
def __init__(self,source_datapipe):
self.source_datapipe = source_datapipe
self.learner_base = find_dp(traverse(self.source_datapipe),LearnerBase)
self.learner_base = find_dps(traverse_dps(self.source_datapipe),LearnerBase)

def __iter__(self): yield from self.source_datapipe

def fit(self,epochs):
epocher = find_dp(traverse(self),EpocherCollector)
epocher = find_dps(traverse_dps(self),EpocherCollector)
epocher.epochs = epochs

for iteration in self:
Expand All @@ -143,7 +147,7 @@ def validate(self,epochs=1,dl_idx=1) -> DataPipe:
for el in self.learner_base.iterable[dl_idx]:pass

pipe = self.learner_base.iterable[dl_idx].datapipe
return pipe.show() if hasattr(pipe,'show') else pip
return pipe.show() if hasattr(pipe,'show') else pipe

add_docs(
LearnerHead,
Expand All @@ -154,7 +158,7 @@ def validate(self,epochs=1,dl_idx=1) -> DataPipe:
`dl_idx` and returns the original datapipe for displaying."""
)

# %% ../../nbs/06_Learning/10a_learner.core.ipynb 14
# %% ../../nbs/06_Learning/10a_learner.core.ipynb 13
class StepBatcher(dp.iter.IterDataPipe):
def __init__(self,
source_datapipe,
Expand Down
Loading

0 comments on commit 1c9b742

Please sign in to comment.