diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_24.py similarity index 100% rename from examples/offline_inference_semi_structured_sparse.py rename to examples/offline_inference_24.py diff --git a/examples/offline_inference_sparse.py b/examples/offline_inference_sparse.py new file mode 100644 index 0000000000000..b35ac0080e5ef --- /dev/null +++ b/examples/offline_inference_sparse.py @@ -0,0 +1,7 @@ +from vllm import LLM, SamplingParams + +model = LLM("nm-testing/TinyLlama-1.1B-Chat-v1.0-pruned2.4", sparsity="sparse_w16a16") + +sampling_params = SamplingParams(max_tokens=100, temperature=0) +outputs = model.generate("Hello my name is", sampling_params=sampling_params) +print(outputs[0].outputs[0].text) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fb16fec8a9d14..c4084deeb3bba 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -302,8 +302,8 @@ def create_engine_configs( self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.code_revision, self.tokenizer_revision, - self.max_model_len, self.sparsity, - self.quantization, self.enforce_eager, + self.max_model_len, self.quantization, + self.sparsity, self.enforce_eager, self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index fd1757e1f97cf..4b470558b1494 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -63,10 +63,7 @@ def get_model(model_config: ModelConfig, f"{supported_dtypes}") linear_method = quant_config.get_linear_method() if model_config.sparsity is not None: - sparse_config = get_sparse_config(model_config.sparsity, - model_config.model, - model_config.hf_config, - model_config.download_dir) + sparse_config = get_sparse_config(model_config) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < sparse_config.get_min_capability():