Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add middleware for ape-x structure pipeline #696

Merged
merged 14 commits into from
Aug 11, 2023
5 changes: 4 additions & 1 deletion ding/data/buffer/middleware/priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def __init__(

def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData:
if meta is None:
meta = {'priority': self.max_priority}
if 'priority' in data:
meta = {'priority': data['priority'].item()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do item operation in original priority producer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unittest for this new if-else branch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

meta = {'priority': self.max_priority}
else:
if 'priority' not in meta:
meta['priority'] = self.max_priority
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
from .learner import OffPolicyLearner, HERLearner
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
from .barrier import Barrier, BarrierRuntime
138 changes: 138 additions & 0 deletions ding/framework/middleware/distributer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from time import sleep, time
from dataclasses import fields
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union
Expand Down Expand Up @@ -287,3 +288,140 @@ def _send_callback(self, storage: Storage):
def __del__(self):
if self._model_loader:
self._model_loader.shutdown()


class PeriodicalModelExchanger:

def __init__(
self,
model: "Module",
mode: str,
period: int = 1,
delay_toleration: float = np.inf,
stale_toleration: int = 1,
event_name: str = "model_exchanger",
model_loader: Optional[ModelLoader] = None
) -> None:
"""
Overview:
Exchange model between processes periodically
Arguments:
- model (:obj:`torch.nn.Module`): Pytorch module.
- mode (:obj:`str`): "send" or "receive"
- period (:obj:`int`): Period of model exchange.
- delay_toleration (:obj:`float`): Delay toleration of model exchange.
- stale_toleration (:obj:`int`): Stale toleration of model exchange.
- event_name (:obj:`str`): Event name of model exchange.
- model_loader (:obj:`ModelLoader`): Encode model in subprocess.
"""
self._model = model
self._model_loader = model_loader
self._event_name = event_name
self._period = period
self.mode = mode
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not underline here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if self.mode == "receive":
self._id_counter = -1
self._model_id = -1
else:
self._id_counter = 0
self.stale_toleration = stale_toleration
self.model_stale = stale_toleration
self.delay_toleration = delay_toleration
self._state_dict_cache: Optional[Union[object, Storage]] = None

if self.mode == "receive":
task.on(self._event_name, self._cache_state_dict)
if model_loader:
task.once("finish", lambda _: model_loader.shutdown())

def _cache_state_dict(self, msg: Dict[str, Any]):
# msg: Dict {'id':id,'model':state_dict: Union[object, Storage]}
print(f"node_id[{task.router.node_id}] get model msg")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logging

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if msg['id'] % self._period == 0:
self._state_dict_cache = msg['model']
self._id_counter = msg['id']
self._time = msg['time']
else:
print(f"node_id[{task.router.node_id}] skip save cache")

def __new__(cls, *args, **kwargs):
return super(PeriodicalModelExchanger, cls).__new__(cls)

def __call__(self, ctx: "Context") -> Any:
if self._model_loader:
self._model_loader.start()

if self.mode == "receive":
print(f"node_id[{task.router.node_id}] try receive model")
if ctx.total_step != 0: # Skip first iteration
self._update_model()
else:
print(f"node_id[{task.router.node_id}] skip first iteration")
elif self.mode == "send":
yield
print(f"node_id[{task.router.node_id}] try send model")
if self._id_counter % self._period == 0:
self._send_model(id=self._id_counter)
print(f"node_id[{task.router.node_id}] model send [{self._id_counter}]")
self._id_counter += 1
else:
raise NotImplementedError

def _update_model(self):
start = time()
while True:
if task.finish:
return
if time() - start > 60:
logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id))
self.model_stale += 1
break
if self._state_dict_cache is None:
if self.model_stale < self.stale_toleration and time() - self._time < self.delay_toleration:
self.model_stale += 1
break
else:
sleep(0.01)
else:
#print(f"node_id[{task.router.node_id}] time diff {time()-self._time}")
if self._id_counter > self._model_id and time() - self._time < self.delay_toleration:
print(f"node_id[{task.router.node_id}] begin update")
if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None:
try:
self._model.load_state_dict(self._model_loader.load(self._state_dict_cache))
self._state_dict_cache = None
self._model_id = self._id_counter
self.model_stale = 1
break
except FileNotFoundError as e:
logging.warning(
"Model file has been deleted on node {}, maybe you can increase the ttl.".format(
task.router.node_id
)
)
self._state_dict_cache = None
continue
else:
self._model.load_state_dict(self._state_dict_cache)
self._state_dict_cache = None
self._model_id = self._id_counter
print(f"node_id[{task.router.node_id}] model updated")
self.model_stale = 1
break
else:
print(f"node_id[{task.router.node_id}] same id skip update")
self.model_stale += 1

def _send_model(self, id: int):
if self._model_loader:
self._model_loader.save(self._send_callback)
else:
task.emit(self._event_name, {'id': id, 'model': self._model.state_dict(), 'time': time()}, only_remote=True)

