Skip to content

Commit

Permalink
update dppo
Browse files Browse the repository at this point in the history
  • Loading branch information
quantumiracle committed May 26, 2021
1 parent d08778b commit 0fa8199
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 199 deletions.
89 changes: 63 additions & 26 deletions dppo_clip_distributed/dppo_global_manager.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,90 @@
from rlzoo.common.policy_networks import StochasticPolicyNetwork
from rlzoo.common.value_networks import ValueNetwork
import numpy as np
from rlzoo.common.utils import *
import pickle
import queue


def write_log(text: str):
pass
# print('global manager: '+text)
# with open('global_manager_log.txt', 'a') as f:
# f.write(str(text) + '\n')


class DPPOGlobalManager:
def __init__(self, net_builder, opt_builder, param_pipe_list, name='DPPO_CLIP'):
networks = net_builder()
optimizers_list = opt_builder()
def __init__(self, net_builder, opt_builder, name='DPPO_CLIP'):
self.net_builder, self.opt_builder = net_builder, opt_builder
self.name = name
self.critic, self.actor = None, None
self.critic_opt, self.actor_opt = None, None

def init_components(self):
networks = self.net_builder()
optimizers_list = self.opt_builder()
assert len(networks) == 2
assert len(optimizers_list) == 2
self.critic, self.actor = networks
assert isinstance(self.critic, ValueNetwork)
assert isinstance(self.actor, StochasticPolicyNetwork)
self.critic_opt, self.actor_opt = optimizers_list
self.param_pipe_list = param_pipe_list
self.name = name

