Skip to content

How to accelerate evaluate when using vLLM engine #2058

Open
@MaximeSongIdris

Description

@MaximeSongIdris

[x] I checked the documentation and related resources and couldn't find an answer to my question.

Your Question
I had an issue with langchain_community.llms.VLLM, so instead of creating a BaseRagasLLM with LangchainLLMWrapper class, I wrote a new subclass of BaseRagasLLM that works with vllm.AsyncLLMEngine.
The evaluation works but my main issue is that it is taking too much time. Each evaluation on a sample generates 5 outputs (=5 context chunks) by the llm model, but instead of doing it in parallel, it is doing it sequentially.
I can't find in the code source how I can change this behavior.

Code Examples

import os
import argparse
import typing as t
import uuid

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.callbacks import Callbacks
from langchain_core.outputs import LLMResult, Generation
from langchain_core.prompt_values import PromptValue
from ragas.cache import CacheInterface
from ragas.embeddings import LangchainEmbeddingsWrapper
from ragas.llms import BaseRagasLLM
from ragas.metrics import LLMContextPrecisionWithReference
from ragas.run_config import RunConfig
from ragas import evaluate
from vllm import AsyncLLMEngine, LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs


# Load Judge LLM with its sampling parameters
sampling_params = SamplingParams(
    repetition_penalty=1.05,
    temperature=0.7,
    top_p=0.8,
    top_k =20,
    max_tokens=512,
)  # Base sampling params from: https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/generation_config.json

judge_llm = AsyncLLMEngine.from_engine_args(
    AsyncEngineArgs(
        model='Qwen/Qwen2.5-72B-Instruct',
        tensor_parallel_size=4,
        trust_remote_code=True,
        dtype="bfloat16"
    )
)  # Load LLM with tensor_parallel_size=4 & bf16 quantization (expected 4 GPUs)
# Create first the vLLM model before using CUDA to avoid conflict with vLLM spawning workers and existing CUDA context

embedding = HuggingFaceEmbeddings(
    model_name='NovaSearch/stella_en_400M_v5',
    model_kwargs={"device": "cuda", "trust_remote_code": "True"},
)  # Load embedding model on 1 GPU

metrics = [  # these metrics relies either on a LLM or on an embedding model (EMB)
    LLMContextPrecisionWithReference(),  # Do we have all relevant contexts ? (LLM)
]

dataset_ragas = [
    {
        'user_input': 'Question ? ...',
        'retrieved_contexts': ['Context 1...', 'Context 2..'],
        'response': 'Answer...',
        'reference': 'Ground truth...',
    }
]

class vLLMWrapper(BaseRagasLLM):
    """
    A wrapper class that adapts vLLM's inference engine to the Ragas-compatible BaseRagasLLM interface.

    This class enables using vLLM for scoring and evaluation tasks within the Ragas framework by implementing
    the `generate_text` and `agenerate_text` method that produces LangChain-compatible `LLMResult` objects.
    Source: https://github.com/explodinggradients/ragas/blob/main/ragas/src/ragas/llms/base.py#L123

    Attributes:
        llm: The vLLM model instance, typically created via `vllm.LLM(...)`.
        sampling_params: A `SamplingParams` object defining temperature, top_p, etc.
        run_config: Optional configuration for controlling how evaluations are executed.
        cache: Optional cache for storing/reusing model outputs.

    """

    def __init__(
        self,
        vllm_model,
        sampling_params,
        run_config: t.Optional[RunConfig] = None,
        cache: t.Optional[CacheInterface] = None,
    ):
        super().__init__(cache=cache)
        self.llm = vllm_model
        self.sampling_params = sampling_params
        
        if run_config is None:  # legacy code 
            run_config = RunConfig()
        self.set_run_config(run_config)

    def generate_text(
        self,
        prompt: PromptValue,
        n: int = 1,
        temperature: t.Optional[float] = None,
        stop: t.Optional[t.List[str]] = None,
        callbacks: Callbacks = None,
    ) -> LLMResult:
        """
        Generates a LangChain-compatible LLMResult from a PromptValue using vLLM.

        This method is designed to be compatible with the BaseRagasLLM interface. It uses the
        preconfigured vLLM engine and sampling parameters to produce completions for a given prompt.

        Args:
            prompt (PromptValue): The input prompt wrapped in a LangChain PromptValue.
            n (int): The number of outputs for the prompt.

        Returns:
            LLMResult: A LangChain LLMResult containing one Generation per prompt.
        """
        # expected arguments from BaseRagasLLM that is kept to have a compatible API
        temperature = None
        stop = None
        callbacks = None

        prompt = prompt.to_string()  # vLLM requires a text as an input
        sampling_params.n = 5        # generate n outputs per input
        sampling_params.best_of = 5

        # vLLM engine will always produce a list[vllm.outputs.RequestOutput]
        # since we only have 1 prompt, the list has 1 entry
        vllm_result = self.llm.generate(prompt, self.sampling_params)[0]
        
        # LangChain's LLMResult expects a list of lists:
        # - The outer list corresponds to each input prompt
        # - The inner list contains one or more Generations per prompt (e.g. multiple outputs for a single input)
        # Furthermore, we keep only the first generation per prompt.
        generations = [
            [Generation(text=output.text.strip()) for output in vllm_result.outputs]
        ]
        ragas_expected_result = LLMResult(generations=generations)

        return ragas_expected_result

    async def agenerate_text(
        self,
        prompt: PromptValue,
        n: int = 1,
        temperature: t.Optional[float] = None,
        stop: t.Optional[t.List[str]] = None,
        callbacks: Callbacks = None,
    ) -> LLMResult:
        # expected arguments from BaseRagasLLM that is kept to have a compatible API
        temperature = None
        stop = None
        callbacks = None

        prompt = prompt.to_string()     # vLLM requires a text as an input
        print(n)
        print(prompt[3800:4100])
        print('\n\n\n')
        sampling_params.n = 5           # generate n outputs per input
        sampling_params.best_of = 5
        request_id = str(uuid.uuid4())  # id used for tracking purpose
        # non-blocking calls, create a request to vLLM engine
        results_generator = self.llm.generate(prompt, self.sampling_params, request_id=request_id)
        
        # get the results
        vllm_result= None
        async for request_output in results_generator:
            vllm_result = request_output
        
        # LangChain's LLMResult expects a list of lists:
        # - The outer list corresponds to each input prompt
        # - The inner list contains one or more Generations per prompt (e.g. multiple outputs for a single input)
        generations = [
            [Generation(text=output.text.strip()) for output in vllm_result.outputs]
        ]
        ragas_expected_result = LLMResult(generations=generations)

        return ragas_expected_result

    def set_run_config(self, run_config: RunConfig):
        self.run_config = run_config

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(llm={self.llm.__class__.__name__}(...))"

results = evaluate(dataset=dataset_ragas, 
                   metrics=metrics,
                   llm=vLLMWrapper(judge_llm, sampling_params),
                   embeddings=LangchainEmbeddingsWrapper(embedding),
                   raise_exceptions=True,
                   run_config=RunConfig(timeout=60, max_retries=2, max_wait=30, max_workers=1),
                   show_progress=True,
                   batch_size=2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    module-metricsthis is part of metrics modulequestionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions