Skip to content

Commit

Permalink
drop generation config and use hf prompt template by default
Browse files Browse the repository at this point in the history
Signed-off-by: Gene Su <e870252314@gmail.com>
  • Loading branch information
GeneDer committed Feb 16, 2025
1 parent f63a808 commit f6c9541
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 127 deletions.
3 changes: 0 additions & 3 deletions python/ray/llm/_internal/serve/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
# Sentinel object used to indicate that a LoRA adapter config file is missing.
LORA_ADAPTER_CONFIG_NAME = "adapter_config.json"

# Names of files in the fine-tuning checkpoint.
GENERATION_CONFIG_NAME = "rayllm_generation_config.json"

DEFAULT_HEALTH_CHECK_PERIOD_S = int(
os.getenv("RAYLLM_DEFAULT_HEALTH_CHECK_PERIOD_S", "10")
)
Expand Down
3 changes: 3 additions & 0 deletions python/ray/llm/_internal/serve/configs/prompt_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ def generate_prompt(self, messages: Union[Prompt, List[Message]]) -> EngineInput

if isinstance(messages, Prompt):
if isinstance(messages.prompt, str):
if not messages.use_prompt_format:
return EngineInput(text=self.bos + messages.prompt)

raise ValueError("String prompts are not supported.")
messages = messages.prompt

Expand Down
79 changes: 16 additions & 63 deletions python/ray/llm/_internal/serve/configs/server_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import copy

import yaml
import pydantic
import os
Expand Down Expand Up @@ -47,7 +45,11 @@
FALLBACK_MAX_ONGOING_REQUESTS,
MAX_NUM_STOPPING_SEQUENCES,
)
from ray.llm._internal.serve.configs.prompt_formats import Prompt
from ray.llm._internal.serve.configs.prompt_formats import (
Prompt,
HuggingFacePromptFormat,
PromptFormat,
)

GPUType = Enum("GPUType", vars(accelerators))
ModelT = TypeVar("ModelT", bound=BaseModel)
Expand Down Expand Up @@ -360,6 +362,7 @@ class LLMConfig(BaseModelExtended):
)

_supports_vision: bool = PrivateAttr(False)
_prompt_format: PromptFormat = PrivateAttr(default_factory=HuggingFacePromptFormat)

