Skip to content

Commit 453e748

Browse files
authored
LLaVa: add cache class attribute (#32278)
cache class flag
1 parent 14ee232 commit 453e748

File tree

6 files changed

+6
-0
lines changed

6 files changed

+6
-0
lines changed

src/transformers/models/llava/modeling_llava.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
126126
_no_split_modules = ["LlavaVisionAttention"]
127127
_skip_keys_device_placement = "past_key_values"
128128
_supports_flash_attn_2 = True
129+
_supports_cache_class = True
129130

130131
def _init_weights(self, module):
131132
# important: this ported version of Llava isn't meant for training from scratch - only

src/transformers/models/llava_next/modeling_llava_next.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
232232
_no_split_modules = ["LlavaNextVisionAttention"]
233233
_skip_keys_device_placement = "past_key_values"
234234
_supports_flash_attn_2 = True
235+
_supports_cache_class = True
235236

236237
def _init_weights(self, module):
237238
# important: this ported version of LlavaNext isn't meant for training from scratch - only

src/transformers/models/llava_next_video/modeling_llava_next_video.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
272272
_no_split_modules = ["LlavaNextVideoVisionAttention"]
273273
_skip_keys_device_placement = "past_key_values"
274274
_supports_flash_attn_2 = True
275+
_supports_cache_class = True
275276

276277
def _init_weights(self, module):
277278
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only

src/transformers/models/paligemma/modeling_paligemma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
127127
_skip_keys_device_placement = "past_key_values"
128128
_supports_flash_attn_2 = False
129129
_supports_sdpa = True
130+
_supports_cache_class = True
130131

131132
def _init_weights(self, module):
132133
# important: this ported version of PaliGemmaisn't meant for training from scratch - only

src/transformers/models/video_llava/modeling_video_llava.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
126126
_no_split_modules = ["VideoLlavaVisionAttention"]
127127
_skip_keys_device_placement = "past_key_values"
128128
_supports_flash_attn_2 = True
129+
_supports_cache_class = True
129130

130131
def _init_weights(self, module):
131132
std = (

src/transformers/models/vipllava/modeling_vipllava.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
135135
_no_split_modules = ["VipLlavaVisionAttention"]
136136
_skip_keys_device_placement = "past_key_values"
137137
_supports_flash_attn_2 = True
138+
_supports_cache_class = True
138139

139140
def _init_weights(self, module):
140141
# important: this ported version of VipLlava isn't meant for training from scratch - only

0 commit comments

Comments
 (0)