Skip to content

[Bug]: Cannot unpickle PostGradPassManager #15223

Closed
@aarnphm

Description

@aarnphm

Your current environment

The output of `python collect_env.py`

nightly, cuda 12.4, torch 2.6

🐛 Describe the bug

This issue doesn't occur when running vllm serve. However, when I was trying to extend additional logics to the OpenAI-compatible server from python sdk, it seemed to run into this pickle issue with the compiler

EngineCore hit an exception: Traceback (most recent call last):
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 332, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 287, in __init__
    super().__init__(vllm_config, executor_class, log_stats)
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 62, in __init__
    num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 121, in _initialize_kv_caches
    available_gpu_memory = self.model_executor.determine_available_memory()
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/executor/abstract.py", line 66, in determine_available_memory
    output = self.collective_rpc("determine_available_memory")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/utils.py", line 2216, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_worker.py", line 157, in determine_available_memory
    self.model_runner.profile_run()
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1464, in profile_run
    hidden_states = self._dummy_run(self.max_num_tokens)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1301, in _dummy_run
    hidden_states = model(
                    ^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/model_executor/models/llama.py", line 529, in forward
    model_output = self.model(input_ids, positions, intermediate_tensors,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/decorators.py", line 238, in __call__
    output = self.compiled_callable(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
    self._return(inst)
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
    self.output.compile_subgraph(
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1101, in compile_subgraph
    self.compile_and_call_fx_graph(
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/__init__.py", line 2385, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/backends.py", line 448, in __call__
    PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/backends.py", line 245, in run
    return super().run(*fake_args)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py", line 230, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/backends.py", line 261, in call_module
    compiler_manager.compile(
                     ^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/backends.py", line 121, in compile
    compiled_graph, handle = self.compiler.compile(
                             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/compiler_interface.py", line 274, in compile
    compiled_graph = compile_fx(
                     ^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1552, in compile_fx
    return compile_fx(
           ^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1863, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 83, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1155, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 203, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 489, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1741, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/paperspace/.local/share/uv/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/compiler_interface.py", line 225, in hijacked_compile_fx_inner
    output = torch._inductor.compile_fx.compile_fx_inner(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 569, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 685, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 857, in codegen_and_compile
    torch._logging.trace_structured(
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_logging/_internal.py", line 1202, in trace_structured
    payload = payload_fn()
              ^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 863, in <lambda>
    payload_fn=lambda: log_graph_runnable(),
                       ^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 852, in log_graph_runnable
    torch._dynamo.repro.after_aot.save_graph_repro(
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 341, in save_graph_repro
    generate_compiler_repro_string(
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 274, in generate_compiler_repro_string
    {generate_config_string(stable_output=stable_output)}
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/_dynamo/debug_utils.py", line 267, in generate_config_string
    {torch._inductor.config.codegen_config()}
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/utils/_config_module.py", line 377, in codegen_config
    for k, v in self._get_dict(
                ^^^^^^^^^^^^^^^
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/torch/utils/_config_module.py", line 351, in _get_dict
    config[key] = copy.deepcopy(getattr(self, key))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/.local/share/uv/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paperspace/.local/share/uv/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/copy.py", line 273, in _reconstruct
    y.__setstate__(state)
  File "/home/paperspace/workspace/BentoVLLM/.venv/lib/python3.11/site-packages/vllm/compilation/pass_manager.py", line 93, in __setstate__
    raise ValueError("Cannot unpickle PostGradPassManager")
torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x7fa45211fb50>' raised:
ValueError: Cannot unpickle PostGradPassManager

The following is the service definition

ENGINE_CONFIG = {
    'tokenizer_mode': 'mistral',
    'config_format': 'mistral',
    'load_format': 'mistral',
    'max_model_len': 4096,
    'enable_prefix_caching': False,
}

openai_api_app = fastapi.FastAPI()


@bentoml.asgi_app(openai_api_app, path='/v1')
@bentoml.service()
class VLLM:
    model = bentoml.models.HuggingFaceModel('mistralai/Ministral-8B-Instruct-2410', exclude=['model*', '*.pth', '*.pt', 'original/**/*'])

    def __init__(self):
        self.exit_stack = contextlib.AsyncExitStack()

    @bentoml.on_startup
    async def init_engine(self) -> None:
        import vllm.entrypoints.openai.api_server as vllm_api_server

        from vllm.utils import FlexibleArgumentParser
        from vllm.entrypoints.openai.cli_args import make_arg_parser

        args = make_arg_parser(FlexibleArgumentParser()).parse_args([])
        args.model = self.model
        args.disable_log_requests = True
        args.max_log_len = 1000
        args.request_logger = None
        args.disable_log_stats = True
        args.use_tqdm_on_load = False
        for key, value in ENGINE_CONFIG.items():
            setattr(args, key, value)

        router = fastapi.APIRouter(lifespan=vllm_api_server.lifespan)
        OPENAI_ENDPOINTS = [
            ['/chat/completions', vllm_api_server.create_chat_completion, ['POST']],
            ['/models', vllm_api_server.show_available_models, ['GET']],
        ]

        for route, endpoint, methods in OPENAI_ENDPOINTS:
            router.add_api_route(path=route, endpoint=endpoint, methods=methods, include_in_schema=True)
        openai_api_app.include_router(router)

        self.engine = await self.exit_stack.enter_async_context(vllm_api_server.build_async_engine_client(args))
        self.model_config = await self.engine.get_model_config()
        self.tokenizer = await self.engine.get_tokenizer()
        args.tool_call_parser = 'mistral'
        args.enable_auto_tool_choice = True

        await vllm_api_server.init_app_state(self.engine, self.model_config, openai_api_app.state, args)

    @bentoml.on_shutdown
    async def teardown_engine(self):
        await self.exit_stack.aclose()

A bentoml service will be run in a subprocess, which I suspect would lead to this error.

with v0 MQ, this would work. (code: link)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions