File tree Expand file tree Collapse file tree 6 files changed +6
-0
lines changed Expand file tree Collapse file tree 6 files changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
126126 _no_split_modules = ["LlavaVisionAttention" ]
127127 _skip_keys_device_placement = "past_key_values"
128128 _supports_flash_attn_2 = True
129+ _supports_cache_class = True
129130
130131 def _init_weights (self , module ):
131132 # important: this ported version of Llava isn't meant for training from scratch - only
Original file line number Diff line number Diff line change @@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
232232 _no_split_modules = ["LlavaNextVisionAttention" ]
233233 _skip_keys_device_placement = "past_key_values"
234234 _supports_flash_attn_2 = True
235+ _supports_cache_class = True
235236
236237 def _init_weights (self , module ):
237238 # important: this ported version of LlavaNext isn't meant for training from scratch - only
Original file line number Diff line number Diff line change @@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
272272 _no_split_modules = ["LlavaNextVideoVisionAttention" ]
273273 _skip_keys_device_placement = "past_key_values"
274274 _supports_flash_attn_2 = True
275+ _supports_cache_class = True
275276
276277 def _init_weights (self , module ):
277278 # important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
Original file line number Diff line number Diff line change @@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
127127 _skip_keys_device_placement = "past_key_values"
128128 _supports_flash_attn_2 = False
129129 _supports_sdpa = True
130+ _supports_cache_class = True
130131
131132 def _init_weights (self , module ):
132133 # important: this ported version of PaliGemmaisn't meant for training from scratch - only
Original file line number Diff line number Diff line change @@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
126126 _no_split_modules = ["VideoLlavaVisionAttention" ]
127127 _skip_keys_device_placement = "past_key_values"
128128 _supports_flash_attn_2 = True
129+ _supports_cache_class = True
129130
130131 def _init_weights (self , module ):
131132 std = (
Original file line number Diff line number Diff line change @@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
135135 _no_split_modules = ["VipLlavaVisionAttention" ]
136136 _skip_keys_device_placement = "past_key_values"
137137 _supports_flash_attn_2 = True
138+ _supports_cache_class = True
138139
139140 def _init_weights (self , module ):
140141 # important: this ported version of VipLlava isn't meant for training from scratch - only
You can’t perform that action at this time.
0 commit comments