Skip to content

Commit 680bb52

Browse files
ysqyangysqyangyaqiu
authored
Rl v3 load save (#463)
* added load/save feature * fixed some bugs * reverted unwanted changes * lint * fixed PR comments Co-authored-by: ysqyang <v-yangqi@microsoft.com> Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
1 parent 4faa8f1 commit 680bb52

File tree

24 files changed

+288
-152
lines changed

24 files changed

+288
-152
lines changed

examples/rl/cim/algorithms/ac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def apply_gradients(self, grad: dict) -> None:
6060
param.grad = grad[name]
6161
self._optim.step()
6262

63-
def get_net_state(self) -> dict:
63+
def get_state(self) -> dict:
6464
return {
6565
"network": self.state_dict(),
6666
"optim": self._optim.state_dict()
6767
}
6868

69-
def set_net_state(self, net_state: dict) -> None:
69+
def set_state(self, net_state: dict) -> None:
7070
self.load_state_dict(net_state["network"])
7171
self._optim.load_state_dict(net_state["optim"])
7272

@@ -95,13 +95,13 @@ def apply_gradients(self, grad: dict) -> None:
9595
param.grad = grad[name]
9696
self._optim.step()
9797

98-
def get_net_state(self) -> dict:
98+
def get_state(self) -> dict:
9999
return {
100100
"network": self.state_dict(),
101101
"optim": self._optim.state_dict()
102102
}
103103

104-
def set_net_state(self, net_state: dict) -> None:
104+
def set_state(self, net_state: dict) -> None:
105105
self.load_state_dict(net_state["network"])
106106
self._optim.load_state_dict(net_state["optim"])
107107

examples/rl/cim/algorithms/dqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def apply_gradients(self, grad: dict) -> None:
4747
param.grad = grad[name]
4848
self._optim.step()
4949

50-
def get_net_state(self) -> object:
50+
def get_state(self) -> object:
5151
return {"network": self.state_dict(), "optim": self._optim.state_dict()}
5252

53-
def set_net_state(self, net_state: object) -> None:
53+
def set_state(self, net_state: object) -> None:
5454
assert isinstance(net_state, dict)
5555
self.load_state_dict(net_state["network"])
5656
self._optim.load_state_dict(net_state["optim"])

examples/rl/cim/algorithms/maddpg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ def apply_gradients(self, grad: dict) -> None:
6262
param.grad = grad[name]
6363
self._optim.step()
6464

65-
def get_net_state(self) -> dict:
65+
def get_state(self) -> dict:
6666
return {
6767
"network": self.state_dict(),
6868
"optim": self._optim.state_dict()
6969
}
7070

71-
def set_net_state(self, net_state: dict) -> None:
71+
def set_state(self, net_state: dict) -> None:
7272
self.load_state_dict(net_state["network"])
7373
self._optim.load_state_dict(net_state["optim"])
7474

@@ -97,13 +97,13 @@ def apply_gradients(self, grad: dict) -> None:
9797
param.grad = grad[name]
9898
self._optim.step()
9999

100-
def get_net_state(self) -> dict:
100+
def get_state(self) -> dict:
101101
return {
102102
"network": self.state_dict(),
103103
"optim": self._optim.state_dict()
104104
}
105105

106-
def set_net_state(self, net_state: dict) -> None:
106+
def set_state(self, net_state: dict) -> None:
107107
self.load_state_dict(net_state["network"])
108108
self._optim.load_state_dict(net_state["optim"])
109109

examples/rl/vm_scheduling/algorithms/ac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def apply_gradients(self, grad: dict) -> None:
6464
param.grad = grad[name]
6565
self._optim.step()
6666

67-
def get_net_state(self) -> dict:
67+
def get_state(self) -> dict:
6868
return {
6969
"network": self.state_dict(),
7070
"optim": self._optim.state_dict()
7171
}
7272

73-
def set_net_state(self, net_state: dict) -> None:
73+
def set_state(self, net_state: dict) -> None:
7474
self.load_state_dict(net_state["network"])
7575
self._optim.load_state_dict(net_state["optim"])
7676

@@ -102,13 +102,13 @@ def apply_gradients(self, grad: dict) -> None:
102102
param.grad = grad[name]
103103
self._optim.step()
104104

105-
def get_net_state(self) -> dict:
105+
def get_state(self) -> dict:
106106
return {
107107
"network": self.state_dict(),
108108
"optim": self._optim.state_dict()
109109
}
110110

111-
def set_net_state(self, net_state: dict) -> None:
111+
def set_state(self, net_state: dict) -> None:
112112
self.load_state_dict(net_state["network"])
113113
self._optim.load_state_dict(net_state["optim"])
114114

examples/rl/vm_scheduling/algorithms/dqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def apply_gradients(self, grad: dict) -> None:
5555
param.grad = grad[name]
5656
self._optim.step()
5757

58-
def get_net_state(self) -> object:
58+
def get_state(self) -> object:
5959
return {"network": self.state_dict(), "optim": self._optim.state_dict()}
6060

61-
def set_net_state(self, net_state: object) -> None:
61+
def set_state(self, net_state: object) -> None:
6262
assert isinstance(net_state, dict)
6363
self.load_state_dict(net_state["network"])
6464
self._optim.load_state_dict(net_state["optim"])

examples/rl/vm_scheduling/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@
4141

4242
test_seed = 1024
4343

44-
algorithm = "dqn" # "dqn" or "ac"
44+
algorithm = "ac" # "dqn" or "ac"

maro/cli/local/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ def get_docker_compose_yml(config: dict, context: str, dockerfile_path: str, ima
158158
}
159159
for component, env in config_parser.get_rl_component_env_vars(config, containerized=True).items()
160160
}
161-
# if config["mode"] != "single":
162-
# manifest["services"]["redis"] = {"image": "redis", "container_name": redis_host}
163161

