Skip to content

Commit a759eb2

Browse files
committed
moved actor, critic networks into model directory
1 parent fc7b9d3 commit a759eb2

File tree

2 files changed

+58
-39
lines changed

2 files changed

+58
-39
lines changed

python/ray/rllib/ddpg/models.py

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import numpy as np
66
import tensorflow as tf
7-
import tensorflow.contrib.slim as slim
87

98
from ray.experimental.tfutils import TensorFlowVariables
9+
from ray.rllib.models.ddpgnet import DDPGActor, DDPGCritic
1010
from ray.rllib.ddpg.random_process import OrnsteinUhlenbeckProcess
1111

1212

@@ -202,50 +202,20 @@ def _setup_critic_loss(self, action_space):
202202
def _setup_critic_network(self, obs_space, ac_space):
203203
"""Sets up Q network."""
204204
with tf.variable_scope("critic", reuse=tf.AUTO_REUSE):
205-
self.critic_eval = self._create_critic_network(
206-
self.obs, self.act)
205+
self.critic_network = DDPGCritic((self.obs, self.act), 1, {})
206+
self.critic_eval = self.critic_network.outputs
207207

208208
with tf.variable_scope("critic", reuse=True):
209-
tf.get_variable_scope().reuse_variables()
210-
self.cn_for_loss = self._create_critic_network(
211-
self.obs, self.output_action)
212-
213-
def _create_critic_network(self, obs, action):
214-
"""Network for critic."""
215-
w_normal = tf.truncated_normal_initializer()
216-
w_init = tf.random_uniform_initializer(minval=-0.0003, maxval=0.0003)
217-
net = slim.fully_connected(
218-
obs, 400, activation_fn=tf.nn.relu, weights_initializer=w_normal)
219-
t1 = slim.fully_connected(
220-
net, 300, activation_fn=None, biases_initializer=None,
221-
weights_initializer=w_normal)
222-
t2 = slim.fully_connected(
223-
action, 300, activation_fn=None, weights_initializer=w_normal)
224-
net = tf.nn.relu(tf.add(t1, t2))
225-
226-
out = slim.fully_connected(
227-
net, 1, activation_fn=None, weights_initializer=w_init)
228-
return out
209+
self.cn_for_loss = DDPGCritic(
210+
(self.obs, self.output_action), 1, {}).outputs
229211

230212
def _setup_actor_network(self, obs_space, ac_space):
231213
"""Sets up actor network."""
232214
with tf.variable_scope("actor", reuse=tf.AUTO_REUSE):
233-
self.output_action = self._create_actor_network(self.obs)
234-
235-
def _create_actor_network(self, obs):
236-
"""Network for actor."""
237-
w_normal = tf.truncated_normal_initializer()
238-
w_init = tf.random_uniform_initializer(minval=-0.003, maxval=0.003)
239-
240-
net = slim.fully_connected(
241-
obs, 400, activation_fn=tf.nn.relu, weights_initializer=w_normal)
242-
net = slim.fully_connected(
243-
net, 300, activation_fn=tf.nn.relu, weights_initializer=w_normal)
244-
out = slim.fully_connected(
245-
net, self.ac_size, activation_fn=tf.nn.tanh,
246-
weights_initializer=w_init)
247-
scaled_out = tf.multiply(out, self.action_bound)
248-
return scaled_out
215+
self.actor_network = DDPGActor(
216+
self.obs, self.ac_size,
217+
options={"action_bound": self.action_bound})
218+
self.output_action = self.actor_network.outputs
249219

250220
def get_weights(self):
251221
"""Returns critic weights, actor weights."""

python/ray/rllib/models/ddpgnet.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import tensorflow as tf
6+
import tensorflow.contrib.slim as slim
7+
8+
from ray.rllib.models.model import Model
9+
10+
11+
class DDPGActor(Model):
12+
"""Actor network for DDPG."""
13+
14+
def _init(self, inputs, num_outputs, options):
15+
w_normal = tf.truncated_normal_initializer()
16+
w_init = tf.random_uniform_initializer(minval=-0.003, maxval=0.003)
17+
ac_bound = options["action_bound"]
18+
19+
net = slim.fully_connected(
20+
inputs, 400, activation_fn=tf.nn.relu,
21+
weights_initializer=w_normal)
22+
net = slim.fully_connected(
23+
net, 300, activation_fn=tf.nn.relu, weights_initializer=w_normal)
24+
out = slim.fully_connected(
25+
net, num_outputs, activation_fn=tf.nn.tanh,
26+
weights_initializer=w_init)
27+
scaled_out = tf.multiply(out, ac_bound)
28+
return scaled_out, net
29+
30+
31+
class DDPGCritic(Model):
32+
"""Critic network for DDPG."""
33+
34+
def _init(self, inputs, num_outputs, options):
35+
obs, action = inputs
36+
w_normal = tf.truncated_normal_initializer()
37+
w_init = tf.random_uniform_initializer(minval=-0.0003, maxval=0.0003)
38+
net = slim.fully_connected(
39+
obs, 400, activation_fn=tf.nn.relu, weights_initializer=w_normal)
40+
t1 = slim.fully_connected(
41+
net, 300, activation_fn=None, biases_initializer=None,
42+
weights_initializer=w_normal)
43+
t2 = slim.fully_connected(
44+
action, 300, activation_fn=None, weights_initializer=w_normal)
45+
net = tf.nn.relu(tf.add(t1, t2))
46+
47+
out = slim.fully_connected(
48+
net, 1, activation_fn=None, weights_initializer=w_init)
49+
return out, net

0 commit comments

Comments
 (0)