Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
‘whl’ committed Sep 10, 2024
1 parent f529857 commit 8a47348
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ding/policy/prompt_awr.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
for ii in range(len(data['candidate_samples'])):
data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
output = self._collect_model.forward(self._cfg.shot_number, data['train_sample'],
data['candidate_samples'], mode="compute_actor_critic")
output = self._collect_model.forward(
self._cfg.shot_number, data['train_sample'], data['candidate_samples'], mode="compute_actor_critic"
)
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
Expand Down

0 comments on commit 8a47348

Please sign in to comment.