Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding free dpo #2437

Closed
wants to merge 58 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ca99954
added eos token for ppotrainer
dame-cell Nov 30, 2024
fe1d5f6
remove the unnecessary stuff
dame-cell Nov 30, 2024
b15c635
Update ppo_config.py
dame-cell Nov 30, 2024
1bcb3a4
remove redundant EOS token fallback
dame-cell Dec 1, 2024
2ef2b24
remove redundant EOS token fallback
dame-cell Dec 1, 2024
42a0f73
remove some unnecessary tests stuff
dame-cell Dec 1, 2024
6130a91
added tests and update concatenated_inputs
dame-cell Dec 4, 2024
6732ed2
return only list and also a lot to do
dame-cell Dec 4, 2024
ce67292
padding free not tested but getting closer
dame-cell Dec 10, 2024
91e40aa
rebase and also reevaluate my approach
dame-cell Dec 10, 2024
8a34cb5
merge main
dame-cell Dec 10, 2024
1d38632
fix identation
dame-cell Dec 10, 2024
814d69e
better tests
dame-cell Dec 10, 2024
7dae607
concatenated_forward now supports padding_free
dame-cell Dec 10, 2024
1a59d74
collator now does not return attention masks
dame-cell Dec 10, 2024
562c52e
postion ids and no attention mask works
dame-cell Dec 11, 2024
d194054
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
3855851
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
d9adbfb
Merge branch 'main' into padding_free_dpo
dame-cell Dec 11, 2024
1145006
grad accumalation tests
dame-cell Dec 13, 2024
24f73a4
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 13, 2024
f6bd9e1
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
bbd99cf
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
ba4969d
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
187b1e5
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
1d9ce3e
fix identation
dame-cell Dec 13, 2024
8a974cc
comments update
dame-cell Dec 13, 2024
7f0298b
fix some small issue
dame-cell Dec 13, 2024
2900275
fix some small issue
dame-cell Dec 13, 2024
58e779a
fix some small issue
dame-cell Dec 13, 2024
6a1e251
fix some small issue
dame-cell Dec 13, 2024
f92e056
update concatenate_forward to support padding_fre
dame-cell Dec 13, 2024
457e3a1
fix some small issue
dame-cell Dec 13, 2024
f1789f4
fix some small issue
dame-cell Dec 13, 2024
0103728
So we need to make sure to correctlty handle the list
dame-cell Dec 13, 2024
5837faa
by correclty updatuing concatenated_forward it works now
dame-cell Dec 13, 2024
0321c1d
refactoring concatenated_forward and batched same length seq for padd…
dame-cell Dec 14, 2024
a6e2163
update
dame-cell Dec 14, 2024
1328fc3
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
51a2cc6
Merge branch 'main' into padding_free_dpo
dame-cell Dec 17, 2024
986ed71
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
570b79a
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
9dd9564
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 17, 2024
b781876
Merge branch 'main' into padding_free_dpo
dame-cell Dec 18, 2024
525ecb2
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
c8ce9c8
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
b7fad73
Reverted PPO trainer to original version and updated DPO files
dame-cell Dec 19, 2024
55cd219
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 19, 2024
5e8df69
Updated DPO files
dame-cell Dec 19, 2024
ba1ded1
Merge branch 'main' into padding_free_dpo
dame-cell Dec 20, 2024
0784202
Merge branch 'main' into padding_free_dpo
dame-cell Dec 21, 2024
ddfed7c
update test_dpo_trainer.py
dame-cell Dec 21, 2024
955c7e8
update dpo_trainer.py
dame-cell Dec 21, 2024
ba4356e
update dpo_trainer.py
dame-cell Dec 21, 2024
c61abb5
update dpo_trainer.py
dame-cell Dec 21, 2024
64e9909
update dpo_trainer.py
dame-cell Dec 22, 2024
68186e7
Merge branch 'main' into padding_free_dpo
dame-cell Dec 22, 2024
90f50e4
Merge branch 'main' into padding_free_dpo
dame-cell Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix some small issue
  • Loading branch information