164162
return manifest
165163

maro/rl/distributed/abs_worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@
1010
from zmq.eventloop.zmqstream import ZMQStream
1111

1212
from maro.rl.utils.common import string_to_bytes
13-
from maro.utils import Logger
13+
from maro.utils import DummyLogger, Logger
1414

1515

1616
class AbsWorker(object):
1717
def __init__(
1818
self,
1919
idx: int,
2020
router_host: str,
21-
router_port: int = 10001
21+
router_port: int = 10001,
22+
logger: Logger = None
2223
) -> None:
2324
super(AbsWorker, self).__init__()
2425

2526
self._id = f"worker.{idx}"
26-
self._logger = Logger(self._id)
27+
self._logger = DummyLogger() if logger is None else logger
2728

2829
# ZMQ sockets and streams
2930
self._context = Context.instance()

maro/rl/model/abs_net.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def _forward_unimplemented(self, *input: Any) -> None: # TODO
4848
pass
4949

5050
@abstractmethod
51-
def get_net_state(self) -> object:
51+
def get_state(self) -> object:
5252
"""
5353
Get the net's state.
5454
"""
5555
raise NotImplementedError
5656

5757
@abstractmethod
58-
def set_net_state(self, net_state: object) -> None:
58+
def set_state(self, net_state: object) -> None:
5959
"""
6060
Set the net's state.
6161
"""

maro/rl/policy/continuous_rl_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def train(self) -> None:
100100
self._policy_net.train()
101101

102102
def get_state(self) -> object:
103-
return self._policy_net.get_net_state()
103+
return self._policy_net.get_state()
104104

105105
def set_state(self, policy_state: object) -> None:
106-
self._policy_net.set_net_state(policy_state)
106+
self._policy_net.set_state(policy_state)
107107

108108
def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
109109
assert isinstance(other_policy, ContinuousRLPolicy)

0 commit comments

Comments
 (0)