|
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 |
|
7 |
| -import_exc = None |
8 |
| - |
9 |
| -try: |
10 |
| - import vllm._punica_C as punica_kernels |
11 |
| -except ImportError as e: |
12 |
| - import_exc = e |
13 |
| - |
14 |
| -if import_exc is None: |
15 |
| - |
16 |
| - def bgmv( |
17 |
| - y: torch.Tensor, |
18 |
| - x: torch.Tensor, |
19 |
| - w_t_all: torch.Tensor, |
20 |
| - indicies: torch.LongTensor, |
21 |
| - layer_idx: int, |
22 |
| - scale: float, |
23 |
| - ): |
24 |
| - """ |
25 |
| - Semantics: |
26 |
| - y[i] += ( |
27 |
| - x[i].unsqueeze(0) |
28 |
| - @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
29 |
| - * scale |
30 |
| - ).squeeze(0) |
31 |
| -
|
32 |
| - Args: |
33 |
| - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. |
34 |
| - x: Shape: `[B, H1]`. Input vectors. |
35 |
| - w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight |
36 |
| - matrices. |
37 |
| - indicies: Shape: `[B]`. Indices of the weight matrices. |
38 |
| - layer_idx: Layer index of the weight matrices. |
39 |
| - scale: Scaling factor. |
40 |
| - """ |
41 |
| - punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) |
42 |
| - |
43 |
| - def add_lora(y: torch.Tensor, |
44 |
| - x: torch.Tensor, |
45 |
| - wa_t_all: torch.Tensor, |
46 |
| - wb_t_all: torch.Tensor, |
47 |
| - indicies: torch.LongTensor, |
48 |
| - layer_idx: int, |
49 |
| - scale: float, |
50 |
| - *, |
51 |
| - buffer: Optional[torch.Tensor] = None): |
52 |
| - """ |
53 |
| - Semantics: |
54 |
| - y[i] += ( |
55 |
| - x[i].unsqueeze(0) |
56 |
| - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
57 |
| - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
58 |
| - * scale |
59 |
| - ).squeeze(0) |
60 |
| -
|
61 |
| - Args: |
62 |
| - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. |
63 |
| - x: Shape: `[B, H1]`. Input vectors. |
64 |
| - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed |
65 |
| - LoRA A matrices. |
66 |
| - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed |
67 |
| - LoRA B matrices. |
68 |
| - indicies: Shape: `[B]`. Indices of the LoRA weights. |
69 |
| - layer_idx: Layer index of LoRA weights. |
70 |
| - scale: Scaling factor. |
71 |
| - buffer: Optional. Shape: `[B, R]`. Temporary buffer. |
72 |
| - """ |
73 |
| - r = wb_t_all.size(-1) |
74 |
| - if buffer is None: |
75 |
| - # We set the buffer to be float32 by default to avoid |
76 |
| - # numerical innacuracies that would otherwise happen |
77 |
| - # due to downcasting. |
78 |
| - buffer = torch.zeros((x.size(0), r), |
79 |
| - dtype=torch.float32, |
80 |
| - device=x.device) |
81 |
| - punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, |
82 |
| - 1.0) |
83 |
| - punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, |
84 |
| - scale) |
85 |
| - |
86 |
| - def add_lora_slice(y: torch.Tensor, |
87 |
| - x: torch.Tensor, |
88 |
| - wa_t_all: torch.Tensor, |
89 |
| - wb_t_all: torch.Tensor, |
90 |
| - indicies: torch.LongTensor, |
91 |
| - layer_idx: int, |
92 |
| - scale: float, |
93 |
| - y_offset: int, |
94 |
| - y_slice_size: int, |
95 |
| - *, |
96 |
| - buffer: Optional[torch.Tensor] = None): |
97 |
| - """ |
98 |
| - Same as `add_lora` but you can operate on slices of y. |
99 |
| - Pass whole y, define y_offset and y_slice_size. |
100 |
| -
|
101 |
| - Semantics: |
102 |
| - y[i] += ( |
103 |
| - x[i].unsqueeze(0) |
104 |
| - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
105 |
| - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
106 |
| - * scale |
107 |
| - ).squeeze(0) |
108 |
| -
|
109 |
| - Args: |
110 |
| - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. |
111 |
| - x: Shape: `[B, H1]`. Input vectors. |
112 |
| - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed |
113 |
| - LoRA A matrices. |
114 |
| - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed |
115 |
| - LoRA B matrices. |
116 |
| - indicies: Shape: `[B]`. Indices of the LoRA weights. |
117 |
| - layer_idx: Layer index of LoRA weights. |
118 |
| - scale: Scaling factor. |
119 |
| - y_offset: Offset to apply to the starting column of y. |
120 |
| - y_slice_size: Size of the y column slice. |
121 |
| - """ |
122 |
| - r = wb_t_all.size(-1) |
123 |
| - if buffer is None: |
124 |
| - # We set the buffer to be float32 by default to avoid |
125 |
| - # numerical inaccuracies that would otherwise happen |
126 |
| - # due to downcasting. |
127 |
| - buffer = torch.zeros((x.size(0), r), |
128 |
| - dtype=torch.float32, |
129 |
| - device=x.device) |
130 |
| - punica_kernels.dispatch_bgmv_low_level( |
131 |
| - buffer, |
132 |
| - x, |
133 |
| - wa_t_all, |
134 |
| - indicies, |
135 |
| - layer_idx, |
136 |
| - 1.0, |
137 |
| - x.size(1), |
138 |
| - buffer.size(1), |
139 |
| - 0, |
140 |
| - ) |
141 |
| - punica_kernels.dispatch_bgmv_low_level( |
142 |
| - y, |
143 |
| - buffer, |
144 |
| - wb_t_all, |
145 |
| - indicies, |
146 |
| - layer_idx, |
147 |
| - scale, |
148 |
| - buffer.size(1), |
149 |
| - y_slice_size, |
150 |
| - y_offset, |
151 |
| - ) |
152 |
| - |
153 |
| -else: |
154 |
| - |
155 |
| - def _raise_exc( |
156 |
| - *args, # pylint: disable=unused-argument |
157 |
| - **kwargs # pylint: disable=unused-argument |
158 |
| - ): |
159 |
| - if torch.cuda.get_device_capability() < (8, 0): |
160 |
| - raise ImportError("punica LoRA kernels require compute " |
161 |
| - "capability>=8.0") from import_exc |
162 |
| - else: |
163 |
| - raise ImportError( |
164 |
| - "punica LoRA kernels could not be imported. If you built vLLM " |
165 |
| - "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " |
166 |
| - "was set.") from import_exc |
167 |
| - |
168 |
| - bgmv = _raise_exc |
169 |
| - add_lora = _raise_exc |
170 |
| - add_lora_slice = _raise_exc |
171 |
| - |
172 |
| -__all__ = [ |
173 |
| - "bgmv", |
174 |
| - "add_lora", |
175 |
| - "add_lora_slice", |
176 |
| -] |
| 7 | + |
| 8 | +def _raise_import_error(e): |
| 9 | + if torch.cuda.get_device_capability() < (8, 0): |
| 10 | + raise ImportError( |
| 11 | + "punica LoRA kernels require compute capability >= 8.0") from e |
| 12 | + else: |
| 13 | + raise ImportError( |
| 14 | + "punica LoRA kernels could not be imported. If you built vLLM " |
| 15 | + "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " |
| 16 | + "was set.") from e |
| 17 | + |
| 18 | + |
| 19 | +def bgmv( |
| 20 | + y: torch.Tensor, |
| 21 | + x: torch.Tensor, |
| 22 | + w_t_all: torch.Tensor, |
| 23 | + indicies: torch.LongTensor, |
| 24 | + layer_idx: int, |
| 25 | + scale: float, |
| 26 | +): |
| 27 | + """ |
| 28 | + Semantics: |
| 29 | + y[i] += ( |
| 30 | + x[i].unsqueeze(0) |
| 31 | + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
| 32 | + * scale |
| 33 | + ).squeeze(0) |
| 34 | +
|
| 35 | + Args: |
| 36 | + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. |
| 37 | + x: Shape: `[B, H1]`. Input vectors. |
| 38 | + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight |
| 39 | + matrices. |
| 40 | + indicies: Shape: `[B]`. Indices of the weight matrices. |
| 41 | + layer_idx: Layer index of the weight matrices. |
| 42 | + scale: Scaling factor. |
| 43 | + """ |
| 44 | + try: |
| 45 | + import vllm._punica_C as punica_kernels |
| 46 | + except ImportError as e: |
| 47 | + _raise_import_error(e) |
| 48 | + |
| 49 | + punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) |
| 50 | + |
| 51 | + |
| 52 | +def add_lora(y: torch.Tensor, |
| 53 | + x: torch.Tensor, |
| 54 | + wa_t_all: torch.Tensor, |
| 55 | + wb_t_all: torch.Tensor, |
| 56 | + indicies: torch.LongTensor, |
| 57 | + layer_idx: int, |
| 58 | + scale: float, |
| 59 | + *, |
| 60 | + buffer: Optional[torch.Tensor] = None): |
| 61 | + """ |
| 62 | + Semantics: |
| 63 | + y[i] += ( |
| 64 | + x[i].unsqueeze(0) |
| 65 | + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
| 66 | + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
| 67 | + * scale |
| 68 | + ).squeeze(0) |
| 69 | +
|
| 70 | + Args: |
| 71 | + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. |
| 72 | + x: Shape: `[B, H1]`. Input vectors. |
| 73 | + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed |
| 74 | + LoRA A matrices. |
| 75 | + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed |
| 76 | + LoRA B matrices. |
| 77 | + indicies: Shape: `[B]`. Indices of the LoRA weights. |
| 78 | + layer_idx: Layer index of LoRA weights. |
| 79 | + scale: Scaling factor. |
| 80 | + buffer: Optional. Shape: `[B, R]`. Temporary buffer. |
| 81 | + """ |
| 82 | + try: |
| 83 | + import vllm._punica_C as punica_kernels |
| 84 | + except ImportError as e: |
| 85 | + _raise_import_error(e) |
| 86 | + |
| 87 | + r = wb_t_all.size(-1) |
| 88 | + if buffer is None: |
| 89 | + # We set the buffer to be float32 by default to avoid |
| 90 | + # numerical innacuracies that would otherwise happen |
| 91 | + # due to downcasting. |
| 92 | + buffer = torch.zeros((x.size(0), r), |
| 93 | + dtype=torch.float32, |
| 94 | + device=x.device) |
| 95 | + punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) |
| 96 | + punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, |
| 97 | + scale) |
| 98 | + |
| 99 | + |
| 100 | +def add_lora_slice(y: torch.Tensor, |
| 101 | + x: torch.Tensor, |
| 102 | + wa_t_all: torch.Tensor, |
| 103 | + wb_t_all: torch.Tensor, |
| 104 | + indicies: torch.LongTensor, |
| 105 | + layer_idx: int, |
| 106 | + scale: float, |
| 107 | + y_offset: int, |
| 108 | + y_slice_size: int, |
| 109 | + *, |
| 110 | + buffer: Optional[torch.Tensor] = None): |
| 111 | + """ |
| 112 | + Same as `add_lora` but you can operate on slices of y. |
| 113 | + Pass whole y, define y_offset and y_slice_size. |
| 114 | +
|
| 115 | + Semantics: |
| 116 | + y[i] += ( |
| 117 | + x[i].unsqueeze(0) |
| 118 | + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
| 119 | + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) |
| 120 | + * scale |
| 121 | + ).squeeze(0) |
| 122 | +
|
| 123 | + Args: |
| 124 | + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. |
| 125 | + x: Shape: `[B, H1]`. Input vectors. |
| 126 | + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed |
| 127 | + LoRA A matrices. |
| 128 | + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed |
| 129 | + LoRA B matrices. |
| 130 | + indicies: Shape: `[B]`. Indices of the LoRA weights. |
| 131 | + layer_idx: Layer index of LoRA weights. |
| 132 | + scale: Scaling factor. |
| 133 | + y_offset: Offset to apply to the starting column of y. |
| 134 | + y_slice_size: Size of the y column slice. |
| 135 | + """ |
| 136 | + try: |
| 137 | + import vllm._punica_C as punica_kernels |
| 138 | + except ImportError as e: |
| 139 | + _raise_import_error(e) |
| 140 | + |
| 141 | + r = wb_t_all.size(-1) |
| 142 | + if buffer is None: |
| 143 | + # We set the buffer to be float32 by default to avoid |
| 144 | + # numerical inaccuracies that would otherwise happen |
| 145 | + # due to downcasting. |
| 146 | + buffer = torch.zeros((x.size(0), r), |
| 147 | + dtype=torch.float32, |
| 148 | + device=x.device) |
| 149 | + punica_kernels.dispatch_bgmv_low_level( |
| 150 | + buffer, |
| 151 | + x, |
| 152 | + wa_t_all, |
| 153 | + indicies, |
| 154 | + layer_idx, |
| 155 | + 1.0, |
| 156 | + x.size(1), |
| 157 | + buffer.size(1), |
| 158 | + 0, |
| 159 | + ) |
| 160 | + punica_kernels.dispatch_bgmv_low_level( |
| 161 | + y, |
| 162 | + buffer, |
| 163 | + wb_t_all, |
| 164 | + indicies, |
| 165 | + layer_idx, |
| 166 | + scale, |
| 167 | + buffer.size(1), |
| 168 | + y_slice_size, |
| 169 | + y_offset, |
| 170 | + ) |
0 commit comments