dame-cell committed Dec 13, 2024
commit f1789f46f5aec9fdc9f625fe5c3a5998d823fd79
260 changes: 123 additions & 137 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,36 +1220,35 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
)
else:
# Padding-free processing
if self.padding_free:
input_ids = [torch.cat((p, c), dim=0) for p, c in zip(prompt_input_ids, completion_input_ids)]
position_ids = [torch.cat((p, c), dim=0) for p, c in zip(concatenated_batch['prompt_position_ids'], concatenated_batch['completion_position_ids'])]
loss_mask = [(pos >= 0).long() for pos in position_ids]

model_kwargs["position_ids"] = position_ids

# Apply max length truncation
if self.args.max_length is not None:
input_ids = [ids[:self.args.max_length] for ids in input_ids]
model_kwargs["position_ids"] = [pos[:self.args.max_length] for pos in model_kwargs["position_ids"]]
loss_mask = [mask[:self.args.max_length] for mask in loss_mask]

# Process each sequence individually
all_logits = []
all_labels = []
all_loss_masks = []

for ids, pos_ids in zip(input_ids, model_kwargs["position_ids"]):
single_output = model(
input_ids=ids.unsqueeze(0),
position_ids=pos_ids.unsqueeze(0),
)
all_logits.append(single_output.logits.squeeze(0))
all_labels.append(ids[1:])
all_loss_masks.append(torch.ones_like(ids[1:]).bool())
input_ids = [torch.cat((p, c), dim=0) for p, c in zip(prompt_input_ids, completion_input_ids)]
position_ids = [torch.cat((p, c), dim=0) for p, c in zip(concatenated_batch['prompt_position_ids'], concatenated_batch['completion_position_ids'])]
loss_mask = [(pos >= 0).long() for pos in position_ids]

model_kwargs["position_ids"] = position_ids

# Apply max length truncation
if self.args.max_length is not None:
input_ids = [ids[:self.args.max_length] for ids in input_ids]
model_kwargs["position_ids"] = [pos[:self.args.max_length] for pos in model_kwargs["position_ids"]]
loss_mask = [mask[:self.args.max_length] for mask in loss_mask]

logits = torch.cat(all_logits, dim=0)
labels = torch.cat(all_labels, dim=0)
loss_mask = torch.cat(all_loss_masks, dim=0)
# Process each sequence individually
all_logits = []
all_labels = []
all_loss_masks = []

for ids, pos_ids in zip(input_ids, model_kwargs["position_ids"]):
single_output = model(
input_ids=ids.unsqueeze(0),
position_ids=pos_ids.unsqueeze(0),
)
all_logits.append(single_output.logits.squeeze(0))
all_labels.append(ids[1:])
all_loss_masks.append(torch.ones_like(ids[1:]).bool())

logits = torch.cat(all_logits, dim=0)
labels = torch.cat(all_labels, dim=0)
loss_mask = torch.cat(all_loss_masks, dim=0)

else:
if self.padding_free:
Expand All @@ -1275,6 +1274,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
single_output = model(
input_ids=ids.unsqueeze(0),
position_ids=pos_ids.unsqueeze(0),
**{k: v for k, v in model_kwargs.items() if k != "position_ids"}
)
all_logits.append(single_output.logits.squeeze(0))
all_labels.append(ids[1:])
Expand All @@ -1283,6 +1283,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
logits = torch.cat(all_logits, dim=0)
labels = torch.cat(all_labels, dim=0)
loss_mask = torch.cat(all_loss_masks, dim=0)

else:
# Non-padding-free processing
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
Expand All @@ -1293,117 +1294,102 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
dim=1)

# Flush left to reduce memory usage (only for non-padding-free case)
if not self.padding_free:
for i in range(attention_mask.size(0)):
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx)

# Remove empty columns (only for non-padding-free case)
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1)
input_ids = input_ids[:, :first_empty_col]
attention_mask = attention_mask[:, :first_empty_col]
loss_mask = loss_mask[:, :first_empty_col]

#Apply max length truncation
if self.args.max_length is not None:
input_ids = input_ids[:, :self.args.max_length]
attention_mask = attention_mask[:, :self.args.max_length]
loss_mask = loss_mask[:, :self.args.max_length]