def _send_callback(self, storage: Storage):
if task.running:
task.emit(self._event_name, storage, only_remote=True)

def __del__(self):
if self._model_loader:
self._model_loader.shutdown()
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from .explorer import eps_greedy_handler, eps_greedy_masker
from .advantage_estimator import gae_estimator, ppof_adv_estimator
from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer

from .priority import priority_calculator
from .timer import epoch_timer
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _fetch_and_enhance(ctx: "OnlineRLContext"):

def nstep_reward_enhancer(cfg: EasyDict) -> Callable:

if task.router.is_active and not task.has_role(task.role.LEARNER):
if task.router.is_active and (not task.has_role(task.role.LEARNER) and not task.has_role(task.role.COLLECTOR)):
return task.void()

def _enhance(ctx: "OnlineRLContext"):
Expand Down
25 changes: 25 additions & 0 deletions ding/framework/middleware/functional/priority.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Callable
import torch
from ding.framework import task
from ding.framework import OnlineRLContext


def priority_calculator(func_for_priority_calculation: Callable) -> Callable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unittest for this file

"""
Overview:
The middleware that calculates the priority of the collected data.
Arguments:
- func_for_priority_calculation (:obj:`Callable`): The function that calculates \
the priority of the collected data.
"""

if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()

def _priority_calculator(ctx: "OnlineRLContext") -> None:

priority = func_for_priority_calculation(ctx.trajectories)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

priority_calculation_fn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for i in range(len(priority)):
ctx.trajectories[i]['priority'] = torch.tensor(priority[i], dtype=torch.float32).unsqueeze(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't transform it to tensor

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


return _priority_calculator
2 changes: 1 addition & 1 deletion ding/framework/middleware/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
"""
self.cfg = cfg
self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_))
self._trainer = task.wrap(trainer(cfg, policy))
self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq))
if reward_model is not None:
self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model))
else:
Expand Down
53 changes: 52 additions & 1 deletion ding/framework/middleware/tests/test_distributer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ding.data.storage_loader import FileStorageLoader
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger
from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
from ding.framework.parallel import Parallel
from ding.utils.default_helper import set_pkg_seed
from os import path
Expand Down Expand Up @@ -221,3 +221,54 @@ def pred(ctx):
@pytest.mark.tmp
def test_model_exchanger_with_model_loader():
Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader)


def periodical_model_exchanger_main():
with task.start(ctx=OnlineRLContext()):
set_pkg_seed(0, use_cuda=False)
policy = MockPolicy()
X = torch.rand(10)
y = torch.rand(10)

if task.router.node_id == 0:
task.add_role(task.role.LEARNER)
task.use(PeriodicalModelExchanger(policy._model, mode="send", period=3))
else:
task.add_role(task.role.COLLECTOR)
task.use(PeriodicalModelExchanger(policy._model, mode="receive", period=1, stale_toleration=3))

if task.has_role(task.role.LEARNER):

def train(ctx):
policy.train(X, y)
sleep(0.3)

task.use(train)
else:
y_pred1 = policy.predict(X)
print("y_pred1: ", y_pred1)
stale = 1

def pred(ctx):
nonlocal stale
y_pred2 = policy.predict(X)
print("y_pred2: ", y_pred2)
stale += 1
assert stale <= 3 or all(y_pred1 == y_pred2)
if any(y_pred1 != y_pred2):
stale = 1

sleep(0.3)

task.use(pred)
task.run(8)


@pytest.mark.tmp
def test_periodical_model_exchanger():
Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main)


if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unittest doesn't need this entry

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

#test_model_exchanger()
test_periodical_model_exchanger()
48 changes: 48 additions & 0 deletions ding/policy/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,54 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}

def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = False) -> Dict[str, Any]:
"""
Overview:
Calculate priority for replay buffer.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training.
Returns:
- priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
- optional: ``value_gamma``
ReturnsKeys:
- necessary: ``priority``
"""

if update_target_model:
self._target_model.load_state_dict(self._learn_model.state_dict())

data = default_preprocess_learn(
data,
use_priority=False,
use_priority_IS_weight=False,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Q-learning forward
# ====================
self._learn_model.eval()
self._target_model.eval()
with torch.no_grad():
# Current q value (main model)
q_value = self._learn_model.forward(data['obs'])['logit']
# Target q value
target_q_value = self._target_model.forward(data['next_obs'])['logit']
# Max q value action (main model), i.e. Double DQN
target_q_action = self._learn_model.forward(data['next_obs'])['action']
data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
value_gamma = data.get('value_gamma')
loss, td_error_per_sample = q_nstep_td_error(
data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
)
return {'priority': td_error_per_sample.abs().tolist()}


@POLICY_REGISTRY.register('dqn_stdim')
class DQNSTDIMPolicy(DQNPolicy):
Expand Down
Loading