Skip to content

[Bug]: 'dict' object has no attribute 'is_kv_transfer_instance' #19259

Closed
@maobaolong

Description

@maobaolong

Your current environment

The output of python collect_env.py
Your output of `python collect_env.py` here

🐛 Describe the bug

  • run vllm container
docker run -d -it --rm --privileged --entrypoint /bin/bash --network host --name poolv1-mbl-test-2  --shm-size 512g --gpus all   -v /:/disc       vllm/vllm-openai:v0.9.0.1

docker exec -it poolv1-mbl-test-2 bash
pip install lmcache
  • start vllm by lmcache example.

The following python script is copied from examples/others/lmcache/cpu_offload_lmcache.py and did some minor changes to run model locally.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates the example usage of cpu offloading
with LMCache in vLLM v1 or v0.

Usage:

    Specify vLLM version

    -v v0 : Use LMCacheConnector
            model = mistralai/Mistral-7B-Instruct-v0.2
            (Includes enable_chunked_prefill = True)

    -v v1 : Use LMCacheConnectorV1 (default)
            model = meta-llama/Meta-Llama-3.1-8B-Instruct
            (Without enable_chunked_prefill)

Note that `lmcache` is needed to run this example.
Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html
"""

import argparse
import contextlib
import os
import time
from dataclasses import asdict

from lmcache.v1.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.engine.arg_utils import EngineArgs


def setup_environment_variables(vllm_version: str):
    # LMCache-related environment variables
    # Use experimental features in LMCache
    os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
    # LMCache is set to use 256 tokens per chunk
    os.environ["LMCACHE_CHUNK_SIZE"] = "256"
    # Enable local CPU backend in LMCache
    os.environ["LMCACHE_LOCAL_CPU"] = "True"
    # Set local CPU memory limit to 5.0 GB
    os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
    if vllm_version == "v0":
        os.environ["VLLM_USE_V1"] = "0"


@contextlib.contextmanager
def build_llm_with_lmcache(lmcache_connector: str, model: str, vllm_version: str):
    ktc = KVTransferConfig(
        kv_connector=lmcache_connector,
        kv_role="kv_both",
    )
    # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
    # memory. Reduce the value if your GPU has less memory.
    # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
    if vllm_version == "v0":
        llm_args = EngineArgs(
            model=model,
            kv_transfer_config=ktc,
            max_model_len=8000,
            gpu_memory_utilization=0.8,
            enable_chunked_prefill=True,  # Only in v0
            trust_remote_code=True,
        )
    else:
        llm_args = EngineArgs(
            model=model,
            kv_transfer_config=ktc,
            max_model_len=8000,
            gpu_memory_utilization=0.8,
            trust_remote_code=True,
        )

    llm = LLM(**asdict(llm_args))
    try:
        yield llm
    finally:
        # Clean up lmcache backend
        LMCacheEngineBuilder.destroy(ENGINE_NAME)


def print_output(
    llm: LLM,
    prompt: list[str],
    sampling_params: SamplingParams,
    req_str: str,
):
    # Should be able to see logs like the following:
    # `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
    # This indicates that the KV cache has been stored in LMCache.
    start = time.time()
    outputs = llm.generate(prompt, sampling_params)
    print("-" * 50)
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")
    print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
    print("-" * 50)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-v",
        "--version",
        choices=["v0", "v1"],
        default="v1",
        help="Specify vLLM version (default: v1)",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    if args.version == "v0":
        lmcache_connector = "LMCacheConnector"
        model = "/disc/data1/deepseek/DeepSeek-V2-Lite-Chat"
    else:
        lmcache_connector = "LMCacheConnectorV1"
        model = "/disc/data1/deepseek/DeepSeek-V2-Lite-Chat"

    setup_environment_variables(args.version)

    with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
        # This example script runs two requests with a shared prefix.
        # Define the shared prompt and specific prompts
        shared_prompt = "Hello, how are you?" * 1000
        first_prompt = [
            shared_prompt + "Hello, my name is",
        ]
        second_prompt = [
            shared_prompt + "Tell me a very long story",
        ]

        sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)

        # Print the first output
        print_output(llm, first_prompt, sampling_params, "first")

        time.sleep(1)

        # print the second output
        print_output(llm, second_prompt, sampling_params, "second")


if __name__ == "__main__":
    main()
  • Execute the python example.
LMCACHE_REMOTE_SERDE="naive"  python3 cpu_offload_lmcache.py
  • Output
ERROR 06-06 00:28:27 [core.py:500] EngineCore failed to start.
ERROR 06-06 00:28:27 [core.py:500] Traceback (most recent call last):
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 491, in run_engine_core
ERROR 06-06 00:28:27 [core.py:500]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-06 00:28:27 [core.py:500]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 390, in __init__
ERROR 06-06 00:28:27 [core.py:500]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 71, in __init__
ERROR 06-06 00:28:27 [core.py:500]     self.model_executor = executor_class(vllm_config)
ERROR 06-06 00:28:27 [core.py:500]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 52, in __init__
ERROR 06-06 00:28:27 [core.py:500]     self._init_executor()
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 46, in _init_executor
ERROR 06-06 00:28:27 [core.py:500]     self.collective_rpc("init_device")
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 06-06 00:28:27 [core.py:500]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-06 00:28:27 [core.py:500]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
ERROR 06-06 00:28:27 [core.py:500]     return func(*args, **kwargs)
ERROR 06-06 00:28:27 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 604, in init_device
ERROR 06-06 00:28:27 [core.py:500]     self.worker.init_device()  # type: ignore
ERROR 06-06 00:28:27 [core.py:500]     ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 137, in init_device
ERROR 06-06 00:28:27 [core.py:500]     init_worker_distributed_environment(self.vllm_config, self.rank,
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 353, in init_worker_distributed_environment
ERROR 06-06 00:28:27 [core.py:500]     ensure_kv_transfer_initialized(vllm_config)
ERROR 06-06 00:28:27 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/kv_transfer/kv_transfer_state.py", line 61, in ensure_kv_transfer_initialized
ERROR 06-06 00:28:27 [core.py:500]     if (vllm_config.kv_transfer_config.is_kv_transfer_instance
ERROR 06-06 00:28:27 [core.py:500]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 06-06 00:28:27 [core.py:500] AttributeError: 'dict' object has no attribute 'is_kv_transfer_instance'
Process EngineCore_0:
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 504, in run_engine_core
    raise e
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 491, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 390, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 71, in __init__
    self.model_executor = executor_class(vllm_config)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 52, in __init__
    self._init_executor()
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 46, in _init_executor
    self.collective_rpc("init_device")
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 604, in init_device
    self.worker.init_device()  # type: ignore
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 137, in init_device
    init_worker_distributed_environment(self.vllm_config, self.rank,
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 353, in init_worker_distributed_environment
    ensure_kv_transfer_initialized(vllm_config)
  File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/kv_transfer/kv_transfer_state.py", line 61, in ensure_kv_transfer_initialized
    if (vllm_config.kv_transfer_config.is_kv_transfer_instance
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'dict' object has no attribute 'is_kv_transfer_instance'
Traceback (most recent call last):
  File "/disc/data1/baoloongmao/lmcache_whl/20250606/blend_kv_v1/cpu_offload_lmcache.py", line 154, in <module>
    main()
  File "/disc/data1/baoloongmao/lmcache_whl/20250606/blend_kv_v1/cpu_offload_lmcache.py", line 131, in main
    with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 137, in __enter__
    return next(self.gen)
           ^^^^^^^^^^^^^^
  File "/disc/data1/baoloongmao/lmcache_whl/20250606/blend_kv_v1/cpu_offload_lmcache.py", line 80, in build_llm_with_lmcache
    llm = LLM(**asdict(llm_args))
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 1183, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/llm.py", line 255, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 501, in from_engine_args
    return engine_cls.from_vllm_config(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/llm_engine.py", line 123, in from_vllm_config
    return cls(vllm_config=vllm_config,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/llm_engine.py", line 100, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 75, in make_client
    return SyncMPClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 580, in __init__
    super().__init__(
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 418, in __init__
    self._wait_for_engine_startup(output_address, parallel_config)
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 484, in _wait_for_engine_startup
    raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions