diff --git a/setup.py b/setup.py index a125443..00790a9 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,8 @@ # tensorflow or tensorflow-gpu >= v1.8 'fire', 'numpy', - 'scipy' # For environments/inventory. + 'scipy', # For environments/inventory. + 'deprecation' ], tests_require=['pytest', 'pytest-cov'], setup_requires=['pytest-runner'], diff --git a/test/discounted_mdp_test.py b/test/discounted_mdp_test.py index 149588a..e9b5fd0 100644 --- a/test/discounted_mdp_test.py +++ b/test/discounted_mdp_test.py @@ -25,9 +25,9 @@ def setUp(self): tf.set_random_seed(10) def test_row_normalize_op(self): - root = tf.reshape(normalized(tf.constant([1, 2, 3.0])), [3, 1]) - v = tf.reshape(tf.constant([1.0, 2, 3]), [3, 1]) - self.assertAllClose([[2.3333335]], root_value_op(root, v)) + root = normalized([1, 2, 3.0]) + v = [1.0, 2, 3] + self.assertAllClose(2.3333335, root_value_op(root, v)) def test_state_action_successor_policy_evaluation_op(self): gamma = 0.9 @@ -97,12 +97,7 @@ def test_dual_and_primal_policy_evaluation_agree(self): threshold=threshold, max_num_iterations=max_num_iterations), dual_action_value_policy_evaluation_op( - transitions, - policy, - r, - gamma=gamma, - threshold=threshold, - max_num_iterations=max_num_iterations)) + transitions, policy, r, gamma=gamma)) with self.subTest('two reward functions'): r_both = tf.stack( @@ -110,12 +105,7 @@ def test_dual_and_primal_policy_evaluation_agree(self): axis=-1) patient = dual_action_value_policy_evaluation_op( - transitions, - policy, - r_both, - gamma=gamma, - threshold=threshold, - max_num_iterations=max_num_iterations) + transitions, policy, r_both, gamma=gamma) self.assertAllClose( primal_action_value_policy_evaluation_op( @@ -166,10 +156,10 @@ def test_gpi_value(self): threshold=threshold, max_num_iterations=max_num_iterations) - mu = normalized(tf.ones([num_states, 1])) + mu = normalized(tf.ones([num_states])) - v = tf.reduce_sum(policy_1_op * q_op, axis=-1, keepdims=True) - self.assertAllClose(-2.354447, tf.squeeze(root_value_op(mu, v))) + v = tf.reduce_sum(policy_1_op * q_op, axis=-1) + self.assertAllClose(-2.354447, root_value_op(mu, v)) policy_5_op = generalized_policy_iteration_op( transitions, @@ -185,20 +175,15 @@ def test_gpi_value(self): threshold=threshold, max_num_iterations=max_num_iterations) - v = tf.reduce_sum(policy_5_op * q_op, axis=-1, keepdims=True) - self.assertAllClose(-2.354447, tf.squeeze(root_value_op(mu, v))) + v = tf.reduce_sum(policy_5_op * q_op, axis=-1) + self.assertAllClose(-2.354447, root_value_op(mu, v)) dual_state_values = dual_state_value_policy_evaluation_op( - transitions, - policy_5_op, - r, - gamma=gamma, - threshold=threshold, - max_num_iterations=max_num_iterations) + transitions, policy_5_op, r, gamma=gamma) self.assertAllClose( -2.354438, - tf.squeeze(root_value_op(mu, dual_state_values)), + root_value_op(mu, dual_state_values), rtol=1e-04, atol=1e-04) @@ -244,9 +229,9 @@ def test_recover_state_distribution_from_state_action_distribution(self): sum_M_op = tf.reduce_sum(M_op, axis=1) self.assertAllClose(tf.ones_like(sum_M_op), sum_M_op) self.assertAllClose(A_op, tf.matmul(M_op, Pi_op)) - self.assertAllClose(M_op, - state_successor_policy_evaluation_op( - transitions, policy_op, gamma)) + self.assertAllClose( + M_op, (1.0 - gamma) * state_successor_policy_evaluation_op( + transitions, policy_op, gamma)) if __name__ == '__main__': diff --git a/tf_kofn_robust_policy_optimization/discounted_mdp.py b/tf_kofn_robust_policy_optimization/discounted_mdp.py index e20e590..9b3ad51 100644 --- a/tf_kofn_robust_policy_optimization/discounted_mdp.py +++ b/tf_kofn_robust_policy_optimization/discounted_mdp.py @@ -1,4 +1,5 @@ import tensorflow as tf +from deprecation import deprecated from tf_contextual_prediction_with_expert_advice import \ l1_projection_to_simplex, \ indmax @@ -52,54 +53,30 @@ def cond(d, H_d, H_dp1): parallel_iterations=1)[-1] -def dual_action_value_policy_evaluation_op(transitions, - policy, - r, - gamma=0.9, - threshold=1e-15, - max_num_iterations=-1): +def dual_action_value_policy_evaluation_op(transitions, policy, r, gamma=0.9): + transitions = tf.convert_to_tensor(transitions) policy = tf.convert_to_tensor(policy) r = tf.convert_to_tensor(r) - extra_dims = r.shape[2:] - shape = policy.shape.concatenate(extra_dims) - H = state_action_successor_policy_evaluation_op( - transitions, - policy, - gamma=gamma, - threshold=threshold, - max_num_iterations=max_num_iterations) - action_values = tf.reshape( - tf.tensordot( - H, - tf.reshape(r, H.shape[0:1].concatenate(extra_dims)), - axes=[[1], [0]]), shape) - if gamma < 1: - action_values = action_values / (1.0 - gamma) - return action_values - - -def dual_state_value_policy_evaluation_op(transitions, - policy, - r, - gamma=0.9, - threshold=1e-15, - max_num_iterations=-1): + + v = gamma * dual_state_value_policy_evaluation_op( + transitions, policy, r, gamma=gamma) + v = tf.reshape(v, [1, 1] + [dim.value for dim in v.shape]) + if len(v.shape) > len(transitions.shape): + transitions = tf.reshape( + transitions, [dim.value for dim in transitions.shape] + [1] * + (len(v.shape) - len(transitions.shape))) + return r + tf.reduce_sum(transitions * v, axis=2) + + +def dual_state_value_policy_evaluation_op(transitions, policy, r, gamma=0.9): policy = tf.convert_to_tensor(policy) r = tf.convert_to_tensor(r) - M = state_successor_policy_evaluation_op( - transitions, - policy, - gamma=gamma, - threshold=threshold, - max_num_iterations=max_num_iterations) + M = state_successor_policy_evaluation_op(transitions, policy, gamma=gamma) if len(r.shape) > 2: M = tf.expand_dims(M, axis=-1) policy = tf.expand_dims(policy, axis=-1) - weighted_rewards = tf.reduce_sum(r * policy, axis=1, keepdims=True) - state_values = tf.tensordot(M, weighted_rewards, axes=[[1], [0]]) - if gamma < 1: - state_values = state_values / (1.0 - gamma) - return state_values + weighted_rewards = tf.expand_dims(tf.reduce_sum(r * policy, axis=1), 0) + return tf.reduce_sum(M * weighted_rewards, axis=1) def primal_action_value_policy_evaluation_op(transitions, @@ -194,9 +171,51 @@ def cond(d, q_d, q_dp1): def root_value_op(mu, v): - return tf.transpose(tf.matmul(mu, v, transpose_a=True)) + ''' + If mu and v are two dimensional, this function assumes the first + dimension of both mu and v is a batch dimension. + ''' + return tf.reduce_sum(mu * v, axis=-1) + + +def state_successor_policy_evaluation_op(transitions, policy, gamma=0.9): + ''' + The discounted unnormalized successor representation for the given + transitions and policy. + + If gamma is less than 1, multiplying each element by 1 - gamma recovers + the row-normalized version. + ''' + weighted_transitions = transitions * tf.expand_dims(policy, axis=-1) + negative_state_to_state = ( + -gamma * tf.reduce_sum(weighted_transitions, axis=1)) + eye_minus_gamma_state_to_state = tf.linalg.set_diag( + negative_state_to_state, 1.0 + tf.diag_part(negative_state_to_state)) + + return tf.matrix_inverse(eye_minus_gamma_state_to_state) + + +def state_distribution(state_successor_rep, state_probs): + ''' + Probability of terminating in each state. + + |States| by 1 Tensor + + Parameters: + - state_successor_rep: |States| by |States| successor representation. + - state_probs: (m by) |States| vector of initial state probabilities. + ''' + state_probs = tf.convert_to_tensor(state_probs) + if len(state_probs.shape) < 2: + state_probs = tf.expand_dims(state_probs, axis=0) + return tf.matmul(state_probs, state_successor_rep) +@deprecated( + details=( + 'Outdated and poorly named. Use state and action policy evaluation methods directly instead.' + ) +) # yapf:disable def value_ops(Pi, root_op, transition_model_op, reward_model_op, **kwargs): action_values_op = dual_action_value_policy_evaluation_op( transition_model_op, Pi, reward_model_op, **kwargs) @@ -207,6 +226,11 @@ def value_ops(Pi, root_op, transition_model_op, reward_model_op, **kwargs): return action_values_op, state_values_op, ev_op +@deprecated( + details=( + 'Outdated and poorly named. Use state and action policy evaluation methods directly instead.' + ) +) # yapf:disable def associated_ops(action_weights, root_op, transition_model_op, @@ -223,58 +247,3 @@ def associated_ops(action_weights, Pi, root_op, transition_model_op, reward_model_op, **kwargs) return Pi, action_values_op, state_values_op, ev_op - - -def state_successor_policy_evaluation_op(transitions, - policy, - gamma=0.9, - threshold=1e-15, - max_num_iterations=-1, - M_0=None): - transitions = tf.convert_to_tensor(transitions) - num_states = transitions.shape[0].value - - if M_0 is None: - M_0 = tf.eye(num_states) - - weighted_transitions = ( - transitions * tf.expand_dims(gamma * policy, axis=-1)) - - state_to_state = tf.reduce_sum(weighted_transitions, axis=1) - - def M_dp1_op(M_d): - future_return = M_d @ state_to_state - return tf.linalg.set_diag(future_return, - tf.diag_part(future_return) + 1.0 - gamma) - - def error_above_threshold(M_d, M_dp1): - return tf.greater(tf.reduce_sum(tf.abs(M_dp1 - M_d)), threshold) - - def cond(d, M_d, M_dp1): - error_is_high = True if threshold is None else error_above_threshold( - M_d, M_dp1) - return tf.logical_or( - tf.logical_and(tf.less(max_num_iterations, 1), error_is_high), - tf.logical_and(tf.less(d, max_num_iterations), error_is_high)) - - return tf.while_loop( - cond, - lambda d, _, M_d: [d + 1, M_d, M_dp1_op(M_d)], - [1, M_0, M_dp1_op(M_0)], - parallel_iterations=1)[-1] - - -def state_distribution(state_successor_rep, state_probs): - ''' - Probability of terminating in each state. - - |States| by 1 Tensor - - Parameters: - - state_successor_rep: |States| by |States| successor representation. - - state_probs: (m by) |States| vector of initial state probabilities. - ''' - state_probs = tf.convert_to_tensor(state_probs) - if len(state_probs.shape) < 2: - state_probs = tf.expand_dims(state_probs, axis=0) - return tf.matmul(state_probs, state_successor_rep)