diff --git a/OAI/types/model.py b/OAI/types/model.py index 9653324b..a8afa678 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -18,6 +18,7 @@ class ModelCardParameters(BaseModel): cache_mode: Optional[str] = "FP16" prompt_template: Optional[str] = None num_experts_per_token: Optional[int] = None + use_cfg: Optional[bool] = None draft: Optional["ModelCard"] = None diff --git a/config_sample.yml b/config_sample.yml index 557d886c..7f88d946 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -87,7 +87,7 @@ model: # Enables CFG support (default: False) # WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream) - use_cfg: False + #use_cfg: False # Options for draft models (speculative decoding). This will use more VRAM! #draft: diff --git a/main.py b/main.py index 9b0e0c77..93793473 100644 --- a/main.py +++ b/main.py @@ -122,6 +122,7 @@ async def get_current_model(): cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16", prompt_template=prompt_template.name if prompt_template else None, num_experts_per_token=MODEL_CONTAINER.config.num_experts_per_token, + use_cfg=MODEL_CONTAINER.use_cfg, ), logging=gen_logging.PREFERENCES, )