diff --git a/examples/visualization_demo.py b/examples/visualization_demo.py index 862ae25..a5b00d0 100644 --- a/examples/visualization_demo.py +++ b/examples/visualization_demo.py @@ -192,7 +192,14 @@ def main(_): jitted_run_demo = jax.jit(_run_demo) print("Starting search.") policy_output = jitted_run_demo(rng_key) - print("Selected action:", policy_output.action[0]) + batch_index = 0 + selected_action = policy_output.action[batch_index] + q_value = policy_output.search_tree.summary().qvalues[ + batch_index, selected_action] + print("Selected action:", selected_action) + # To estimate the value of the root state, use the Q-value of the selected + # action. The Q-value is not affected by the exploration at the root node. + print("Selected action Q-value:", q_value) graph = convert_tree_to_graph(policy_output.search_tree) print("Saving tree diagram to:", FLAGS.output_file) graph.draw(FLAGS.output_file, prog="dot")