-
Notifications
You must be signed in to change notification settings - Fork 370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(zjow): add middleware for ape-x structure pipeline #696
Changes from 4 commits
88cfb0c
c22e8ee
455134d
980f18e
b921ea2
fdb8030
0501145
9e31feb
be3e60c
213e41e
37ec816
32a27db
f6dda29
e1a7f84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,10 @@ def __init__( | |
|
||
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData: | ||
if meta is None: | ||
meta = {'priority': self.max_priority} | ||
if 'priority' in data: | ||
meta = {'priority': data['priority'].item()} | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add unittest for this new if-else branch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
meta = {'priority': self.max_priority} | ||
else: | ||
if 'priority' not in meta: | ||
meta['priority'] = self.max_priority | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import numpy as np | ||
from time import sleep, time | ||
from dataclasses import fields | ||
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union | ||
|
@@ -287,3 +288,140 @@ def _send_callback(self, storage: Storage): | |
def __del__(self): | ||
if self._model_loader: | ||
self._model_loader.shutdown() | ||
|
||
|
||
class PeriodicalModelExchanger: | ||
|
||
def __init__( | ||
self, | ||
model: "Module", | ||
mode: str, | ||
period: int = 1, | ||
delay_toleration: float = np.inf, | ||
stale_toleration: int = 1, | ||
event_name: str = "model_exchanger", | ||
model_loader: Optional[ModelLoader] = None | ||
) -> None: | ||
""" | ||
Overview: | ||
Exchange model between processes periodically | ||
Arguments: | ||
- model (:obj:`torch.nn.Module`): Pytorch module. | ||
- mode (:obj:`str`): "send" or "receive" | ||
- period (:obj:`int`): Period of model exchange. | ||
- delay_toleration (:obj:`float`): Delay toleration of model exchange. | ||
- stale_toleration (:obj:`int`): Stale toleration of model exchange. | ||
- event_name (:obj:`str`): Event name of model exchange. | ||
- model_loader (:obj:`ModelLoader`): Encode model in subprocess. | ||
""" | ||
self._model = model | ||
self._model_loader = model_loader | ||
self._event_name = event_name | ||
self._period = period | ||
self.mode = mode | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not underline here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
if self.mode == "receive": | ||
self._id_counter = -1 | ||
self._model_id = -1 | ||
else: | ||
self._id_counter = 0 | ||
self.stale_toleration = stale_toleration | ||
self.model_stale = stale_toleration | ||
self.delay_toleration = delay_toleration | ||
self._state_dict_cache: Optional[Union[object, Storage]] = None | ||
|
||
if self.mode == "receive": | ||
task.on(self._event_name, self._cache_state_dict) | ||
if model_loader: | ||
task.once("finish", lambda _: model_loader.shutdown()) | ||
|
||
def _cache_state_dict(self, msg: Dict[str, Any]): | ||
# msg: Dict {'id':id,'model':state_dict: Union[object, Storage]} | ||
print(f"node_id[{task.router.node_id}] get model msg") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use logging There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
if msg['id'] % self._period == 0: | ||
self._state_dict_cache = msg['model'] | ||
self._id_counter = msg['id'] | ||
self._time = msg['time'] | ||
else: | ||
print(f"node_id[{task.router.node_id}] skip save cache") | ||
|
||
def __new__(cls, *args, **kwargs): | ||
return super(PeriodicalModelExchanger, cls).__new__(cls) | ||
|
||
def __call__(self, ctx: "Context") -> Any: | ||
if self._model_loader: | ||
self._model_loader.start() | ||
|
||
if self.mode == "receive": | ||
print(f"node_id[{task.router.node_id}] try receive model") | ||
if ctx.total_step != 0: # Skip first iteration | ||
self._update_model() | ||
else: | ||
print(f"node_id[{task.router.node_id}] skip first iteration") | ||
elif self.mode == "send": | ||
yield | ||
print(f"node_id[{task.router.node_id}] try send model") | ||
if self._id_counter % self._period == 0: | ||
self._send_model(id=self._id_counter) | ||
print(f"node_id[{task.router.node_id}] model send [{self._id_counter}]") | ||
self._id_counter += 1 | ||
else: | ||
raise NotImplementedError | ||
|
||
def _update_model(self): | ||
start = time() | ||
while True: | ||
if task.finish: | ||
return | ||
if time() - start > 60: | ||
logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) | ||
self.model_stale += 1 | ||
break | ||
if self._state_dict_cache is None: | ||
if self.model_stale < self.stale_toleration and time() - self._time < self.delay_toleration: | ||
self.model_stale += 1 | ||
break | ||
else: | ||
sleep(0.01) | ||
else: | ||
#print(f"node_id[{task.router.node_id}] time diff {time()-self._time}") | ||
if self._id_counter > self._model_id and time() - self._time < self.delay_toleration: | ||
print(f"node_id[{task.router.node_id}] begin update") | ||
if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: | ||
try: | ||
self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) | ||
self._state_dict_cache = None | ||
self._model_id = self._id_counter | ||
self.model_stale = 1 | ||
break | ||
except FileNotFoundError as e: | ||
logging.warning( | ||
"Model file has been deleted on node {}, maybe you can increase the ttl.".format( | ||
task.router.node_id | ||
) | ||
) | ||
self._state_dict_cache = None | ||
continue | ||
else: | ||
self._model.load_state_dict(self._state_dict_cache) | ||
self._state_dict_cache = None | ||
self._model_id = self._id_counter | ||
print(f"node_id[{task.router.node_id}] model updated") | ||
self.model_stale = 1 | ||
break | ||
else: | ||
print(f"node_id[{task.router.node_id}] same id skip update") | ||
self.model_stale += 1 | ||
|
||
def _send_model(self, id: int): | ||
if self._model_loader: | ||
self._model_loader.save(self._send_callback) | ||
else: | ||
task.emit(self._event_name, {'id': id, 'model': self._model.state_dict(), 'time': time()}, only_remote=True) | ||
|
||
def _send_callback(self, storage: Storage): | ||
if task.running: | ||
task.emit(self._event_name, storage, only_remote=True) | ||
|
||
def __del__(self): | ||
if self._model_loader: | ||
self._model_loader.shutdown() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Callable | ||
import torch | ||
from ding.framework import task | ||
from ding.framework import OnlineRLContext | ||
|
||
|
||
def priority_calculator(func_for_priority_calculation: Callable) -> Callable: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add unittest for this file |
||
""" | ||
Overview: | ||
The middleware that calculates the priority of the collected data. | ||
Arguments: | ||
- func_for_priority_calculation (:obj:`Callable`): The function that calculates \ | ||
the priority of the collected data. | ||
""" | ||
|
||
if task.router.is_active and not task.has_role(task.role.COLLECTOR): | ||
return task.void() | ||
|
||
def _priority_calculator(ctx: "OnlineRLContext") -> None: | ||
|
||
priority = func_for_priority_calculation(ctx.trajectories) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. priority_calculation_fn There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
for i in range(len(priority)): | ||
ctx.trajectories[i]['priority'] = torch.tensor(priority[i], dtype=torch.float32).unsqueeze(-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't transform it to tensor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
return _priority_calculator |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
from ding.data.storage_loader import FileStorageLoader | ||
from ding.framework import task | ||
from ding.framework.context import OnlineRLContext | ||
from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger | ||
from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger | ||
from ding.framework.parallel import Parallel | ||
from ding.utils.default_helper import set_pkg_seed | ||
from os import path | ||
|
@@ -221,3 +221,54 @@ def pred(ctx): | |
@pytest.mark.tmp | ||
def test_model_exchanger_with_model_loader(): | ||
Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader) | ||
|
||
|
||
def periodical_model_exchanger_main(): | ||
with task.start(ctx=OnlineRLContext()): | ||
set_pkg_seed(0, use_cuda=False) | ||
policy = MockPolicy() | ||
X = torch.rand(10) | ||
y = torch.rand(10) | ||
|
||
if task.router.node_id == 0: | ||
task.add_role(task.role.LEARNER) | ||
task.use(PeriodicalModelExchanger(policy._model, mode="send", period=3)) | ||
else: | ||
task.add_role(task.role.COLLECTOR) | ||
task.use(PeriodicalModelExchanger(policy._model, mode="receive", period=1, stale_toleration=3)) | ||
|
||
if task.has_role(task.role.LEARNER): | ||
|
||
def train(ctx): | ||
policy.train(X, y) | ||
sleep(0.3) | ||
|
||
task.use(train) | ||
else: | ||
y_pred1 = policy.predict(X) | ||
print("y_pred1: ", y_pred1) | ||
stale = 1 | ||
|
||
def pred(ctx): | ||
nonlocal stale | ||
y_pred2 = policy.predict(X) | ||
print("y_pred2: ", y_pred2) | ||
stale += 1 | ||
assert stale <= 3 or all(y_pred1 == y_pred2) | ||
if any(y_pred1 != y_pred2): | ||
stale = 1 | ||
|
||
sleep(0.3) | ||
|
||
task.use(pred) | ||
task.run(8) | ||
|
||
|
||
@pytest.mark.tmp | ||
def test_periodical_model_exchanger(): | ||
Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main) | ||
|
||
|
||
if __name__ == "__main__": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unittest doesn't need this entry There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
#test_model_exchanger() | ||
test_periodical_model_exchanger() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do item operation in original priority producer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed