Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
‘whl’ committed Aug 1, 2023
1 parent 286d976 commit b229216
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions ding/policy/prompt_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ding.rl_utils import get_train_sample
from ding.torch_utils import Adam, to_device
from ding.utils import POLICY_REGISTRY
from ding.utils import POLICY_REGISTRY, split_data_generator
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy

Expand Down Expand Up @@ -85,15 +85,17 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
data = to_device(data, self._device)

return_infos = []
for i in range(0, len(data), self._cfg.learn.batch_size):
batch = default_collate(data[i:i + self._cfg.learn.batch_size])
for batch in split_data_generator(data, self._cfg.learn.batch_size):
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
for ii in range(len(cand_samples)):
cand_samples[ii] = cand_samples[ii][0]
output = self._learn_model.forward(train_samples, cand_samples)
return_ = batch['return']

if self._cuda:
return_ = to_device(return_, self._device)

# calculate PG loss
real_act = []
for b in range(batch['action'].shape[0]):
Expand Down

0 comments on commit b229216

Please sign in to comment.