def _infer_supports_vision(self, model_id_or_path: str) -> None:
"""Called in llm node initializer together with other transformers calls. It
Expand All @@ -378,6 +381,10 @@ def apply_checkpoint_info(self, model_id_or_path: str) -> None:
def supports_vision(self) -> bool:
return self._supports_vision

@property
def prompt_format(self) -> PromptFormat:
return self._prompt_format

@property
def input_modality(self) -> str:
"""Returns the input modality of the model. There could be more types in the
Expand Down Expand Up @@ -686,46 +693,10 @@ def from_vllm_finish_reason(
return cls.STOP


# TODO (genesu): remove GenerationConfig
class GenerationConfig(BaseModelExtended):
# prompt_format: Optional[
# Union[HuggingFacePromptFormat]
# ] = Field(
# default=HuggingFacePromptFormat(use_hugging_face_chat_template=True),
# description="Handles chat template formatting and tokenization. If None, prompt formatting will be disabled and the model can be only queried in the completion mode.",
# )
generate_kwargs: Dict[str, Any] = Field(
default={},
description="Extra generation kwargs that needs to be passed into the sampling stage for the deployment (this includes things like temperature, etc.)",
)
stopping_sequences: Optional[List[str]] = Field(
default=None,
description="Stopping sequences (applied after detokenization) to propagate for inference.",
)
stopping_tokens: Optional[List[int]] = Field(
default=[],
description="Stopping tokens (applied before detokenization) to propagate for inference. By default, we use EOS/UNK tokens at inference.",
)

# @field_validator("prompt_format")
# @classmethod
# def default_prompt_format(cls, prompt_format):
# return prompt_format if prompt_format is not None else DisabledPromptFormat()
#
# @property
# def all_generate_kwargs(self) -> Dict[str, Any]:
# return {
# "stopping_sequences": self.stopping_sequences,
# "stopping_tokens": self.stopping_tokens,
# **self.generate_kwargs,
# }


class LoraMirrorConfig(BaseModelExtended):
lora_model_id: str
bucket_uri: str
max_total_tokens: Optional[int]
generation: Optional[GenerationConfig]
sync_args: Optional[List[str]] = None

@field_validator("bucket_uri")
Expand Down Expand Up @@ -755,7 +726,6 @@ def bucket_path(self) -> str:

class DiskMultiplexConfig(BaseModelExtended):
model_id: str
generation: Optional[GenerationConfig]
max_total_tokens: Optional[int]
local_path: str

Expand Down Expand Up @@ -1076,31 +1046,14 @@ def validate_stopping_sequences(cls, values):
return unique_val

@classmethod
def merge_generation_params(
cls: Type[ModelT], prompt: Prompt, generation: GenerationConfig
) -> ModelT:
def from_prompt(cls: Type[ModelT], prompt: Prompt) -> ModelT:
# Extract parameters object from prompt
parameters = prompt.parameters or {}
if not isinstance(parameters, dict):
parameters = parameters.model_dump(exclude_unset=True)

# Merge in the generate kwargs
generate_kwargs_copy = copy.deepcopy(generation.generate_kwargs)
generate_kwargs = merge_dicts(
generate_kwargs_copy,
parameters,
)
generate_kwargs = prompt.parameters or {}
if not isinstance(generate_kwargs, dict):
generate_kwargs = generate_kwargs.model_dump(exclude_unset=True)

# The stoppping sequence needs to be merged manually
generate_kwargs["stop"] = list(
set((parameters.get("stop") or []) + (generation.stopping_sequences or []))
)
generate_kwargs["stop_tokens"] = list(
set(
(parameters.get("stop_tokens") or [])
+ (generation.stopping_tokens or [])
)
)
generate_kwargs["stop"] = set(generate_kwargs.get("stop", []))
generate_kwargs["stop_tokens"] = set(generate_kwargs.get("stop_tokens", []))

return cls.model_validate(generate_kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def _load_model_sync(
{
"model_id": lora_mirror_config.lora_model_id,
"max_total_tokens": lora_mirror_config.max_total_tokens,
"generation": lora_mirror_config.generation,
"local_path": local_path,
"lora_assigned_int_id": global_id_manager.next(),
}
Expand Down
42 changes: 3 additions & 39 deletions python/ray/llm/_internal/serve/deployments/llm/multiplex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
list_subfolders_s3,
)
from ray.llm._internal.serve.configs.server_models import (
GenerationConfig,
LLMConfig,
LoraMirrorConfig,
)
from ray.llm._internal.serve.configs.constants import (
CLOUD_OBJECT_MISSING_EXPIRE_S,
CLOUD_OBJECT_EXISTS_EXPIRE_S,
LORA_ADAPTER_CONFIG_NAME,
GENERATION_CONFIG_NAME,
)
from ray.llm._internal.serve.deployments.server_utils import make_async

Expand Down Expand Up @@ -241,29 +239,6 @@ async def get_lora_finetuned_context_length(bucket_uri: str):
return adapter_config.get("context_length")


async def get_lora_generation_config(bucket_uri: str) -> Optional[GenerationConfig]:
"""Gets the generation config used to tune the LoRA adapter.
Return: Returns the generation config for the adapter, if it
exists. Returns None otherwise.
"""

if bucket_uri.endswith("/"):
bucket_uri = bucket_uri.rstrip("/")
object_uri = f"{bucket_uri}/{GENERATION_CONFIG_NAME}"

object_str_or_missing_message = await get_object_from_cloud(object_uri)

# TODO (shrekris): add deduped logs here to tell whether the correct
# generation config is actually being used.
if object_str_or_missing_message is CLOUD_OBJECT_MISSING:
return None
else:
generation_config_str = object_str_or_missing_message
generation_config = GenerationConfig.model_validate_json(generation_config_str)
return generation_config


def get_lora_model_ids(
dynamic_lora_loading_path: str,
base_model_id: str,
Expand Down Expand Up @@ -316,25 +291,22 @@ def get_lora_model_ids(

async def download_multiplex_config_info(
model_id: str, base_path: str
) -> Tuple[str, int, Optional[GenerationConfig]]:
) -> Tuple[str, int]:
"""Downloads info needed to create a multiplex config.
Downloads objects using cloud storage provider APIs.
Returns: 3-tuple containing
Returns: 2-tuple containing
1. A bucket_uri for the bucket containing LoRA weights and config.
2. The maximum LoRA sequence length.
3. The generation config from the LoRA checkpoint, if the config
exists. Otherwise, this is None.
Raises: HTTPException if the LoRA adapter config file isn't available
in the cloud storage repository.
"""

bucket_uri = f"{base_path}/{model_id}"
ft_context_length = await get_lora_finetuned_context_length(bucket_uri)
generation_config = await get_lora_generation_config(bucket_uri)
return bucket_uri, ft_context_length, generation_config
return bucket_uri, ft_context_length


async def get_lora_model_metadata(
Expand All @@ -357,18 +329,11 @@ async def get_lora_model_metadata(
(
bucket_uri,
ft_context_length,
generation_config,
) = await download_multiplex_config_info(lora_id, base_path)

if generation_config is None:
generation = llm_config.generation_config
else:
generation = generation_config

return {
"model_id": model_id,
"base_model_id": base_model_id,
"generation": generation.model_dump(),
"max_request_context_length": ft_context_length,
# Note (genesu): `bucket_uri` affects where the lora weights are downloaded
# from remote location.
Expand All @@ -386,5 +351,4 @@ async def get_lora_mirror_config(
lora_model_id=model_id,
bucket_uri=metadata["bucket_uri"],
max_total_tokens=metadata["max_request_context_length"],
generation=metadata["generation"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -453,16 +453,12 @@ async def _predict(
self._llm_config.lora_config is not None
), "Must setup lora config for multiplexed requests."
disk_lora_model = await self._disk_lora_model(multiplexed_model_id)
generation = disk_lora_model.generation
else:
disk_lora_model = None
generation = self._llm_config.generation_config

prompt_output = generation.prompt_format.generate_prompt(prompt)
prompt_output = self._llm_config.prompt_format.generate_prompt(prompt)

sampling_params = VLLMSamplingParams.merge_generation_params(
prompt, generation
)
sampling_params = VLLMSamplingParams.from_prompt(prompt)
prompt_text = prompt_output.text
image_input = prompt_output.image
image = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from vllm.lora.request import LoRARequest

from ray.llm._internal.serve.observability.logging import get_logger
from ray.llm._internal.serve.configs.prompt_formats import (
PromptFormat,
)
from ray.llm._internal.serve.configs.server_models import (
BaseModelExtended,
DiskMultiplexConfig,
GCSMirrorConfig,
GenerationConfig,
GenerationRequest,
GPUType,
LLMConfig,
Expand Down Expand Up @@ -56,7 +58,7 @@ class VLLMEngineConfig(BaseModelExtended):
description="The type of accelerator to use. This is used to determine the placement group strategy.",
)
runtime_env: Optional[Dict[str, Any]] = None
generation: GenerationConfig
prompt_format: PromptFormat
engine_kwargs: Dict[str, Any] = {}

@property
Expand Down Expand Up @@ -111,7 +113,7 @@ def from_llm_config(cls, llm_config: LLMConfig) -> "VLLMEngineConfig":
s3_mirror_config=s3_mirror_config,
gcs_mirror_config=gcs_mirror_config,
accelerator_type=llm_config.accelerator_type,
generation=llm_config.generation_config,
prompt_format=llm_config.prompt_format,
engine_kwargs=llm_config.engine_kwargs,
runtime_env=llm_config.runtime_env,
)
Expand Down
13 changes: 3 additions & 10 deletions python/ray/llm/_internal/serve/deployments/llm_node_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,13 @@ async def initialize_node(llm_config: LLMConfig) -> InitializeNodeOutput:
Downloads model, tokenizer, and extra files as necessary.
If an Anytensor fast loading config is provided, model files will not be downloaded.
If the placement strategy is STRICT_PACK, all of the initialization will be run locally
(as all of the workers must be colocated with this process). Else, the initialization
will be run across the placement group bundles.
"""
anytensor_config = llm_config.model_loading_config.anytensor_config
local_node_download_model = NodeModelDownloadable.TOKENIZER_ONLY
if anytensor_config is None:
worker_node_download_model = NodeModelDownloadable.MODEL_AND_TOKENIZER
extra_init_kwargs = {}
else:
worker_node_download_model = NodeModelDownloadable.TOKENIZER_ONLY
extra_init_kwargs = {"anytensor_config": anytensor_config}
worker_node_download_model = NodeModelDownloadable.MODEL_AND_TOKENIZER
extra_init_kwargs = {}

engine_config = llm_config.get_engine_config()
pg = engine_config.get_or_create_pg()
Expand Down Expand Up @@ -271,7 +264,7 @@ def _initialize_local_node(
engine_config.actual_hf_model_id,
trust_remote_code=engine_config.trust_remote_code,
)
prompt_format = engine_config.generation.prompt_format
prompt_format = engine_config.prompt_format
# Note (genesu): The prompt format is always loaded from HF.
prompt_format.set_processor(
engine_config.actual_hf_model_id,
Expand Down
2 changes: 0 additions & 2 deletions python/ray/llm/_internal/serve/deployments/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from ray.llm._internal.serve.observability.logging import get_logger

# TODO (genesu): double check if LLMConfig, LLMRawResponse need to be lazy imported
from ray.llm._internal.serve.configs.server_models import (
ModelData,
LLMConfig,
Expand Down Expand Up @@ -54,7 +53,6 @@ def to_model_metadata(
metadata = {
"model_id": model_config.model_id,
"input_modality": model_config.input_modality,
"generation": model_config.generation_config.model_dump(),
"max_request_context_length": model_config.max_request_context_length,
}

Expand Down

0 comments on commit f6c9541

Please sign in to comment.