Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pix2struct #34374

Merged
merged 5 commits into from
Oct 28, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,16 +809,16 @@ def forward(
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value = past_key_value.cross_attention_cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
curr_past_key_value = past_key_value.self_attention_cache

# get key/value states
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.key(current_states).contiguous()
value_states = self.value(current_states).contiguous()
Expand All @@ -827,7 +827,7 @@ def forward(
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
Expand All @@ -836,21 +836,23 @@ def forward(

# compute scores
scores = torch.matmul(query_states, key_states.transpose(3, 2))

if position_bias is None:
real_seq_length = cache_position[-1] + 1 if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
position_bias = position_bias[:, :, -seq_length:, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask

if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
Expand Down Expand Up @@ -969,7 +971,10 @@ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optiona
layer_idx=layer_idx,
)

self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
config,
layer_idx=layer_idx,
)

self.mlp = Pix2StructTextLayerFF(config)

Expand Down
Loading