Skip to content

[Bug]: Colocate example leads to errors #13515

Closed
@fingertap

Description

@fingertap

Your current environment

The output of `python collect_env.py`

Image

Driver Version : 525.125.06
CUDA Version : 12.0

Attached GPUs : 8
GPU 00000000:26:00.0
Product Name : NVIDIA A800-SXM4-80GB

🐛 Describe the bug

Error message:
Image

I tried to use a similar method to run LLM:

import ray
from functools import cached_property
from vllm import LLM
from vllm.worker.worker import Worker


class MyLLM(LLM):
    def __init__(self, *args, **kwargs):
        import os

        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
        super().__init__(*args, **kwargs)
        

class Runner:
    model_path: str = "/checkpoints/Qwen2.5-32B"
    data_path: str = "/data/math-train_7500.jsonl"

    num_gpus: int = 8
    tensor_parallel_size: int = 4
    num_prompts_per_iter: int = 128
    num_rollouts_per_iter: int = 64
    dtype: str = "bfloat16"
    gpu_memory_utilization: float = 0.9
    
    def run(self):
        from vllm import SamplingParams
        sampling_params = SamplingParams(n=self.num_rollouts_per_iter)

        for batch in self.debug_data:
            futures = []
            num_prompts_per_worker = len(batch) // len(self.workers)
            for indx, vllm_engine in enumerate(self.vllm_engines):
                prompts = batch[indx * num_prompts_per_worker: (indx + 1) * num_prompts_per_worker]
                futures.append(
                    vllm_engine.generate.remote(
                        prompts,
                        sampling_params=sampling_params
                    )
                )
            results = ray.get(futures)
            import pdb; pdb.set_trace()
    
    @cached_property
    def workers(self):
        from ray.util.placement_group import PlacementGroupSchedulingStrategy
        
        pg = ray.util.placement_group([{"GPU": 1, "CPU": 0}] * self.num_gpus)
        ray.get(pg.ready())
        
        workers = []
        for i in range(self.num_gpus // self.tensor_parallel_size):
            bundle_ids = list(
                range(i * self.tensor_parallel_size, (i + 1) * self.tensor_parallel_size)
            )
            workers.append(
                ray.remote(MyLLM).options(
                    num_gpus=0,
                    num_cpus=0,
                    scheduling_strategy=PlacementGroupSchedulingStrategy(
                        placement_group=pg,
                        placement_group_capture_child_tasks=True,
                        placement_group_bundle_index=bundle_ids
                    ),
                    runtime_env={
                        "env_vars": {
                            "VLLM_RAY_PER_WORKER_GPUS": str(0.1),
                            "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, bundle_ids)),
                        }
                    },
                ).remote(
                    self.model_path,
                    enforce_eager=True,
                    tensor_parallel_size=self.tensor_parallel_size,
                    distributed_executor_backend="ray",
                    gpu_memory_utilization=self.gpu_memory_utilization,
                    enable_sleep_mode=True
                )
            )
        return workers

    @cached_property
    def tokenizer(self):
        from transformers import AutoTokenizer
        
        return AutoTokenizer.from_pretrained(self.model_path)

    @cached_property
    def debug_data(self):
        import json
        from torch.utils.data import DataLoader
        
        data = []

        with open(self.data_path, "r") as f:
            lines = f.read().splitlines()
            
        for line in lines:
            question = json.loads(line)["question"]
            data.append(
                self.tokenizer.apply_chat_template(
                    [{"role": "user", "content": question}],
                    add_generation_prompt=True,
                    tokenize=False,
                )
            )
        return DataLoader(data, batch_size=self.num_prompts_per_iter)
    

if __name__ == "__main__":
    runner = Runner()
    runner.run()

This leads to this error:

Traceback (most recent call last):
  File "/test/vllm_0.7.2_colocate.py", line 112, in <module>
    runner.run()
  File "/test/vllm_0.7.2_colocate.py", line 32, in run
    num_prompts_per_worker = len(batch) // len(self.workers)
  File "/data/miniconda3/envs/vllm7/lib/python3.10/functools.py", line 981, in __get__
    val = self.func(instance)
  File "/test/vllm_0.7.2_colocate.py", line 57, in workers
    ray.remote(MyLLM).options(
  File "/data/miniconda3/envs/vllm7/lib/python3.10/site-packages/ray/actor.py", line 869, in remote
    return actor_cls._remote(args=args, kwargs=kwargs, **updated_options)
  File "/data/miniconda3/envs/vllm7/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/data/miniconda3/envs/vllm7/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 384, in _invocation_actor_class_remote_span
    return method(self, args, kwargs, *_args, **_kwargs)
  File "/data/miniconda3/envs/vllm7/lib/python3.10/site-packages/ray/actor.py", line 1142, in _remote
    placement_group = _configure_placement_group_based_on_context(
  File "/data/miniconda3/envs/vllm7/lib/python3.10/site-packages/ray/util/placement_group.py", line 547, in _configure_placement_group_based_on_context
    check_placement_group_index(placement_group, bundle_index)
  File "/data/miniconda3/envs/vllm7/lib/python3.10/site-packages/ray/util/placement_group.py", line 335, in check_placement_group_index
    elif bundle_index >= placement_group.bundle_count or bundle_index < -1:
TypeError: '>=' not supported between instances of 'list' and 'int'

Here bundle_index is [0,1,2,3].

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions