-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Hey, nice work!
Is this intended or a bug? I would think mean and min to be applied directly to the variance like it is done in sum?
EditReward/EditReward/model/qwen2_5_vl_trainer.py
Lines 384 to 397 in 49ea57a
| if self.pooling_strategy == "min": | |
| final_logits = torch.min(torch.stack(logits_per_head, dim=0), dim=0).values | |
| elif self.pooling_strategy == "mean": | |
| final_logits = torch.mean(torch.stack(logits_per_head, dim=0), dim=0) | |
| elif self.pooling_strategy == "sum": | |
| means = stacked[:, :, 0] # [num_heads, B] | |
| sigmas = torch.exp(stacked[:, :, 1]) # [num_heads, B] | |
| final_mean = means.sum(dim=0) # [B] | |
| final_var = (sigmas ** 2).sum(dim=0) # [B] | |
| final_sigma = torch.sqrt(final_var) # [B] | |
| final_logits = torch.stack([final_mean, torch.log(final_sigma)], dim=-1) # [B, 2] | |
| else: | |
| final_logits = stacked.mean(dim=0) |
Happy to contribute with this fork: https://github.com/affromero/EditReward/, which also includes a pyproject.toml for easy installation with uv.
Metadata
Metadata
Assignees
Labels
No labels