|
1 | 1 | import enum
|
2 | 2 | import json
|
3 | 3 | import os
|
4 |
| -from dataclasses import dataclass |
| 4 | +from dataclasses import dataclass, fields |
5 | 5 | from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
6 | 6 |
|
7 | 7 | import torch
|
@@ -617,6 +617,159 @@ def __init__(self, device: str = "auto") -> None:
|
617 | 617 | self.device = torch.device(self.device_type)
|
618 | 618 |
|
619 | 619 |
|
| 620 | +class SpeculativeConfig: |
| 621 | + """Configuration for speculative decoding. |
| 622 | +
|
| 623 | + The configuration is currently specialized to draft-model speculative |
| 624 | + decoding with top-1 proposals. |
| 625 | + """ |
| 626 | + |
| 627 | + @staticmethod |
| 628 | + def maybe_create_spec_config( |
| 629 | + target_model_config: ModelConfig, |
| 630 | + target_parallel_config: ParallelConfig, |
| 631 | + target_dtype: str, |
| 632 | + speculative_model: Optional[str], |
| 633 | + num_speculative_tokens: Optional[int], |
| 634 | + ) -> Optional["SpeculativeConfig"]: |
| 635 | + """Create a SpeculativeConfig if possible, else return None. |
| 636 | +
|
| 637 | + This function attempts to create a SpeculativeConfig object based on the |
| 638 | + provided parameters. If the necessary conditions are met, it returns an |
| 639 | + instance of SpeculativeConfig. Otherwise, it returns None. |
| 640 | +
|
| 641 | + Args: |
| 642 | + target_model_config (ModelConfig): The configuration of the target |
| 643 | + model. |
| 644 | + target_parallel_config (ParallelConfig): The parallel configuration |
| 645 | + for the target model. |
| 646 | + target_dtype (str): The data type used for the target model. |
| 647 | + speculative_model (Optional[str]): The name of the speculative |
| 648 | + model, if provided. |
| 649 | + num_speculative_tokens (Optional[int]): The number of speculative |
| 650 | + tokens, if provided. |
| 651 | +
|
| 652 | + Returns: |
| 653 | + Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if |
| 654 | + the necessary conditions are met, else None. |
| 655 | + """ |
| 656 | + |
| 657 | + if (speculative_model is None and num_speculative_tokens is None): |
| 658 | + return None |
| 659 | + |
| 660 | + if speculative_model is not None and num_speculative_tokens is None: |
| 661 | + raise ValueError( |
| 662 | + "Expected both speculative_model and " |
| 663 | + "num_speculative_tokens to be provided, but found " |
| 664 | + f"{speculative_model=} and {num_speculative_tokens=}.") |
| 665 | + |
| 666 | + # TODO: The user should be able to specify revision/quantization/max |
| 667 | + # model len for the draft model. It is not currently supported. |
| 668 | + draft_revision = None |
| 669 | + draft_code_revision = None |
| 670 | + draft_quantization = None |
| 671 | + draft_max_model_len = None |
| 672 | + |
| 673 | + draft_model_config = ModelConfig( |
| 674 | + model=speculative_model, |
| 675 | + tokenizer=target_model_config.tokenizer, |
| 676 | + tokenizer_mode=target_model_config.tokenizer_mode, |
| 677 | + trust_remote_code=target_model_config.trust_remote_code, |
| 678 | + download_dir=target_model_config.download_dir, |
| 679 | + load_format=target_model_config.load_format, |
| 680 | + dtype=target_model_config.dtype, |
| 681 | + seed=target_model_config.seed, |
| 682 | + revision=draft_revision, |
| 683 | + code_revision=draft_code_revision, |
| 684 | + tokenizer_revision=target_model_config.tokenizer_revision, |
| 685 | + max_model_len=draft_max_model_len, |
| 686 | + quantization=draft_quantization, |
| 687 | + enforce_eager=target_model_config.enforce_eager, |
| 688 | + max_context_len_to_capture=target_model_config. |
| 689 | + max_context_len_to_capture, |
| 690 | + max_logprobs=target_model_config.max_logprobs, |
| 691 | + ) |
| 692 | + |
| 693 | + draft_parallel_config = ( |
| 694 | + SpeculativeConfig.create_draft_parallel_config( |
| 695 | + target_parallel_config)) |
| 696 | + |
| 697 | + return SpeculativeConfig( |
| 698 | + draft_model_config, |
| 699 | + draft_parallel_config, |
| 700 | + num_speculative_tokens, |
| 701 | + ) |
| 702 | + |
| 703 | + @staticmethod |
| 704 | + def create_draft_parallel_config( |
| 705 | + target_parallel_config: ParallelConfig) -> ParallelConfig: |
| 706 | + """Create a parallel config for use by the draft worker. |
| 707 | +
|
| 708 | + This is mostly a copy of the target parallel config. In the future the |
| 709 | + draft worker can have a different parallel strategy, e.g. TP=1. |
| 710 | + """ |
| 711 | + draft_parallel_config = ParallelConfig( |
| 712 | + pipeline_parallel_size=target_parallel_config. |
| 713 | + pipeline_parallel_size, |
| 714 | + tensor_parallel_size=target_parallel_config.tensor_parallel_size, |
| 715 | + worker_use_ray=target_parallel_config.worker_use_ray, |
| 716 | + max_parallel_loading_workers=target_parallel_config. |
| 717 | + max_parallel_loading_workers, |
| 718 | + disable_custom_all_reduce=target_parallel_config. |
| 719 | + disable_custom_all_reduce, |
| 720 | + tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, |
| 721 | + ray_workers_use_nsight=target_parallel_config. |
| 722 | + ray_workers_use_nsight, |
| 723 | + placement_group=target_parallel_config.placement_group, |
| 724 | + ) |
| 725 | + |
| 726 | + return draft_parallel_config |
| 727 | + |
| 728 | + def __init__( |
| 729 | + self, |
| 730 | + draft_model_config: ModelConfig, |
| 731 | + draft_parallel_config: ParallelConfig, |
| 732 | + num_speculative_tokens: int, |
| 733 | + ): |
| 734 | + """Create a SpeculativeConfig object. |
| 735 | +
|
| 736 | + Args: |
| 737 | + draft_model_config: ModelConfig for the draft model. |
| 738 | + draft_parallel_config: ParallelConfig for the draft model. |
| 739 | + num_speculative_tokens: The number of tokens to sample from the |
| 740 | + draft model before scoring with the target model. |
| 741 | + """ |
| 742 | + self.draft_model_config = draft_model_config |
| 743 | + self.draft_parallel_config = draft_parallel_config |
| 744 | + self.num_speculative_tokens = num_speculative_tokens |
| 745 | + |
| 746 | + self._verify_args() |
| 747 | + |
| 748 | + def _verify_args(self) -> None: |
| 749 | + if self.num_speculative_tokens <= 0: |
| 750 | + raise ValueError("Expected num_speculative_tokens to be greater " |
| 751 | + f"than zero ({self.num_speculative_tokens}).") |
| 752 | + |
| 753 | + if self.draft_model_config: |
| 754 | + self.draft_model_config.verify_with_parallel_config( |
| 755 | + self.draft_parallel_config) |
| 756 | + |
| 757 | + @property |
| 758 | + def num_lookahead_slots(self) -> int: |
| 759 | + """The number of additional slots the scheduler should allocate per |
| 760 | + step, in addition to the slots allocated for each known token. |
| 761 | +
|
| 762 | + This is equal to the number of speculative tokens, as each speculative |
| 763 | + token must be scored. |
| 764 | + """ |
| 765 | + return self.num_speculative_tokens |
| 766 | + |
| 767 | + def __repr__(self) -> str: |
| 768 | + draft_model = self.draft_model_config.model |
| 769 | + num_spec_tokens = self.num_speculative_tokens |
| 770 | + return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" |
| 771 | + |
| 772 | + |
620 | 773 | @dataclass
|
621 | 774 | class LoRAConfig:
|
622 | 775 | max_lora_rank: int
|
@@ -838,3 +991,36 @@ def _get_and_verify_max_len(
|
838 | 991 | "to incorrect model outputs or CUDA errors. Make sure the "
|
839 | 992 | "value is correct and within the model context size.")
|
840 | 993 | return int(max_model_len)
|
| 994 | + |
| 995 | + |
| 996 | +@dataclass(frozen=True) |
| 997 | +class EngineConfig: |
| 998 | + """Dataclass which contains all engine-related configuration. This |
| 999 | + simplifies passing around the distinct configurations in the codebase. |
| 1000 | + """ |
| 1001 | + |
| 1002 | + model_config: ModelConfig |
| 1003 | + cache_config: CacheConfig |
| 1004 | + parallel_config: ParallelConfig |
| 1005 | + scheduler_config: SchedulerConfig |
| 1006 | + device_config: DeviceConfig |
| 1007 | + lora_config: Optional[LoRAConfig] |
| 1008 | + vision_language_config: Optional[VisionLanguageConfig] |
| 1009 | + speculative_config: Optional[SpeculativeConfig] |
| 1010 | + |
| 1011 | + def __post_init__(self): |
| 1012 | + """Verify configs are valid & consistent with each other. |
| 1013 | + """ |
| 1014 | + self.model_config.verify_with_parallel_config(self.parallel_config) |
| 1015 | + self.cache_config.verify_with_parallel_config(self.parallel_config) |
| 1016 | + |
| 1017 | + if self.lora_config: |
| 1018 | + self.lora_config.verify_with_model_config(self.model_config) |
| 1019 | + self.lora_config.verify_with_scheduler_config( |
| 1020 | + self.scheduler_config) |
| 1021 | + |
| 1022 | + def to_dict(self): |
| 1023 | + """Return the configs as a dictionary, for use in **kwargs. |
| 1024 | + """ |
| 1025 | + return dict( |
| 1026 | + (field.name, getattr(self, field.name)) for field in fields(self)) |
0 commit comments