Skip to content

Commit

Permalink
Fix pix2struct (#34374)
Browse files Browse the repository at this point in the history
* fix

* fix and test use_cache test

* style

* remove atol
  • Loading branch information
IlyasMoutawwakil authored Oct 28, 2024
1 parent 1d06379 commit fddbd3c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 25 deletions.
60 changes: 35 additions & 25 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
return relative_buckets

# Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
def compute_bias(self, query_length, key_length, device=None):
def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
Expand All @@ -779,6 +782,7 @@ def compute_bias(self, query_length, key_length, device=None):
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values

# Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
def forward(
self,
hidden_states,
Expand All @@ -796,61 +800,66 @@ def forward(
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length)
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
batch_size, seq_length = hidden_states.shape[:2]

# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None

query_states = self.query(hidden_states).contiguous()
query_states = self.query(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value = past_key_value.cross_attention_cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
curr_past_key_value = past_key_value.self_attention_cache

# get key/value states
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.key(current_states).contiguous()
value_states = self.value(current_states).contiguous()
key_states = self.key(current_states)
value_states = self.value(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True

# compute scores
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))

if position_bias is None:
real_seq_length = cache_position[-1] + 1 if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]

if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask

if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
Expand All @@ -860,23 +869,22 @@ def forward(
position_bias_masked = position_bias

scores += position_bias_masked
# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)

# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask

attn_output = torch.matmul(attn_weights, value_states)
# (batch_size, seq_length, dim)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.output(attn_output)

outputs = (attn_output,) + (past_key_value,) + (position_bias,)
outputs = (attn_output, past_key_value, position_bias)

if output_attentions:
outputs = outputs + (attn_weights,)
Expand Down Expand Up @@ -969,7 +977,10 @@ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optiona
layer_idx=layer_idx,
)

self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
config,
layer_idx=layer_idx,
)

self.mlp = Pix2StructTextLayerFF(config)

Expand Down Expand Up @@ -1019,7 +1030,6 @@ def forward(
query_length=cache_position[-1] + 1,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states, past_key_value = cross_attention_outputs[:2]

Expand Down
11 changes: 11 additions & 0 deletions tests/models/pix2struct/test_modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
Expand All @@ -445,6 +446,16 @@ def test_model(self):
),
)

def test_generative_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config).eval().to(torch_device)

output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)

torch.testing.assert_close(output, output_use_cache)

@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
Expand Down

0 comments on commit fddbd3c

Please sign in to comment.