Skip to content

Commit

Permalink
Fixed wrong logs prefixes in KTOTrainer (#1641)
Browse files Browse the repository at this point in the history
* Fixed wrong logs prefixes in KTOTrainer

* Pre-commit formating
  • Loading branch information
bartoszzuk authored May 14, 2024
1 parent 5aeb752 commit d632a5b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,26 +1525,26 @@ def log(self, logs: Dict[str, float]) -> None:
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
logs[f"{train_eval}/rewards/{split}"] = (
logs[f"{prefix}rewards/{split}"] = (
torch.Tensor(self._stored_metrics[train_eval][f"rewards/{split}_sum"]).sum().item() / count_sum
)
logs[f"{train_eval}/logps/{split}"] = (
logs[f"{prefix}logps/{split}"] = (
torch.Tensor(self._stored_metrics[train_eval][f"logps/{split}_sum"]).sum().item() / count_sum
)
for key in [f"count/{split}", f"rewards/{split}_sum", f"logps/{split}_sum"]:
del self._stored_metrics[train_eval][key]
# calculate reward margin
if f"{train_eval}/rewards/chosen" in logs and f"{train_eval}/rewards/rejected" in logs:
logs[f"{train_eval}/rewards/margins"] = (
logs[f"{train_eval}/rewards/chosen"] - logs[f"{train_eval}/rewards/rejected"]
)
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{train_eval}/{key}"] = torch.Tensor(metrics).mean().item()
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

Expand Down

0 comments on commit d632a5b

Please sign in to comment.