Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pix2struct #34374

Merged
merged 5 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
return relative_buckets

# Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
def compute_bias(self, query_length, key_length, device=None):
def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
Expand All @@ -779,6 +782,7 @@ def compute_bias(self, query_length, key_length, device=None):
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values

# Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
def forward(
self,
hidden_states,
Expand All @@ -796,61 +800,66 @@ def forward(
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length)
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
batch_size, seq_length = hidden_states.shape[:2]

# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None

query_states = self.query(hidden_states).contiguous()
query_states = self.query(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value = past_key_value.cross_attention_cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
curr_past_key_value = past_key_value.self_attention_cache

# get key/value states
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.key(current_states).contiguous()
value_states = self.value(current_states).contiguous()
key_states = self.key(current_states)
value_states = self.value(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True

# compute scores
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))

if position_bias is None:
real_seq_length = cache_position[-1] + 1 if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask

if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
Expand All @@ -860,23 +869,22 @@ def forward(
position_bias_masked = position_bias

scores += position_bias_masked
# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)

# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask

attn_output = torch.matmul(attn_weights, value_states)
# (batch_size, seq_length, dim)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.output(attn_output)

outputs = (attn_output,) + (past_key_value,) + (position_bias,)
outputs = (attn_output, past_key_value, position_bias)

if output_attentions:
outputs = outputs + (attn_weights,)
Expand Down Expand Up @@ -969,7 +977,10 @@ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optiona
layer_idx=layer_idx,
)

self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
config,
layer_idx=layer_idx,
)

self.mlp = Pix2StructTextLayerFF(config)

Expand Down Expand Up @@ -1019,7 +1030,6 @@ def forward(
query_length=cache_position[-1] + 1,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states, past_key_value = cross_attention_outputs[:2]

Expand Down
11 changes: 11 additions & 0 deletions tests/models/pix2struct/test_modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
Expand All @@ -445,6 +446,16 @@ def test_model(self):
),
)

def test_generative_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config).eval().to(torch_device)

output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)

torch.testing.assert_close(output, output_use_cache)

@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
Expand Down
Loading