Skip to content

Commit

Permalink
Speed up dual policy evaluation functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
dmorrill10 committed Oct 27, 2018
1 parent 3165358 commit 5eb33a5
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 128 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# tensorflow or tensorflow-gpu >= v1.8
'fire',
'numpy',
'scipy' # For environments/inventory.
'scipy', # For environments/inventory.
'deprecation'
],
tests_require=['pytest', 'pytest-cov'],
setup_requires=['pytest-runner'],
Expand Down
45 changes: 15 additions & 30 deletions test/discounted_mdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def setUp(self):
tf.set_random_seed(10)

def test_row_normalize_op(self):
root = tf.reshape(normalized(tf.constant([1, 2, 3.0])), [3, 1])
v = tf.reshape(tf.constant([1.0, 2, 3]), [3, 1])
self.assertAllClose([[2.3333335]], root_value_op(root, v))
root = normalized([1, 2, 3.0])
v = [1.0, 2, 3]
self.assertAllClose(2.3333335, root_value_op(root, v))

def test_state_action_successor_policy_evaluation_op(self):
gamma = 0.9
Expand Down Expand Up @@ -97,25 +97,15 @@ def test_dual_and_primal_policy_evaluation_agree(self):
threshold=threshold,
max_num_iterations=max_num_iterations),
dual_action_value_policy_evaluation_op(
transitions,
policy,
r,
gamma=gamma,
threshold=threshold,
max_num_iterations=max_num_iterations))
transitions, policy, r, gamma=gamma))

with self.subTest('two reward functions'):
r_both = tf.stack(
[r, tf.random_normal(shape=[num_states, num_actions])],
axis=-1)

patient = dual_action_value_policy_evaluation_op(
transitions,
policy,
r_both,
gamma=gamma,
threshold=threshold,
max_num_iterations=max_num_iterations)
transitions, policy, r_both, gamma=gamma)

self.assertAllClose(
primal_action_value_policy_evaluation_op(
Expand Down Expand Up @@ -166,10 +156,10 @@ def test_gpi_value(self):
threshold=threshold,
max_num_iterations=max_num_iterations)

mu = normalized(tf.ones([num_states, 1]))
mu = normalized(tf.ones([num_states]))

v = tf.reduce_sum(policy_1_op * q_op, axis=-1, keepdims=True)
self.assertAllClose(-2.354447, tf.squeeze(root_value_op(mu, v)))
v = tf.reduce_sum(policy_1_op * q_op, axis=-1)
self.assertAllClose(-2.354447, root_value_op(mu, v))

policy_5_op = generalized_policy_iteration_op(
transitions,
Expand All @@ -185,20 +175,15 @@ def test_gpi_value(self):
threshold=threshold,
max_num_iterations=max_num_iterations)

v = tf.reduce_sum(policy_5_op * q_op, axis=-1, keepdims=True)
self.assertAllClose(-2.354447, tf.squeeze(root_value_op(mu, v)))
v = tf.reduce_sum(policy_5_op * q_op, axis=-1)
self.assertAllClose(-2.354447, root_value_op(mu, v))

dual_state_values = dual_state_value_policy_evaluation_op(
transitions,
policy_5_op,
r,
gamma=gamma,
threshold=threshold,
max_num_iterations=max_num_iterations)
transitions, policy_5_op, r, gamma=gamma)

self.assertAllClose(
-2.354438,
tf.squeeze(root_value_op(mu, dual_state_values)),
root_value_op(mu, dual_state_values),
rtol=1e-04,
atol=1e-04)

Expand Down Expand Up @@ -244,9 +229,9 @@ def test_recover_state_distribution_from_state_action_distribution(self):
sum_M_op = tf.reduce_sum(M_op, axis=1)
self.assertAllClose(tf.ones_like(sum_M_op), sum_M_op)
self.assertAllClose(A_op, tf.matmul(M_op, Pi_op))
self.assertAllClose(M_op,
state_successor_policy_evaluation_op(
transitions, policy_op, gamma))
self.assertAllClose(
M_op, (1.0 - gamma) * state_successor_policy_evaluation_op(
transitions, policy_op, gamma))


if __name__ == '__main__':
Expand Down
163 changes: 66 additions & 97 deletions tf_kofn_robust_policy_optimization/discounted_mdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
from deprecation import deprecated
from tf_contextual_prediction_with_expert_advice import \
l1_projection_to_simplex, \
indmax
Expand Down Expand Up @@ -52,54 +53,30 @@ def cond(d, H_d, H_dp1):
parallel_iterations=1)[-1]


def dual_action_value_policy_evaluation_op(transitions,
policy,
r,
gamma=0.9,
threshold=1e-15,
max_num_iterations=-1):
def dual_action_value_policy_evaluation_op(transitions, policy, r, gamma=0.9):
transitions = tf.convert_to_tensor(transitions)
policy = tf.convert_to_tensor(policy)
r = tf.convert_to_tensor(r)
extra_dims = r.shape[2:]
shape = policy.shape.concatenate(extra_dims)
H = state_action_successor_policy_evaluation_op(
transitions,
policy,
gamma=gamma,
threshold=threshold,
max_num_iterations=max_num_iterations)
action_values = tf.reshape(
tf.tensordot(
H,
tf.reshape(r, H.shape[0:1].concatenate(extra_dims)),
axes=[[1], [0]]), shape)
if gamma < 1:
action_values = action_values / (1.0 - gamma)
return action_values


def dual_state_value_policy_evaluation_op(transitions,
policy,
r,
gamma=0.9,
threshold=1e-15,
max_num_iterations=-1):

v = gamma * dual_state_value_policy_evaluation_op(
transitions, policy, r, gamma=gamma)
v = tf.reshape(v, [1, 1] + [dim.value for dim in v.shape])
if len(v.shape) > len(transitions.shape):
transitions = tf.reshape(
transitions, [dim.value for dim in transitions.shape] + [1] *
(len(v.shape) - len(transitions.shape)))
return r + tf.reduce_sum(transitions * v, axis=2)


def dual_state_value_policy_evaluation_op(transitions, policy, r, gamma=0.9):
policy = tf.convert_to_tensor(policy)
r = tf.convert_to_tensor(r)
M = state_successor_policy_evaluation_op(
transitions,
policy,
gamma=gamma,
threshold=threshold,
max_num_iterations=max_num_iterations)
M = state_successor_policy_evaluation_op(transitions, policy, gamma=gamma)
if len(r.shape) > 2:
M = tf.expand_dims(M, axis=-1)
policy = tf.expand_dims(policy, axis=-1)
weighted_rewards = tf.reduce_sum(r * policy, axis=1, keepdims=True)
state_values = tf.tensordot(M, weighted_rewards, axes=[[1], [0]])
if gamma < 1:
state_values = state_values / (1.0 - gamma)
return state_values
weighted_rewards = tf.expand_dims(tf.reduce_sum(r * policy, axis=1), 0)
return tf.reduce_sum(M * weighted_rewards, axis=1)


def primal_action_value_policy_evaluation_op(transitions,
Expand Down Expand Up @@ -194,9 +171,51 @@ def cond(d, q_d, q_dp1):


def root_value_op(mu, v):
return tf.transpose(tf.matmul(mu, v, transpose_a=True))
'''
If mu and v are two dimensional, this function assumes the first
dimension of both mu and v is a batch dimension.
'''
return tf.reduce_sum(mu * v, axis=-1)


def state_successor_policy_evaluation_op(transitions, policy, gamma=0.9):
'''
The discounted unnormalized successor representation for the given
transitions and policy.
If gamma is less than 1, multiplying each element by 1 - gamma recovers
the row-normalized version.
'''
weighted_transitions = transitions * tf.expand_dims(policy, axis=-1)
negative_state_to_state = (
-gamma * tf.reduce_sum(weighted_transitions, axis=1))
eye_minus_gamma_state_to_state = tf.linalg.set_diag(
negative_state_to_state, 1.0 + tf.diag_part(negative_state_to_state))

return tf.matrix_inverse(eye_minus_gamma_state_to_state)


def state_distribution(state_successor_rep, state_probs):
'''
Probability of terminating in each state.
|States| by 1 Tensor
Parameters:
- state_successor_rep: |States| by |States| successor representation.
- state_probs: (m by) |States| vector of initial state probabilities.
'''
state_probs = tf.convert_to_tensor(state_probs)
if len(state_probs.shape) < 2:
state_probs = tf.expand_dims(state_probs, axis=0)
return tf.matmul(state_probs, state_successor_rep)


@deprecated(
details=(
'Outdated and poorly named. Use state and action policy evaluation methods directly instead.'
)
) # yapf:disable
def value_ops(Pi, root_op, transition_model_op, reward_model_op, **kwargs):
action_values_op = dual_action_value_policy_evaluation_op(
transition_model_op, Pi, reward_model_op, **kwargs)
Expand All @@ -207,6 +226,11 @@ def value_ops(Pi, root_op, transition_model_op, reward_model_op, **kwargs):
return action_values_op, state_values_op, ev_op


@deprecated(
details=(
'Outdated and poorly named. Use state and action policy evaluation methods directly instead.'
)
) # yapf:disable
def associated_ops(action_weights,
root_op,
transition_model_op,
Expand All @@ -223,58 +247,3 @@ def associated_ops(action_weights,
Pi, root_op, transition_model_op, reward_model_op, **kwargs)

return Pi, action_values_op, state_values_op, ev_op


def state_successor_policy_evaluation_op(transitions,
policy,
gamma=0.9,
threshold=1e-15,
max_num_iterations=-1,
M_0=None):
transitions = tf.convert_to_tensor(transitions)
num_states = transitions.shape[0].value

if M_0 is None:
M_0 = tf.eye(num_states)

weighted_transitions = (
transitions * tf.expand_dims(gamma * policy, axis=-1))

state_to_state = tf.reduce_sum(weighted_transitions, axis=1)

def M_dp1_op(M_d):
future_return = M_d @ state_to_state
return tf.linalg.set_diag(future_return,
tf.diag_part(future_return) + 1.0 - gamma)

def error_above_threshold(M_d, M_dp1):
return tf.greater(tf.reduce_sum(tf.abs(M_dp1 - M_d)), threshold)

def cond(d, M_d, M_dp1):
error_is_high = True if threshold is None else error_above_threshold(
M_d, M_dp1)
return tf.logical_or(
tf.logical_and(tf.less(max_num_iterations, 1), error_is_high),
tf.logical_and(tf.less(d, max_num_iterations), error_is_high))

return tf.while_loop(
cond,
lambda d, _, M_d: [d + 1, M_d, M_dp1_op(M_d)],
[1, M_0, M_dp1_op(M_0)],
parallel_iterations=1)[-1]


def state_distribution(state_successor_rep, state_probs):
'''
Probability of terminating in each state.
|States| by 1 Tensor
Parameters:
- state_successor_rep: |States| by |States| successor representation.
- state_probs: (m by) |States| vector of initial state probabilities.
'''
state_probs = tf.convert_to_tensor(state_probs)
if len(state_probs.shape) < 2:
state_probs = tf.expand_dims(state_probs, axis=0)
return tf.matmul(state_probs, state_successor_rep)

0 comments on commit 5eb33a5

Please sign in to comment.