From 5e8ca973ebd5584582923b8ed1d3d823769a80a5 Mon Sep 17 00:00:00 2001 From: William Lin Date: Tue, 23 Jul 2024 18:49:44 -0700 Subject: [PATCH] [Bugfix] fix flashinfer cudagraph capture for PP (#6708) --- tests/distributed/test_pipeline_parallel.py | 24 +++++++++++++++++++++ vllm/worker/model_runner.py | 14 ++++++------ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 7f555ed9168a4..d666b8a1d44bd 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -61,3 +61,27 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, tp_args.append("--enforce-eager") compare_two_settings(MODEL_NAME, pp_args, tp_args) + + +@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ + (2, "JackFram/llama-160m"), +]) +@pytest.mark.parametrize("ATTN_BACKEND", [ + "FLASH_ATTN", + "FLASHINFER", +]) +def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND): + cudagraph_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--pipeline-parallel-size", + str(PP_SIZE), + "--distributed-executor-backend", + "ray", + ] + os.environ["VLLM_ATTENTION_BACKEND"] = ATTN_BACKEND + + eager_args = cudagraph_args + ["--enforce-eager"] + + compare_two_settings(MODEL_NAME, eager_args, cudagraph_args) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e63be184af16a..073c5a73f739b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1040,9 +1040,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.parallel_config.pipeline_parallel_size): for batch_size in reversed(batch_size_capture_list): if self.attn_backend.get_name() == "flashinfer": - indptr_buffer = indptr_buffer[:batch_size + 1] - last_page_len_buffer = last_page_len_buffer[: - batch_size] + _indptr_buffer = indptr_buffer[:batch_size + 1] + _last_page_len_buffer = last_page_len_buffer[: + batch_size] num_qo_heads = ( self.model_config.get_num_attention_heads( @@ -1055,8 +1055,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: use_tensor_cores = False decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, indptr_buffer, - indices_buffer, last_page_len_buffer, "NHD", + decode_workspace_buffer, _indptr_buffer, + indices_buffer, _last_page_len_buffer, "NHD", use_tensor_cores) kv_cache_dtype = get_kv_cache_torch_dtype( self.kv_cache_dtype, self.model_config.dtype) @@ -1131,10 +1131,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.model, self.attn_backend.get_name()) if self.attn_backend.get_name() == "flashinfer": - graph_runner.flashinfer_indptr_buffer = indptr_buffer + graph_runner.flashinfer_indptr_buffer = _indptr_buffer graph_runner.flashinfer_indices_buffer = indices_buffer graph_runner.flashinfer_last_page_len_buffer = \ - last_page_len_buffer + _last_page_len_buffer graph_runner.flashinfer_decode_workspace_buffer = \ decode_workspace_buffer graph_runner.flashinfer_decode_wrapper = \