@@ -537,6 +537,7 @@ def batch_select_indices(self, indices: torch.Tensor):
537537 self .value_cache [layer_idx ] = self .value_cache [layer_idx ][indices , ...]
538538
539539
540+ # Utilities for `DynamicCache` <> torch.export support
540541def _flatten_dynamic_cache (
541542 dynamic_cache : DynamicCache ,
542543):
@@ -584,15 +585,16 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
584585 return torch .utils ._pytree .tree_flatten (dictionary )[0 ]
585586
586587
587- torch .utils ._pytree .register_pytree_node (
588- DynamicCache ,
589- _flatten_dynamic_cache ,
590- _unflatten_dynamic_cache ,
591- serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
592- flatten_with_keys_fn = _flatten_with_keys_dynamic_cache ,
593- )
594- # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
595- torch .fx ._pytree .register_pytree_flatten_spec (DynamicCache , _flatten_dynamic_cache_for_fx )
588+ if is_torch_greater_or_equal ("2.2" ):
589+ torch .utils ._pytree .register_pytree_node (
590+ DynamicCache ,
591+ _flatten_dynamic_cache ,
592+ _unflatten_dynamic_cache ,
593+ serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
594+ flatten_with_keys_fn = _flatten_with_keys_dynamic_cache ,
595+ )
596+ # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
597+ torch .fx ._pytree .register_pytree_flatten_spec (DynamicCache , _flatten_dynamic_cache_for_fx )
596598
597599
598600class OffloadedCache (DynamicCache ):
0 commit comments