Skip to content

Commit

Permalink
deprecate in models
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Sep 6, 2024
1 parent 2b88f4b commit 82b6193
Show file tree
Hide file tree
Showing 17 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

Expand Down Expand Up @@ -1141,7 +1141,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_length()
max_cache_length = past_key_values.get_max_cache_shape()

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def _update_causal_mask(
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if past_key_values is not None and (past_key_values.is_sliding or past_key_values.is_static):
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1311,7 +1311,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1351,7 +1351,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1209,7 +1209,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1297,7 +1297,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def _update_causal_mask(
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if past_key_values is not None and past_key_values.is_static:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand Down Expand Up @@ -1185,7 +1185,7 @@ def prepare_inputs_for_generation(
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
target_length=past_key_values.get_max_cache_shape(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
Expand Down

0 comments on commit 82b6193

Please sign in to comment.