def run(self, traj_queue, grad_queue, should_stop, should_update, barrier,
def run(self, traj_queue, grad_queue, should_stop, should_update, barrier, param_pipe_list,
max_update_num=1000, update_interval=100, save_interval=10, env_name='CartPole-v0'):

self.init_components()

if should_update.is_set():
write_log('syn model')
self.send_param(param_pipe_list)
write_log('wait for barrier')
barrier.wait()
should_update.clear()

update_cnt = 0
batch_a_grad, batch_c_grad = [], []
while update_cnt < max_update_num:
batch_a_grad, batch_c_grad = [], []
for _ in range(update_interval):
a_grad, c_grad = grad_queue.get()
# print('\rupdate cnt {}, traj_que {}, grad_que {}'.format(
# update_cnt, traj_queue.qsize(), grad_queue[0].qsize()), end='')
print('update cnt {}, traj_que {}, grad_que {}'.format(
update_cnt, traj_queue.qsize(), grad_queue[0].qsize()))
try:
a_grad, c_grad = [q.get(timeout=1) for q in grad_queue]
batch_a_grad.append(a_grad)
batch_c_grad.append(c_grad)
write_log('got grad')
except queue.Empty:
continue

# update
should_update.set()
self.update_model(batch_a_grad, batch_c_grad)
self.send_param()
if len(batch_a_grad) > update_interval and len(batch_c_grad) > update_interval:
# write_log('ready to update')
# update
should_update.set()
write_log('update model')
self.update_model(batch_a_grad, batch_c_grad)
write_log('send_param')
self.send_param(param_pipe_list)

traj_queue.empty()
for q in grad_queue: q.empty()

barrier.wait()
should_update.clear()
write_log('empty queue')
traj_queue.empty()
for q in grad_queue:
q.empty()
batch_a_grad.clear()
batch_c_grad.clear()

update_cnt += 1
if update_cnt // save_interval == 0:
self.save_model(env_name)
write_log('wait for barrier')
barrier.wait()
should_update.clear()
barrier.reset()
update_cnt += 1
if update_cnt // save_interval == 0:
self.save_model(env_name)
should_stop.set()

def send_param(self):
def send_param(self, param_pipe_list):
params = self.critic.trainable_weights + self.actor.trainable_weights
for pipe_connection in self.param_pipe_list:
pipe_connection.send(params)
params = [p.numpy() for p in params]
for i, pipe_connection in enumerate(param_pipe_list):
pipe_connection.put(params)

def update_model(self, batch_a_grad, batch_c_grad):
a_grad = np.mean(batch_a_grad, axis=0)
Expand Down
136 changes: 56 additions & 80 deletions dppo_clip_distributed/dppo_infer_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,31 @@
import pickle


def write_log(text: str):
pass
# print('infer server: '+text)
# with open('infer_server_log.txt', 'a') as f:
# f.write(str(text) + '\n')


class DPPOInferServer:
def __init__(self, net_builder, net_param_pipe, n_step=1000, gamma=0.9):
networks = net_builder()
assert len(networks) == 2
self.critic, self.actor = networks
assert isinstance(self.critic, ValueNetwork)
assert isinstance(self.actor, StochasticPolicyNetwork)
def __init__(self, net_builder, n_step=100, gamma=0.9):
self.critic, self.actor = None, None
self.net_builder = net_builder
self.state_buffer = []
self.action_buffer = []
self.reward_buffer = []
self.done_buffer = []
self.logp_buffer = []
self.gamma = gamma
self.n_step = n_step
self.net_param_pipe = net_param_pipe

def init_components(self):
networks = self.net_builder()
assert len(networks) == 2
self.critic, self.actor = networks
assert isinstance(self.critic, ValueNetwork)
assert isinstance(self.actor, StochasticPolicyNetwork)

def _cal_adv(self):
dc_r = self._cal_discounted_r()
Expand All @@ -42,19 +52,24 @@ def _cal_discounted_r(self):
return discounted_r

def _get_traj(self):
traj = []
traj_list = []
for element in [self.state_buffer, self.action_buffer, self.reward_buffer, self.done_buffer, self._cal_adv(),
self.logp_buffer]:
axes = list(range(len(np.shape(element))))
axes[0], axes[1] = 1, 0
traj.append(np.transpose(element, axes))
traj_list.append(np.transpose(element, axes))
if type(element) == list:
element.clear()
return traj
traj_list = list(zip(*traj_list))
return traj_list

def inference_service(self, batch_s):
print(batch_s)
write_log('get action')
# write_log(self.actor.trainable_weights)
# write_log(batch_s)
batch_s = np.array(batch_s)
batch_a = self.actor(batch_s).numpy()
write_log('get log p')
batch_log_p = self.actor.policy_dist.get_param()
return batch_a, batch_log_p

Expand All @@ -66,32 +81,50 @@ def collect_data(self, s, a, r, d, log_p):
self.logp_buffer.append(log_p)

def upload_data(self, que):
traj_data = self._get_traj()
que.put(traj_data)
print('\rupdated, queue size: {}, current data shape: {}'.format(que.qsize(), [np.shape(i) for i in traj_data]))

def run(self, pipe_list, traj_queue, should_stop, should_update, barrier, ):
states, rewards, dones, infos = zip(*[remote.recv() for remote in pipe_list])
traj_list = self._get_traj()
traj = []
for traj in traj_list:
que.put(traj)
# print('\rinfer server: updated, queue size: {}, current data shape: {}'.format(que.qsize(), [np.shape(i) for i in traj]))
write_log('\rupdated, queue size: {}, current data shape: {}'.format(que.qsize(), [np.shape(i) for i in traj]))

def run(self, pipe_list, traj_queue, should_stop, should_update, barrier, param_que):
self.init_components()
data = []
for i, remote_connect in enumerate(pipe_list):
write_log('recv {}'.format(i))
data.append(remote_connect.recv())
write_log('first recved')
states, rewards, dones, infos = zip(*data)
# states, rewards, dones, infos = zip(*[remote.recv() for remote in pipe_list])
states, rewards, dones, infos = np.stack(states), np.stack(rewards), np.stack(dones), np.stack(infos)

write_log('before while')
while not should_stop.is_set():
write_log('into while')
if should_update.is_set():
self.update_model()
write_log('update_model')
self.update_model(param_que)
write_log('barrier.wait')
barrier.wait()
write_log('befor infer')
actions, log_ps = self.inference_service(states)
write_log('before send')
for (remote, a) in zip(pipe_list, actions):
remote.send(a)

write_log('recv from pipe')
states, rewards, dones, infos = zip(*[remote.recv() for remote in pipe_list])
states, rewards, dones, infos = np.stack(states), np.stack(rewards), np.stack(dones), np.stack(infos)
self.collect_data(states, actions, rewards, dones, log_ps)

print('\rsampling, {}'.format(len(self.state_buffer)), end='')
write_log('sampling, {}'.format(len(self.state_buffer)))
# print('\rsampling, {}'.format(len(self.state_buffer)), end='')
if len(self.state_buffer) >= self.n_step:
self.upload_data(traj_queue)

def update_model(self):
params = self.net_param_pipe.recv()
def update_model(self, param_que):
write_log('get from param_que')
params = param_que.get()
write_log('assign param')
for i, j in zip(self.critic.trainable_weights + self.actor.trainable_weights, params):
i.assign(j)
self.state_buffer.clear()
Expand All @@ -100,60 +133,3 @@ def update_model(self):
self.done_buffer.clear()
self.logp_buffer.clear()


if __name__ == '__main__':
import multiprocessing as mp

from rlzoo.common.env_wrappers import build_env
from dppo_clip_distributed.dppo_sampler import DPPOSampler
import copy, json, pickle
from gym.spaces.box import Box
from gym.spaces.discrete import Discrete
import cloudpickle

should_stop_event = mp.Event()
should_stop_event.clear()

# build_sampler
nenv = 3


def build_func():
return build_env('CartPole-v0', 'classic_control')


pipe_list = []
for _ in range(nenv):
sampler = DPPOSampler(build_func)
remote_a, remote_b = mp.Pipe()
p = mp.Process(target=sampler.run, args=(remote_a, should_stop_event))
p.daemon = True # todo 守护进程的依赖关系
p.start()
pipe_list.append(remote_b)

traj_queue = mp.Queue(maxsize=10000)
grad_queue = mp.Queue(maxsize=10000), mp.Queue(maxsize=10000),
should_update_event = mp.Event()
should_update_event.clear()
barrier = mp.Barrier(1) # sampler + updater

""" build networks for the algorithm """
name = 'DPPO_CLIP'
hidden_dim = 64
num_hidden_layer = 2
critic = ValueNetwork(Box(0, 1, (4,)), [hidden_dim] * num_hidden_layer, name=name + '_value')
actor = StochasticPolicyNetwork(Box(0, 1, (4,)), Discrete(2),
[hidden_dim] * num_hidden_layer,
trainable=True,
name=name + '_policy')

actor = copy.deepcopy(actor)
global_nets = critic, actor

global_nets = cloudpickle.dumps(global_nets)
# p = mp.Process(
# target=DPPOInferServer(global_nets).run,
# args=(traj_queue, should_stop_event, should_update_event, barrier)
# )
# p.start()
DPPOInferServer(global_nets).run(pipe_list, traj_queue, should_stop_event, should_update_event, barrier)
Loading

0 comments on commit 0fa8199

Please sign in to comment.