@@ -88,6 +88,7 @@ class InferenceConfig:
8888 max_output_len (int): Maximum output length, defaults to 256.
8989 max_input_len (int): Maximum input length, defaults to 256.
9090 dtype (Union[str, torch.dtype]): The data type for weights and activations.
91+ kv_cache_dtype (Optional[str]): The data type of kv_cache, defaults to None.
9192 prompt_template (Optional[str]): The prompt template for generation, defaults to None.
9293 do_sample (bool): Whether to use sampling for generation, defaults to False.
9394 beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
@@ -122,6 +123,7 @@ class InferenceConfig:
122123
123124 # general configs
124125 dtype : Union [str , torch .dtype ] = torch .float16 # use fp16 by default
126+ kv_cache_dtype : Optional [str ] = None
125127
126128 # generation configs
127129 prompt_template : Optional [str ] = None
@@ -177,6 +179,12 @@ def _verify_config(self) -> None:
177179 self .dtype in _ALLOWED_DTYPES
178180 ), f"Expected dtype to be in { _ALLOWED_DTYPES } but found an unknown dtype: { self .dtype } "
179181
182+ if self .kv_cache_dtype :
183+ assert (
184+ self .use_cuda_kernel and self .kv_cache_dtype == "fp8"
185+ ), f"FP8 kv_cache is only supported with use_cuda_kernel open now"
186+ self .kv_cache_dtype = torch .uint8
187+
180188 # skip using casting when the data type is float32
181189 if self .dtype == torch .float32 :
182190 self .high_precision = False
0 commit comments