Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add user-configurable task for models that support both generation and embedding #9424

Merged
merged 20 commits into from
Oct 18, 2024

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Oct 16, 2024

Follow-up to #9303

This PR adds a --task option which is used to determine which model runner (for generation or embedding) to create when initializing vLLM. The default (auto) will select the first task which the model supports. In the case of embedding models which share the same model architecture as the base model, users can explicitly specify which task to initialize vLLM for, e.g.:

# Both models use Phi3VForCausalLM
vllm serve microsoft/Phi-3.5-vision-instruct --task generate <extra_args>
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding <extra_args>

(Also addresses #6282 (comment).)

Since the new task option is semantically related to model argument, I've placed it right after model, before tokenizer. To avoid incompatibilities resulting from this, I have added backwards compatibility and deprecated the usage of positional arguments apart from model in LLM.__init__.

Note: This will introduce a breaking change for VLM2Vec users as they currently do not have to pass --task embedding due to the hardcoded embedding model detection logic. Nevertheless, requiring the user to set the mode explicitly is more maintainable in the long run as the number of embedding models increases.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Oct 16, 2024

Mainly looking for review from @robertgshaw2-neuralmagic since you're also working on embedding models. Also @simon-mo or @njhill for signing off on this CLI/API change.

vllm/config.py Outdated
@@ -33,14 +33,21 @@
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

Task = Literal["generate", "embed"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's your thought about using enums instead of strings?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use strings unless there is a need to attach additional data for each task. Otherwise, we will have to spend additional effort converting between string and enum at the API level.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think auto should default to generate, even if embed is also possible. We can put a warning that this is ambiguous, but I'm wary of this growing to most models. This seems like it would break current flows for users that are using models, like microsoft/Phi-3-vision-128k-instruct

Ultimately my stance is that generation models should be first-class citizens and the expansion of other tasks should be opt-in

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Oct 16, 2024

I think auto should default to generate, even if embed is also possible. We can put a warning that this is ambiguous, but I'm wary of this growing to most models. This seems like it would break current flows for users that are using models, like microsoft/Phi-3-vision-128k-instruct

Ultimately my stance is that generation models should be first-class citizens and the expansion of other tasks should be opt-in

Good point, I have updated this so that --task auto defaults to the first task supported by the model (i.e. generate) instead of throwing an error.

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 16, 2024
Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on default to generate

Comment on lines 124 to 125
*args: Never,
task: TaskOption = "auto",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm this is a bigger change and a bit risky, can we revert this?

Copy link
Member Author

@DarkLight1337 DarkLight1337 Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the deprecation into a decorator (vllm.utils.deprecate_args) where the original args are directly passed through to this function. This should eliminate any chances of a breaking change.

@DarkLight1337 DarkLight1337 force-pushed the model-task branch 2 times, most recently from 9b7746d to ecad240 Compare October 17, 2024 11:03
@DarkLight1337
Copy link
Member Author

@mgoin does this look good to you now?

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, this looks pretty good to me now.
I have the question if we could default task = "auto" in ModelConfig and SchedulerConfig as well. It's okay if you feel strongly about this, it just seems a bit unnecessary to require explicit specification

vllm/config.py Show resolved Hide resolved
vllm/config.py Show resolved Hide resolved
tests/test_config.py Outdated Show resolved Hide resolved
tests/distributed/test_pipeline_parallel.py Outdated Show resolved Hide resolved
@simon-mo simon-mo merged commit 051eaf6 into main Oct 18, 2024
66 checks passed
@comaniac
Copy link
Collaborator

@DarkLight1337 somehow this PR failed some speculative decoding tests. For example:

pytest spec_decode/e2e/test_mlp_correctness.py -k "test_mqa_scorer[1-32-5-test_llm_kwargs0-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0]"

Before this PR merged (7dbe738 in main branch):

1 passed, 20 deselected, 1 warning in 29.35s

After this PR merged (051eaf6 in main branch):

_______________________ test_mqa_scorer[1-32-5-test_llm_kwargs0-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0] ________________________

vllm_runner = <class 'tests.conftest.VllmRunner'>
common_llm_kwargs = {'enforce_eager': True, 'model_name': 'JackFram/llama-160m', 'speculative_model': 'ibm-fms/llama-160m-accelerator'}
per_test_common_llm_kwargs = {}, baseline_llm_kwargs = {}, test_llm_kwargs = {'speculative_disable_mqa_scorer': True}, batch_size = 5, output_len = 32
seed = 1

    @pytest.mark.parametrize(
        "common_llm_kwargs",
        [{
            "model_name": MAIN_MODEL,

            # Skip cuda graph recording for fast test.
            "enforce_eager": True,
            "speculative_model": SPEC_MODEL,
        }])
    @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
    @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
    @pytest.mark.parametrize("test_llm_kwargs",
                             [{
                                 "speculative_disable_mqa_scorer": True,
                             }])
    @pytest.mark.parametrize("batch_size", [1, 5])
    @pytest.mark.parametrize(
        "output_len",
        [
            # Use smaller output len for fast test.
            32,
        ])
    @pytest.mark.parametrize("seed", [1])
    def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
                        baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
                        output_len: int, seed: int):
        """Verify that speculative decoding generates the same output
        with batch expansion scorer and mqa scorer.
        """
>       run_equality_correctness_test(vllm_runner,
                                      common_llm_kwargs,
                                      per_test_common_llm_kwargs,
                                      baseline_llm_kwargs,
                                      test_llm_kwargs,
                                      batch_size,
                                      max_output_len=output_len,
                                      seed=seed,
                                      temperature=0.0)

spec_decode/e2e/test_mlp_correctness.py:470:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
spec_decode/e2e/conftest.py:202: in run_equality_correctness_test
    with vllm_runner(**org_args) as vllm_model:
conftest.py:636: in __init__
    self.model = LLM(
../vllm/utils.py:1073: in inner
    return fn(*args, **kwargs)
../vllm/entrypoints/llm.py:193: in __init__
    self.llm_engine = LLMEngine.from_engine_args(
../vllm/engine/llm_engine.py:570: in from_engine_args
    engine_config = engine_args.create_engine_config()
../vllm/engine/arg_utils.py:984: in create_engine_config
    speculative_config = SpeculativeConfig.maybe_create_spec_config(
../vllm/config.py:1270: in maybe_create_spec_config
    draft_model_config = ModelConfig(
../vllm/config.py:219: in __init__
    supported_tasks, task = self._resolve_task(task, self.hf_config)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <vllm.config.ModelConfig object at 0x769724c36f70>, task_option = 'generate'
hf_config = MLPSpeculatorConfig {
  "architectures": [
    "MLPSpeculatorPreTrainedModel"
  ],
  "emb_dim": 768,
  "inner_dim": 10...d": [
    5,
    3,
    2
  ],
  "torch_dtype": "float16",
  "transformers_version": "4.45.2",
  "vocab_size": 32000
}


    def _resolve_task(
        self,
        task_option: TaskOption,
        hf_config: PretrainedConfig,
    ) -> Tuple[Set[Task], Task]:
        architectures = getattr(hf_config, "architectures", [])

        task_support: Dict[Task, bool] = {
            # NOTE: Listed from highest to lowest priority,
            # in case the model supports multiple of them
            "generate": ModelRegistry.is_text_generation_model(architectures),
            "embedding": ModelRegistry.is_embedding_model(architectures),
        }
        supported_tasks_lst: List[Task] = [
            task for task, is_supported in task_support.items() if is_supported
        ]
        supported_tasks = set(supported_tasks_lst)

        if task_option == "auto":
            selected_task = next(iter(supported_tasks_lst))

            if len(supported_tasks) > 1:
                logger.info(
                    "This model supports multiple tasks: %s. "
                    "Defaulting to '%s'.", supported_tasks, selected_task)
        else:
            if task_option not in supported_tasks:
                msg = (
                    f"This model does not support the '{task_option}' task. "
                    f"Supported tasks: {supported_tasks}")
>               raise ValueError(msg)
E               ValueError: This model does not support the 'generate' task. Supported tasks: set()

../vllm/config.py:286: ValueError

@DarkLight1337 DarkLight1337 deleted the model-task branch October 19, 2024 02:43
charlifu pushed a commit to charlifu/vllm that referenced this pull request Oct 23, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: charlifu <charlifu@amd.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Oct 23, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: Alvant <alvasian@yandex.ru>
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: qishuai <ferdinandzhong@gmail.com>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…ation and embedding (vllm-project#9424)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants