|
4 | 4 | import torch
|
5 | 5 | import torch.nn.functional as F
|
6 | 6 | from torch.nn.parameter import Parameter
|
7 |
| -from vllm.liquid.sharded_parameter import ShardedParameter, QKVShardedParameter |
| 7 | +from vllm.liquid.sharded_parameter import ShardedParameter, QKVShardedParameter,GateUpShardedParameter |
8 | 8 |
|
9 | 9 | from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
10 | 10 | get_tensor_model_parallel_world_size,
|
@@ -94,15 +94,28 @@ def create_weights(self, layer: torch.nn.Module,
|
94 | 94 | shard_dim: int = -1,
|
95 | 95 | param_class = ShardedParameter,
|
96 | 96 | **extra_weight_attrs):
|
97 |
| - weight = param_class( |
98 |
| - data=torch.empty(sum(output_partition_sizes), |
99 |
| - input_size_per_partition, |
100 |
| - dtype=params_dtype), |
101 |
| - num_shards=len(shard_ids), |
102 |
| - shard_dim=shard_dim, |
103 |
| - shard_ids=shard_ids, |
104 |
| - requires_grad=False, |
105 |
| - ) |
| 97 | + if param_class == QKVShardedParameter: |
| 98 | + weight = QKVShardedParameter( |
| 99 | + data=torch.empty(sum(output_partition_sizes), |
| 100 | + input_size_per_partition, |
| 101 | + dtype=params_dtype), |
| 102 | + num_shards=len(shard_ids), |
| 103 | + shard_dim=shard_dim, |
| 104 | + shard_ids=shard_ids, |
| 105 | + requires_grad=False, |
| 106 | + num_heads_ratio=extra_weight_attrs['num_heads_ratio'], |
| 107 | + num_kv_heads_ratio=extra_weight_attrs['num_kv_heads_ratio'], |
| 108 | + ) |
| 109 | + else: |
| 110 | + weight = param_class( |
| 111 | + data=torch.empty(sum(output_partition_sizes), |
| 112 | + input_size_per_partition, |
| 113 | + dtype=params_dtype), |
| 114 | + num_shards=len(shard_ids), |
| 115 | + shard_dim=shard_dim, |
| 116 | + shard_ids=shard_ids, |
| 117 | + requires_grad=False, |
| 118 | + ) |
106 | 119 | set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
107 | 120 | layer.register_parameter("weight", weight)
|
108 | 121 | set_weight_attrs(weight, extra_weight_attrs)
|
@@ -276,6 +289,8 @@ def __init__(self,
|
276 | 289 | shard_ids: List[int] = [0],
|
277 | 290 | total_num_shards: int = 1,
|
278 | 291 | param_class = ShardedParameter,
|
| 292 | + num_heads_ratio: int=1, |
| 293 | + num_kv_heads_ratio: int=1, |
279 | 294 | ):
|
280 | 295 | super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
281 | 296 | quant_config)
|
@@ -310,6 +325,8 @@ def __init__(self,
|
310 | 325 | shard_ids=shard_ids,
|
311 | 326 | shard_dim=shard_dim,
|
312 | 327 | param_class=param_class,
|
| 328 | + num_heads_ratio=num_heads_ratio, |
| 329 | + num_kv_heads_ratio=num_kv_heads_ratio, |
313 | 330 | )
|
314 | 331 | if bias:
|
315 | 332 | self.bias = param_class(
|
@@ -446,6 +463,8 @@ def __init__(self,
|
446 | 463 | shard_ids=shard_ids,
|
447 | 464 | total_num_shards=total_num_shards,
|
448 | 465 | param_class=QKVShardedParameter,
|
| 466 | + num_heads_ratio=self.num_heads, |
| 467 | + num_kv_heads_ratio=self.num_kv_heads, |
449 | 468 | )
|
450 | 469 |
|
451 | 470 | def weight_loader(self,
|
@@ -737,3 +756,175 @@ def extra_repr(self) -> str:
|
737 | 756 | s += f", tp_size={self.tp_size}"
|
738 | 757 | s += f", reduce_results={self.reduce_results}"
|
739 | 758 | return s
|
| 759 | + |
| 760 | + |
| 761 | +class MergedColumnParallelLinear(ColumnParallelLinear): |
| 762 | + """Packed linear layers with column parallelism. |
| 763 | +
|
| 764 | + Similar to ColumnParallelLinear, but the weight matrix is concatenated |
| 765 | + along the output dimension. When the weight matrix is loaded, the |
| 766 | + different partitions are sharded separately. |
| 767 | +
|
| 768 | + Args: |
| 769 | + input_size: input dimension of the linear layer. |
| 770 | + output_sizes: list of output dimensions of the linear layer. |
| 771 | + bias: If true, add bias. |
| 772 | + gather_output: If true, call all-gather on output and make the output |
| 773 | + available to all GPUs, otherwise, every GPU will have |
| 774 | + its own output. |
| 775 | + skip_bias_add: This was added to enable performance optimizations where |
| 776 | + bias can be fused with other element-wise operations. we |
| 777 | + skip adding bias but instead return it. |
| 778 | + params_dtype: Data type for the parameters. |
| 779 | + quant_config: Quantization configure. |
| 780 | + """ |
| 781 | + |
| 782 | + def __init__(self, |
| 783 | + input_size: int, |
| 784 | + output_sizes: List[int], |
| 785 | + bias: bool = True, |
| 786 | + gather_output: bool = False, |
| 787 | + skip_bias_add: bool = False, |
| 788 | + params_dtype: Optional[torch.dtype] = None, |
| 789 | + quant_config: Optional[QuantizationConfig] = None, |
| 790 | + shard_ids: List[int] = [0], |
| 791 | + total_num_shards: int = 1,): |
| 792 | + self.output_sizes = output_sizes |
| 793 | + # tp_size = get_tensor_model_parallel_world_size() |
| 794 | + # assert all(output_size % tp_size == 0 for output_size in output_sizes) |
| 795 | + super().__init__(input_size=input_size, |
| 796 | + output_size=sum(output_sizes), |
| 797 | + bias=bias, |
| 798 | + gather_output=gather_output, |
| 799 | + skip_bias_add=skip_bias_add, |
| 800 | + params_dtype=params_dtype, |
| 801 | + quant_config=quant_config, |
| 802 | + shard_ids=shard_ids, |
| 803 | + total_num_shards=total_num_shards, |
| 804 | + param_class=GateUpShardedParameter, |
| 805 | + ) |
| 806 | + |
| 807 | + def weight_loader(self, |
| 808 | + param: Parameter, |
| 809 | + loaded_weight: torch.Tensor, |
| 810 | + loaded_shard_id: Optional[int] = None): |
| 811 | + |
| 812 | + param_data = param.data |
| 813 | + output_dim = getattr(param, "output_dim", None) |
| 814 | + # Special case for AQLM codebooks. |
| 815 | + is_metadata = getattr(param, "is_metadata", False) |
| 816 | + |
| 817 | + param_shard_splitter = getattr(param, "shard_splitter", None) |
| 818 | + |
| 819 | + if output_dim is not None and param_shard_splitter is not None: |
| 820 | + raise NotImplementedError( |
| 821 | + "We do not currently support output_dim != None and " |
| 822 | + "shard_splitter != None for a parameter. Please open an issue." |
| 823 | + ) |
| 824 | + # If a parameter has defined a shard_splitter to be used for |
| 825 | + # the weight, it should be applied before the weight is |
| 826 | + # loaded/copied to the parameter. The shard_splitter applies |
| 827 | + # logic by using the loaded_shard_id to ensure that the loaded |
| 828 | + # param is loaded to the correct location |
| 829 | + # within the parameter defined by the linear method. |
| 830 | + if loaded_shard_id is None and param_shard_splitter is not None: |
| 831 | + raise NotImplementedError( |
| 832 | + "We do not currently support loaded_shard_id == None and " |
| 833 | + "shard_splitter != None for a parameter. Please open an issue." |
| 834 | + ) |
| 835 | + |
| 836 | + # Special case for Fp8 scales. |
| 837 | + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", |
| 838 | + None) |
| 839 | + |
| 840 | + if loaded_shard_id is None: |
| 841 | + # Loaded weight is already packed. |
| 842 | + if output_dim is None: |
| 843 | + assert param_data.shape == loaded_weight.shape |
| 844 | + param_data.copy_(loaded_weight) |
| 845 | + return |
| 846 | + current_shard_offset = 0 |
| 847 | + shard_offsets = [] |
| 848 | + for i, output_size in enumerate(self.output_sizes): |
| 849 | + shard_offsets.append((i, current_shard_offset, output_size)) |
| 850 | + current_shard_offset += output_size |
| 851 | + packed_dim = getattr(param, "packed_dim", None) |
| 852 | + for shard_id, shard_offset, shard_size in shard_offsets: |
| 853 | + # Special case for Quantization. |
| 854 | + # If quantized, we need to adjust the offset and size to account |
| 855 | + # for the packing. |
| 856 | + if packed_dim == output_dim: |
| 857 | + shard_size = shard_size // param.pack_factor |
| 858 | + shard_offset = shard_offset // param.pack_factor |
| 859 | + # Special case for Marlin. |
| 860 | + shard_size, shard_offset = adjust_marlin_shard( |
| 861 | + param, shard_size, shard_offset) |
| 862 | + |
| 863 | + loaded_weight_shard = loaded_weight.narrow( |
| 864 | + output_dim, shard_offset, shard_size) |
| 865 | + self.weight_loader(param, loaded_weight_shard, shard_id) |
| 866 | + return |
| 867 | + |
| 868 | + assert loaded_shard_id < len(self.output_sizes) |
| 869 | + tp_rank = get_tensor_model_parallel_rank() |
| 870 | + tp_size = get_tensor_model_parallel_world_size() |
| 871 | + if output_dim is not None: |
| 872 | + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size |
| 873 | + shard_size = self.output_sizes[loaded_shard_id] // tp_size |
| 874 | + # Special case for quantization. |
| 875 | + # If quantized, we need to adjust the offset and size to account |
| 876 | + # for the packing. |
| 877 | + packed_dim = getattr(param, "packed_dim", None) |
| 878 | + if packed_dim == output_dim: |
| 879 | + shard_size = shard_size // param.pack_factor |
| 880 | + shard_offset = shard_offset // param.pack_factor |
| 881 | + # Special case for Marlin. |
| 882 | + shard_size, shard_offset = adjust_marlin_shard( |
| 883 | + param, shard_size, shard_offset) |
| 884 | + |
| 885 | + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) |
| 886 | + if use_bitsandbytes: |
| 887 | + shard_size = loaded_weight.shape[output_dim] |
| 888 | + shard_offset = loaded_weight.shape[output_dim] * \ |
| 889 | + loaded_shard_id |
| 890 | + |
| 891 | + param_data = param_data.narrow(output_dim, shard_offset, |
| 892 | + shard_size) |
| 893 | + start_idx = tp_rank * shard_size |
| 894 | + loaded_weight = loaded_weight.narrow(output_dim, start_idx, |
| 895 | + shard_size) |
| 896 | + # Special case for AQLM codebooks. |
| 897 | + elif is_metadata: |
| 898 | + # metadata indicates fixed size concatenated along dim 0 |
| 899 | + shard_size = loaded_weight.shape[0] |
| 900 | + shard_offset = loaded_shard_id * shard_size |
| 901 | + param_data = param_data.narrow(0, shard_offset, shard_size) |
| 902 | + |
| 903 | + # If a param_shard_splitter is defined by the LinearMethod, use it. |
| 904 | + elif param_shard_splitter is not None: |
| 905 | + logical_widths = getattr(param, "logical_widths", None) |
| 906 | + param_data, loaded_weight = param_shard_splitter( |
| 907 | + param_data, loaded_weight, loaded_shard_id, logical_widths) |
| 908 | + |
| 909 | + # Special case for Fp8 scales. |
| 910 | + elif fp8_scales_shard_indexer is not None: |
| 911 | + param_data, loaded_weight = fp8_scales_shard_indexer( |
| 912 | + param_data, loaded_weight, loaded_shard_id) |
| 913 | + |
| 914 | + else: |
| 915 | + ignore_warning = getattr(param, "ignore_warning", False) |
| 916 | + if not ignore_warning: |
| 917 | + logger.warning( |
| 918 | + "Loading a weight without `output_dim` attribute in " |
| 919 | + "MergedColumnParallelLinear, assume the weight is " |
| 920 | + "the same for all partitions.") |
| 921 | + |
| 922 | + if fp8_scales_shard_indexer is None: |
| 923 | + if len(param_data.shape) == 0: |
| 924 | + param_data = param_data.reshape(1) |
| 925 | + |
| 926 | + if len(loaded_weight.shape) == 0: |
| 927 | + loaded_weight = loaded_weight.reshape(1) |
| 928 | + |
| 929 | + assert param_data.shape == loaded_weight.shape |
| 930 | + param_data.copy_(loaded_weight) |
0 commit comments