Skip to content

Commit

Permalink
[Bugfix]: serialize config by value for --trust-remote-code (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#6751)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
2 people authored and garg-amit committed Oct 28, 2024
1 parent cce2a8b commit 378812e
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 28 deletions.
63 changes: 35 additions & 28 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,25 @@ class ParallelSetup(NamedTuple):
chunked_prefill: bool


class PPTestOptions(NamedTuple):
multi_node_only: bool
trust_remote_code: bool
tokenizer_mode: Optional[str]


@dataclass
class PPTestSettings:
parallel_setups: List[ParallelSetup]
distributed_backends: List[str]
task: TaskOption
trust_remote_code: bool
tokenizer_mode: Optional[str]
test_options: PPTestOptions

@staticmethod
def detailed(
*,
tp_base: int = 1,
pp_base: int = 2,
multi_node_only: bool = False,
task: TaskOption = "auto",
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
Expand Down Expand Up @@ -70,8 +76,9 @@ def detailed(
],
distributed_backends=["mp", "ray"],
task=task,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode),
)

@staticmethod
Expand All @@ -80,6 +87,7 @@ def fast(
tp_base: int = 1,
pp_base: int = 2,
task: TaskOption = "auto",
multi_node_only: bool = False,
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
):
Expand All @@ -92,15 +100,18 @@ def fast(
],
distributed_backends=["mp"],
task=task,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode),
)

def iter_params(self, model_name: str):
opts = self.test_options

for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
self.task, self.trust_remote_code, self.tokenizer_mode)
self.task, opts)


# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
Expand All @@ -110,6 +121,7 @@ def iter_params(self, model_name: str):
GENERATION_MODEL_SETTINGS = {
# [DETAILED TESTS]
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
# [FAST TESTS]
# Uses Llama
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
Expand Down Expand Up @@ -151,10 +163,8 @@ def iter_params(self, model_name: str):
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
"microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(),
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
# FIXME: https://github.com/vllm-project/vllm/issues/8553
# "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(),
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
Expand Down Expand Up @@ -205,6 +215,7 @@ def iter_params(self, model_name: str):
# [LANGUAGE GENERATION]
"meta-llama/Meta-Llama-3-8B",
"ibm/PowerLM-3b",
"microsoft/Phi-3-mini-4k-instruct",
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-multilingual-gemma2",
Expand All @@ -220,19 +231,21 @@ def _compare_tp(
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
test_options: PPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate", "encode"] = "encode",
method: Literal["generate", "encode"],
):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
multi_node_only, trust_remote_code, tokenizer_mode = test_options

if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")

common_args = [
# use half precision for speed and memory savings in CI environment
Expand Down Expand Up @@ -307,7 +320,7 @@ def _compare_tp(

@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"trust_remote_code", "tokenizer_mode"),
"test_options"),
[
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
Expand All @@ -320,23 +333,21 @@ def test_tp_language_generation(
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
task,
trust_remote_code,
tokenizer_mode,
test_options,
num_gpus_available,
method="generate")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"trust_remote_code", "tokenizer_mode"),
"test_options"),
[
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
Expand All @@ -349,23 +360,21 @@ def test_tp_language_embedding(
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
task,
trust_remote_code,
tokenizer_mode,
test_options,
num_gpus_available,
method="encode")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"trust_remote_code", "tokenizer_mode"),
"test_options"),
[
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
Expand All @@ -378,15 +387,13 @@ def test_tp_multimodal_generation(
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
test_options: PPTestOptions,
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
task,
trust_remote_code,
tokenizer_mode,
test_options,
num_gpus_available,
method="generate")
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -924,6 +926,8 @@ def create_engine_config(self) -> EngineConfig:
"supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False

maybe_register_config_serialize_by_value(self.trust_remote_code)

cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else
Expand Down
62 changes: 62 additions & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,68 @@ def get_config(
return config


def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
"""Try to register HF model configuration class to serialize by value
With trust_remote_code, the config class is typically an instance of a
custom class imported from the HF modules cache. The class will not be
importable in spawned workers by default (and won't exist at all on
other nodes), which breaks serialization of the config.
In this function we tell the cloudpickle serialization library to pass
instances of these generated classes by value instead of by reference,
i.e. the class definition is serialized along with its data so that the
class module does not need to be importable on the receiving end. This
registration only works if the modules cache has already been
initialized.
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
"""
if not trust_remote_code:
return

try:
import transformers_modules
except ImportError:
logger.debug("Could not import transformers_modules used for remote"
" code. If remote code is not needed remove"
" `--trust-remote-code`.")
return

try:
import cloudpickle
cloudpickle.register_pickle_by_value(transformers_modules)

# ray vendors its own version of cloudpickle
from vllm.executor.ray_utils import ray
if ray:
ray.cloudpickle.register_pickle_by_value(transformers_modules)

# multiprocessing uses pickle to serialize arguments when using spawn
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
# that contain instances of the custom config class to avoid
# serialization problems if the generated module (and model) has a `.`
# in its name
import multiprocessing
import pickle

from vllm.config import ModelConfig

def _reduce_modelconfig(mc: ModelConfig):
return (pickle.loads, (cloudpickle.dumps(mc), ))

multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)

except Exception as e:
logger.warning(
"Unable to register remote classes used by"
" trust_remote_code with by-value serialization. This may"
" lead to a later error. If remote code is not needed"
" remove `--trust-remote-code`",
exc_info=e)


def load_params_config(model, revision) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format
Expand Down
2 changes: 2 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist]


# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
Expand Down

0 comments on commit 378812e

Please sign in to comment.