Skip to content

Commit

Permalink
fixed multiprocess gail, and cleaned up model_rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
jdchang1 committed Nov 11, 2023
1 parent 5fe7cfc commit 37fcdda
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 1 addition & 2 deletions src/tril/algorithms/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def update_buffer(self):
terminal_rewards = torch.cat(all_scores, dim=0)
seq_lens = self.buffer.masks.sum(axis=-1)
self.buffer.rewards = torch.zeros(
(self.buffer.total_num_traj, self.max_gen_len), dtype=torch.float32
(self.trajectories_per_update, self.max_gen_len), dtype=torch.float32
)
for reward, length in zip(terminal_rewards, seq_lens):
self.buffer.rewards[:, int(length - 1)] = reward
Expand Down Expand Up @@ -179,7 +179,6 @@ def discriminator_step(self):
) as pbar:
for batch_ix, rollout_data in enumerate(self.buffer_dataloader):
with self.accelerator.accumulate():
# expert_data = next(self.expert_sampler)
# NOTE: we could just grab it from rollout_data.target_ids

chosen_tokens = rollout_data.observations.to(
Expand Down
3 changes: 1 addition & 2 deletions src/tril/rewards/model_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def __init__(
self._metric_tokenizer.truncation_side = "left"
self._metric_model = AutoModelForSequenceClassification.from_pretrained(
model_name
) # .to(self._accelerator.device)
# self._accelerator.prepare(self._metric_model)
)
self._label_ix = label_ix
self._include_prompt_for_eval = include_prompt_for_eval

Expand Down

0 comments on commit 37fcdda

Please sign in to comment.