3
3
import gc
4
4
import math
5
5
import pathlib
6
- import threading
7
- import time
8
6
import traceback
9
7
import torch
10
8
import uuid
@@ -57,10 +55,11 @@ class ExllamaV2Container:
57
55
generator : Optional [ExLlamaV2DynamicGeneratorAsync ] = None
58
56
prompt_template : Optional [PromptTemplate ] = None
59
57
active_loras : List [ExLlamaV2Lora ] = []
58
+ paged : bool = True
60
59
61
60
# Internal config vars
62
61
cache_mode : str = "FP16"
63
- use_cfg : bool = False
62
+ max_batch_size : int = 20
64
63
generation_config : Optional [GenerationConfig ] = None
65
64
66
65
# GPU split vars
@@ -115,10 +114,6 @@ def progress(loaded_modules: int, total_modules: int,
115
114
available devices (default: True)
116
115
'gpu_split' (list[float]): Allocation for weights and (some)
117
116
tensors, per device
118
- 'no_flash_attn' (bool): Turns off flash attention
119
- (increases vram usage) (default: False)
120
- 'use_cfg" (bool): Enables CFG support. Disables flash attention
121
- (default: False)
122
117
"""
123
118
124
119
self .quiet = quiet
@@ -184,18 +179,9 @@ def progress(loaded_modules: int, total_modules: int,
184
179
kwargs .get ("rope_alpha" ), self .calculate_rope_alpha (base_seq_len )
185
180
)
186
181
187
- # Enable CFG if present
188
- self .use_cfg = unwrap (kwargs .get ("use_cfg" ), False )
189
-
190
182
# Enable fasttensors loading if present
191
183
self .config .fasttensors = unwrap (kwargs .get ("fasttensors" ), False )
192
184
193
- # Turn off flash attention if CFG is on
194
- # Workaround until batched FA2 is fixed in exllamav2 upstream
195
- # self.config.no_flash_attn = (
196
- # True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
197
- # )
198
-
199
185
# Try to set prompt template
200
186
self .prompt_template = self .find_prompt_template (
201
187
kwargs .get ("prompt_template" ), model_directory
@@ -345,7 +331,6 @@ def get_model_parameters(self):
345
331
"cache_mode" : self .cache_mode ,
346
332
"chunk_size" : self .config .max_input_len ,
347
333
"num_experts_per_token" : self .config .num_experts_per_token ,
348
- "use_cfg" : self .use_cfg ,
349
334
"prompt_template" : self .prompt_template .name
350
335
if self .prompt_template
351
336
else None ,
@@ -420,10 +405,24 @@ async def load_gen(self, progress_callback=None):
420
405
async for value in iterate_in_threadpool (model_load_generator ):
421
406
yield value
422
407
423
- # TODO: Change these!
424
- # Set the max batch size and check if paged support is available
425
- max_batch_size = 1 if self .config .no_flash_attn else 20
426
- paged = not self .config .no_flash_attn
408
+ # Disable paged mode if the user's min GPU is supported (ampere and above)
409
+ min_compute_capability = min (
410
+ set (
411
+ [
412
+ torch .cuda .get_device_capability (device = module .device_idx )[0 ]
413
+ for module in self .model .modules
414
+ if module .device_idx >= 0
415
+ ]
416
+ )
417
+ )
418
+
419
+ if torch .version .hip or min_compute_capability < 8 :
420
+ logger .warning (
421
+ "An unsupported GPU is found in this configuration. "
422
+ "Switching to compatibility mode. This disables parallel batching."
423
+ )
424
+ self .paged = False
425
+ self .max_batch_size = 1
427
426
428
427
# Create async generator
429
428
self .generator = ExLlamaV2DynamicGeneratorAsync (
@@ -432,8 +431,8 @@ async def load_gen(self, progress_callback=None):
432
431
draft_model = self .draft_model ,
433
432
draft_cache = self .draft_cache ,
434
433
tokenizer = self .tokenizer ,
435
- max_batch_size = max_batch_size ,
436
- paged = paged ,
434
+ max_batch_size = self . max_batch_size ,
435
+ paged = self . paged ,
437
436
)
438
437
439
438
# Clean up any extra vram usage from torch and cuda
@@ -741,7 +740,7 @@ async def generate_gen(self, prompt: str, **kwargs):
741
740
cfg_scale = unwrap (kwargs .get ("cfg_scale" ), 1.0 )
742
741
negative_prompt = None
743
742
if cfg_scale not in [None , 1.0 ]:
744
- if self .use_cfg :
743
+ if self .paged :
745
744
gen_settings .cfg_scale = cfg_scale
746
745
747
746
# If the negative prompt is empty, use the BOS token
@@ -752,8 +751,8 @@ async def generate_gen(self, prompt: str, **kwargs):
752
751
prompts .append (negative_prompt )
753
752
else :
754
753
logger .warning (
755
- "CFG is currently disabled. "
756
- "If your GPU is supported, reload your model with use_cfg = True "
754
+ "CFG is currently disabled because paged mode is disabled . "
755
+ "Please use an ampere (30 series) or higher GPU for CFG support. "
757
756
)
758
757
759
758
gen_settings .token_repetition_penalty = unwrap (
0 commit comments