Skip to content

Commit ecd60d0

Browse files
authored
[CI] fix update metadata job (#36850)
fix updata_metadata job
1 parent 42c489f commit ecd60d0

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

.github/workflows/update_metdata.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- name: Setup environment
2020
run: |
2121
pip install --upgrade pip
22-
pip install datasets pandas==2.0.3
22+
pip install datasets pandas
2323
pip install .[torch,tf,flax]
2424
2525
- name: Update metadata

src/transformers/cache_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def batch_select_indices(self, indices: torch.Tensor):
537537
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
538538

539539

540+
# Utilities for `DynamicCache` <> torch.export support
540541
def _flatten_dynamic_cache(
541542
dynamic_cache: DynamicCache,
542543
):
@@ -584,15 +585,16 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
584585
return torch.utils._pytree.tree_flatten(dictionary)[0]
585586

586587

587-
torch.utils._pytree.register_pytree_node(
588-
DynamicCache,
589-
_flatten_dynamic_cache,
590-
_unflatten_dynamic_cache,
591-
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
592-
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
593-
)
594-
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
595-
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
588+
if is_torch_greater_or_equal("2.2"):
589+
torch.utils._pytree.register_pytree_node(
590+
DynamicCache,
591+
_flatten_dynamic_cache,
592+
_unflatten_dynamic_cache,
593+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
594+
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
595+
)
596+
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
597+
torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx)
596598

597599

598600
class OffloadedCache(DynamicCache):

0 commit comments

Comments
 (0)