diff --git a/python/ray/rllib/ppo/runner.py b/python/ray/rllib/ppo/runner.py index 47c89eccb7c8b..46b738ae53c46 100644 --- a/python/ray/rllib/ppo/runner.py +++ b/python/ray/rllib/ppo/runner.py @@ -155,7 +155,7 @@ def load_data(self, trajectories, full_trace): [trajectories["observations"], trajectories["value_targets"] if use_gae else dummy, trajectories["advantages"], - trajectories["actions"].squeeze(), + trajectories["actions"], trajectories["logprobs"], trajectories["vf_preds"] if use_gae else dummy], full_trace=full_trace)