Skip to content

Commit

Permalink
allow for belief state wrapper to have different loss weight for reve…
Browse files Browse the repository at this point in the history
…rse autoregressive loss
  • Loading branch information
lucidrains committed Mar 7, 2025
1 parent c8d6913 commit 51139ef
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "x-transformers"
version = "2.1.5"
version = "2.1.6"
description = "X-Transformers"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
8 changes: 5 additions & 3 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def test_lime(
model(x)

def test_belief_state_wrapper():
from x_transformers.belief_state import BeliefStateWrapper
from x_transformers.belief_state_wrapper import BeliefStateWrapper

forward_model = TransformerWrapper(
num_tokens = 20000,
Expand All @@ -721,9 +721,11 @@ def test_belief_state_wrapper():

model = BeliefStateWrapper(
forward_decoder = forward_model,
backward_decoder = backward_model
backward_decoder = backward_model,
backward_ar_loss_weight = 0.5
)

seq = torch.randint(0, 20000, (2, 16))

loss = model(seq)
loss = model(seq, backward = False)
loss.backward()
1 change: 1 addition & 0 deletions x_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
from x_transformers.belief_state_wrapper import BeliefStateWrapper

from x_transformers.continuous import (
ContinuousTransformerWrapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.autograd import Function
from torch.nn import Module, ModuleList
from torch import nn, cat, stack, arange, cartesian_prod
from torch import nn, cat, stack, tensor, arange, cartesian_prod
import torch.nn.functional as F

from x_transformers.x_transformers import (
Expand Down Expand Up @@ -36,7 +36,8 @@ def __init__(
self,
forward_decoder: TransformerWrapper,
backward_decoder: TransformerWrapper,
train_frac_forward_backward_pairs: float = 1.
train_frac_forward_backward_pairs: float = 1.,
backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
):
super().__init__()
assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
Expand Down Expand Up @@ -70,9 +71,17 @@ def __init__(
self.train_frac_fb_pairs = train_frac_forward_backward_pairs
self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.

# loss weighting

self.backward_ar_loss_weight = backward_ar_loss_weight
self.needs_loss_weight = backward_ar_loss_weight != 1.

self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))

def forward(
self,
seq
seq,
backward = True
):
batch, seq_len, device = *seq.shape, seq.device

Expand Down Expand Up @@ -149,14 +158,31 @@ def forward(

fb_loss = F.cross_entropy(
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
rearrange(labels, 'b n fb -> b (fb n)')
rearrange(labels, 'b n fb -> b (fb n)'),
reduction = 'none' if self.needs_loss_weight else 'mean'
)

# maybe loss weighting

if self.needs_loss_weight:
fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n')
fb_loss = fb_loss * self.fwd_bwd_loss_weights
fb_loss = fb_loss.mean()

# backwards

fb_loss.backward()
orig_backward = getattr(fb_loss, 'backward')

def patched_backward_fn(*args, **kwargs):
orig_backward(*args, **kwargs)
orig_forward_embeds.backward(forward_embeds.grad)
orig_backward_embeds.backward(backward_embeds.grad)

# can allow the researcher to call .backward from the outside

orig_forward_embeds.backward(forward_embeds.grad)
orig_backward_embeds.backward(backward_embeds.grad)
if backward:
patched_backward_fn()
else:
setattr(fb_loss, 'backward', patched_backward_fn)

return fb_loss

0 comments on commit 51139ef

Please sign in to comment.