diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 07c10a3a18c55..d4ede4d2320a7 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -7,7 +7,7 @@ initialized randomly with a fixed seed. """ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, List, Optional, Tuple import torch from torch import nn @@ -54,6 +54,16 @@ class LlamaConfig: tractable_init: bool = False random_seed: int = 0 + def compute_hash(self) -> str: + factors: List[Any] = [] + for k, v in self.__dict__.items(): + if k == "random_seed": + continue + factors.append((k, v)) + factors.sort() + import hashlib + return hashlib.md5(str(factors).encode()).hexdigest() + def __post_init__(self): assert self.mlp_size >= self.hidden_size @@ -263,7 +273,8 @@ def run_model(llama_config, compilation_config = CompilationConfig( level=CompilationLevel.NO_COMPILATION, ) - vllm_config = VllmConfig(compilation_config=compilation_config) + vllm_config = VllmConfig(compilation_config=compilation_config, + additional_config=llama_config) with set_current_vllm_config(vllm_config): model = LlamaModel(config=llama_config, vllm_config=vllm_config, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4f960b441f21d..a8dd628b9cd6f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -619,8 +619,10 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, # the entries for different shapes that we need to either # compile or capture cudagraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} - self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union( - self.capture_sizes) + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() for shape in self.compile_sizes.union(self.capture_sizes): self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, @@ -628,12 +630,17 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, use_cudagraph=shape in self.capture_sizes, ) + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.compilation_config.inductor_hash_cache.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True - # no specific sizes to compile - if self.is_last_graph and not self.to_be_compiled_sizes: - end_monitoring_torch_compile(self.vllm_config) + self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] @@ -662,10 +669,7 @@ def __call__(self, *args) -> Any: # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: - - # save the hash of the inductor graph for the next run - self.compilation_config.inductor_hash_cache.save_to_file() - end_monitoring_torch_compile(self.vllm_config) + self.check_for_ending_compilation() if not entry.use_cudagraph: return entry.runnable(*args) diff --git a/vllm/config.py b/vllm/config.py index 8e556743c8528..765a46e6aeee3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,8 +9,8 @@ from dataclasses import dataclass, field, replace from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, - Final, List, Literal, Mapping, Optional, Set, Tuple, Type, - Union) + Final, List, Literal, Mapping, Optional, Protocol, Set, + Tuple, Type, Union) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -75,6 +75,12 @@ PretrainedConfig]] +class SupportsHash(Protocol): + + def compute_hash(self) -> str: + ... + + class ModelConfig: """Configuration for the model. @@ -2969,6 +2975,10 @@ class VllmConfig: init=True) # type: ignore kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore + # some opaque config, only used to provide additional information + # for the hash computation, mainly used for testing and debugging. + additional_config: SupportsHash = field(default=None, + init=True) # type: ignore instance_id: str = "" def compute_hash(self) -> str: @@ -3000,33 +3010,62 @@ def compute_hash(self) -> str: vllm_factors.append(__version__) if self.model_config: vllm_factors.append(self.model_config.compute_hash()) + else: + vllm_factors.append("None") if self.cache_config: vllm_factors.append(self.cache_config.compute_hash()) + else: + vllm_factors.append("None") if self.parallel_config: vllm_factors.append(self.parallel_config.compute_hash()) + else: + vllm_factors.append("None") if self.scheduler_config: vllm_factors.append(self.scheduler_config.compute_hash()) + else: + vllm_factors.append("None") if self.device_config: vllm_factors.append(self.device_config.compute_hash()) + else: + vllm_factors.append("None") if self.load_config: vllm_factors.append(self.load_config.compute_hash()) + else: + vllm_factors.append("None") if self.lora_config: vllm_factors.append(self.lora_config.compute_hash()) + else: + vllm_factors.append("None") if self.speculative_config: vllm_factors.append(self.speculative_config.compute_hash()) + else: + vllm_factors.append("None") if self.decoding_config: vllm_factors.append(self.decoding_config.compute_hash()) + else: + vllm_factors.append("None") if self.observability_config: vllm_factors.append(self.observability_config.compute_hash()) + else: + vllm_factors.append("None") if self.prompt_adapter_config: vllm_factors.append(self.prompt_adapter_config.compute_hash()) + else: + vllm_factors.append("None") if self.quant_config: pass # should be captured by model_config.quantization if self.compilation_config: vllm_factors.append(self.compilation_config.compute_hash()) + else: + vllm_factors.append("None") if self.kv_transfer_config: vllm_factors.append(self.kv_transfer_config.compute_hash()) - + else: + vllm_factors.append("None") + if self.additional_config: + vllm_factors.append(self.additional_config.compute_hash()) + else: + vllm_factors.append("None") factors.append(vllm_factors) hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0000b09bfaa36..af438f7d5820c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -48,6 +48,7 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.parallel_config.rank = rank self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method