#forward pass
if self.padding_free:
all_logits = []
for ids, pos_ids in zip(input_ids, model_kwargs["position_ids"]):
single_output = model(
input_ids=ids.unsqueeze(0),
position_ids=pos_ids.unsqueeze(0),
**{k: v for k, v in model_kwargs.items() if k != "position_ids"}
)
all_logits.append(single_output.logits.squeeze(0))
logits = torch.cat(all_logits, dim=0)
else:
outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)
logits = outputs.logits

if self.padding_free:
labels = torch.cat([ids[1:] for ids in input_ids])
loss_mask = torch.cat([mask[1:] for mask in loss_mask]).bool()
else:
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()


if self.use_num_logits_to_keep:
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1

# Common processing for both padding-free and non-padding-free cases
if self.use_num_logits_to_keep:
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]

# Handle LLaVA-style models
if logits.shape[:2] != labels.shape[:2]:
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]

# Compute log probabilities
labels[~loss_mask] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)

output = {}

# Handle weighting if enabled
if self.use_weighting:
with torch.no_grad():
logprobs = F.log_softmax(logits, dim=-1)
weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1)
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)
chosen_weights = all_weights[:num_examples]
rejected_weights = all_weights[num_examples:]
output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)

#Handle RPO loss if alpha is set
if self.args.rpo_alpha is not None:
chosen_logits = logits[:num_examples]
chosen_labels = labels[:num_examples]
output["nll_loss"] = F.cross_entropy(
torch.flatten(chosen_logits, end_dim=1),
torch.flatten(chosen_labels, end_dim=1),
ignore_index=0
)

# Normalize logps for IPO loss type
if self.loss_type == "ipo":
all_logps = all_logps / loss_mask.sum(-1)

# Split logps into chosen and rejected
output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]

# Calculate mean logits
output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean()
output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean()
# Flush left to reduce memory usage
for i in range(attention_mask.size(0)):
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx)
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx)

# Add auxiliary loss if enabled
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
# Remove empty columns
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1)
input_ids = input_ids[:, :first_empty_col]
attention_mask = attention_mask[:, :first_empty_col]
loss_mask = loss_mask[:, :first_empty_col]

# Ensure all tensors that need gradients have them
for key in output:
if isinstance(output[key], torch.Tensor):
if not output[key].requires_grad and output[key].is_floating_point():
output[key] = output[key].requires_grad_(True)
# Apply max length truncation
if self.args.max_length is not None:
input_ids = input_ids[:, :self.args.max_length]
attention_mask = attention_mask[:, :self.args.max_length]
loss_mask = loss_mask[:, :self.args.max_length]

# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)
logits = outputs.logits

labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()

# Common processing for both padding-free and non-padding-free cases
if self.use_num_logits_to_keep:
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1

# Truncate labels and loss mask
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]

# Handle LLaVA-style models
if logits.shape[:2] != labels.shape[:2]:
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]

# Compute log probabilities
labels[~loss_mask] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)

output = {}

# Handle weighting if enabled
if self.use_weighting:
with torch.no_grad():
logprobs = F.log_softmax(logits, dim=-1)
weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1)
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)
chosen_weights = all_weights[:num_examples]
rejected_weights = all_weights[num_examples:]
output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)

# Handle RPO loss if alpha is set
if self.args.rpo_alpha is not None:
chosen_logits = logits[:num_examples]
chosen_labels = labels[:num_examples]
output["nll_loss"] = F.cross_entropy(
torch.flatten(chosen_logits, end_dim=1),
torch.flatten(chosen_labels, end_dim=1),
ignore_index=0
)

# Normalize logps for IPO loss type
if self.loss_type == "ipo":
all_logps = all_logps / loss_mask.sum(-1)

# Split logps into chosen and rejected
output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]

# Calculate mean logits
output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean()
output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean()

# Add auxiliary loss if enabled
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss

# Ensure all tensors that need gradients have them
for key in output:
if isinstance(output[key], torch.Tensor):
if not output[key].requires_grad and output[key].is_floating_point():
output[key] = output[key].requires_grad_(True)

return output

return output



def get_batch_loss_metrics(
Expand Down