Skip to content

Commit f99784c

Browse files
WoosukKwonjimpang
authored andcommitted
[BugFix] Fix GC bug for LLM class (vllm-project#2882)
1 parent 211242f commit f99784c

File tree

2 files changed

+182
-170
lines changed

2 files changed

+182
-170
lines changed

tests/test_regression.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
will never happen again.
55
66
"""
7+
import gc
8+
9+
import torch
10+
711
from vllm import LLM, SamplingParams
812

913

@@ -35,6 +39,20 @@ def test_max_tokens_none():
3539
assert len(prompts) == len(outputs)
3640

3741

42+
def test_gc():
43+
llm = LLM("facebook/opt-125m", enforce_eager=True)
44+
del llm
45+
46+
gc.collect()
47+
torch.cuda.empty_cache()
48+
49+
# The memory allocated for model and KV cache should be released.
50+
# The memory allocated for PyTorch and others should be less than 50MB.
51+
# Usually, it's around 10MB.
52+
allocated = torch.cuda.memory_allocated()
53+
assert allocated < 50 * 1024 * 1024
54+
55+
3856
if __name__ == "__main__":
3957
import pytest
4058
pytest.main([__file__])

vllm/lora/punica.py

Lines changed: 164 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -4,173 +4,167 @@
44

55
import torch
66

7-
import_exc = None
8-
9-
try:
10-
import vllm._punica_C as punica_kernels
11-
except ImportError as e:
12-
import_exc = e
13-
14-
if import_exc is None:
15-
16-
def bgmv(
17-
y: torch.Tensor,
18-
x: torch.Tensor,
19-
w_t_all: torch.Tensor,
20-
indicies: torch.LongTensor,
21-
layer_idx: int,
22-
scale: float,
23-
):
24-
"""
25-
Semantics:
26-
y[i] += (
27-
x[i].unsqueeze(0)
28-
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
29-
* scale
30-
).squeeze(0)
31-
32-
Args:
33-
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
34-
x: Shape: `[B, H1]`. Input vectors.
35-
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
36-
matrices.
37-
indicies: Shape: `[B]`. Indices of the weight matrices.
38-
layer_idx: Layer index of the weight matrices.
39-
scale: Scaling factor.
40-
"""
41-
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
42-
43-
def add_lora(y: torch.Tensor,
44-
x: torch.Tensor,
45-
wa_t_all: torch.Tensor,
46-
wb_t_all: torch.Tensor,
47-
indicies: torch.LongTensor,
48-
layer_idx: int,
49-
scale: float,
50-
*,
51-
buffer: Optional[torch.Tensor] = None):
52-
"""
53-
Semantics:
54-
y[i] += (
55-
x[i].unsqueeze(0)
56-
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
57-
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
58-
* scale
59-
).squeeze(0)
60-
61-
Args:
62-
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
63-
x: Shape: `[B, H1]`. Input vectors.
64-
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
65-
LoRA A matrices.
66-
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
67-
LoRA B matrices.
68-
indicies: Shape: `[B]`. Indices of the LoRA weights.
69-
layer_idx: Layer index of LoRA weights.
70-
scale: Scaling factor.
71-
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
72-
"""
73-
r = wb_t_all.size(-1)
74-
if buffer is None:
75-
# We set the buffer to be float32 by default to avoid
76-
# numerical innacuracies that would otherwise happen
77-
# due to downcasting.
78-
buffer = torch.zeros((x.size(0), r),
79-
dtype=torch.float32,
80-
device=x.device)
81-
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
82-
1.0)
83-
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
84-
scale)
85-
86-
def add_lora_slice(y: torch.Tensor,
87-
x: torch.Tensor,
88-
wa_t_all: torch.Tensor,
89-
wb_t_all: torch.Tensor,
90-
indicies: torch.LongTensor,
91-
layer_idx: int,
92-
scale: float,
93-
y_offset: int,
94-
y_slice_size: int,
95-
*,
96-
buffer: Optional[torch.Tensor] = None):
97-
"""
98-
Same as `add_lora` but you can operate on slices of y.
99-
Pass whole y, define y_offset and y_slice_size.
100-
101-
Semantics:
102-
y[i] += (
103-
x[i].unsqueeze(0)
104-
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
105-
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
106-
* scale
107-
).squeeze(0)
108-
109-
Args:
110-
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
111-
x: Shape: `[B, H1]`. Input vectors.
112-
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
113-
LoRA A matrices.
114-
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
115-
LoRA B matrices.
116-
indicies: Shape: `[B]`. Indices of the LoRA weights.
117-
layer_idx: Layer index of LoRA weights.
118-
scale: Scaling factor.
119-
y_offset: Offset to apply to the starting column of y.
120-
y_slice_size: Size of the y column slice.
121-
"""
122-
r = wb_t_all.size(-1)
123-
if buffer is None:
124-
# We set the buffer to be float32 by default to avoid
125-
# numerical inaccuracies that would otherwise happen
126-
# due to downcasting.
127-
buffer = torch.zeros((x.size(0), r),
128-
dtype=torch.float32,
129-
device=x.device)
130-
punica_kernels.dispatch_bgmv_low_level(
131-
buffer,
132-
x,
133-
wa_t_all,
134-
indicies,
135-
layer_idx,
136-
1.0,
137-
x.size(1),
138-
buffer.size(1),
139-
0,
140-
)
141-
punica_kernels.dispatch_bgmv_low_level(
142-
y,
143-
buffer,
144-
wb_t_all,
145-
indicies,
146-
layer_idx,
147-
scale,
148-
buffer.size(1),
149-
y_slice_size,
150-
y_offset,
151-
)
152-
153-
else:
154-
155-
def _raise_exc(
156-
*args, # pylint: disable=unused-argument
157-
**kwargs # pylint: disable=unused-argument
158-
):
159-
if torch.cuda.get_device_capability() < (8, 0):
160-
raise ImportError("punica LoRA kernels require compute "
161-
"capability>=8.0") from import_exc
162-
else:
163-
raise ImportError(
164-
"punica LoRA kernels could not be imported. If you built vLLM "
165-
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
166-
"was set.") from import_exc
167-
168-
bgmv = _raise_exc
169-
add_lora = _raise_exc
170-
add_lora_slice = _raise_exc
171-
172-
__all__ = [
173-
"bgmv",
174-
"add_lora",
175-
"add_lora_slice",
176-
]
7+
8+
def _raise_import_error(e):
9+
if torch.cuda.get_device_capability() < (8, 0):
10+
raise ImportError(
11+
"punica LoRA kernels require compute capability >= 8.0") from e
12+
else:
13+
raise ImportError(
14+
"punica LoRA kernels could not be imported. If you built vLLM "
15+
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
16+
"was set.") from e
17+
18+
19+
def bgmv(
20+
y: torch.Tensor,
21+
x: torch.Tensor,
22+
w_t_all: torch.Tensor,
23+
indicies: torch.LongTensor,
24+
layer_idx: int,
25+
scale: float,
26+
):
27+
"""
28+
Semantics:
29+
y[i] += (
30+
x[i].unsqueeze(0)
31+
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
32+
* scale
33+
).squeeze(0)
34+
35+
Args:
36+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
37+
x: Shape: `[B, H1]`. Input vectors.
38+
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
39+
matrices.
40+
indicies: Shape: `[B]`. Indices of the weight matrices.
41+
layer_idx: Layer index of the weight matrices.
42+
scale: Scaling factor.
43+
"""
44+
try:
45+
import vllm._punica_C as punica_kernels
46+
except ImportError as e:
47+
_raise_import_error(e)
48+
49+
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
50+
51+
52+
def add_lora(y: torch.Tensor,
53+
x: torch.Tensor,
54+
wa_t_all: torch.Tensor,
55+
wb_t_all: torch.Tensor,
56+
indicies: torch.LongTensor,
57+
layer_idx: int,
58+
scale: float,
59+
*,
60+
buffer: Optional[torch.Tensor] = None):
61+
"""
62+
Semantics:
63+
y[i] += (
64+
x[i].unsqueeze(0)
65+
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
66+
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
67+
* scale
68+
).squeeze(0)
69+
70+
Args:
71+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
72+
x: Shape: `[B, H1]`. Input vectors.
73+
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
74+
LoRA A matrices.
75+
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
76+
LoRA B matrices.
77+
indicies: Shape: `[B]`. Indices of the LoRA weights.
78+
layer_idx: Layer index of LoRA weights.
79+
scale: Scaling factor.
80+
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
81+
"""
82+
try:
83+
import vllm._punica_C as punica_kernels
84+
except ImportError as e:
85+
_raise_import_error(e)
86+
87+
r = wb_t_all.size(-1)
88+
if buffer is None:
89+
# We set the buffer to be float32 by default to avoid
90+
# numerical innacuracies that would otherwise happen
91+
# due to downcasting.
92+
buffer = torch.zeros((x.size(0), r),
93+
dtype=torch.float32,
94+
device=x.device)
95+
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
96+
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
97+
scale)
98+
99+
100+
def add_lora_slice(y: torch.Tensor,
101+
x: torch.Tensor,
102+
wa_t_all: torch.Tensor,
103+
wb_t_all: torch.Tensor,
104+
indicies: torch.LongTensor,
105+
layer_idx: int,
106+
scale: float,
107+
y_offset: int,
108+
y_slice_size: int,
109+
*,
110+
buffer: Optional[torch.Tensor] = None):
111+
"""
112+
Same as `add_lora` but you can operate on slices of y.
113+
Pass whole y, define y_offset and y_slice_size.
114+
115+
Semantics:
116+
y[i] += (
117+
x[i].unsqueeze(0)
118+
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
119+
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
120+
* scale
121+
).squeeze(0)
122+
123+
Args:
124+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
125+
x: Shape: `[B, H1]`. Input vectors.
126+
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
127+
LoRA A matrices.
128+
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
129+
LoRA B matrices.
130+
indicies: Shape: `[B]`. Indices of the LoRA weights.
131+
layer_idx: Layer index of LoRA weights.
132+
scale: Scaling factor.
133+
y_offset: Offset to apply to the starting column of y.
134+
y_slice_size: Size of the y column slice.
135+
"""
136+
try:
137+
import vllm._punica_C as punica_kernels
138+
except ImportError as e:
139+
_raise_import_error(e)
140+
141+
r = wb_t_all.size(-1)
142+
if buffer is None:
143+
# We set the buffer to be float32 by default to avoid
144+
# numerical inaccuracies that would otherwise happen
145+
# due to downcasting.
146+
buffer = torch.zeros((x.size(0), r),
147+
dtype=torch.float32,
148+
device=x.device)
149+
punica_kernels.dispatch_bgmv_low_level(
150+
buffer,
151+
x,
152+
wa_t_all,
153+
indicies,
154+
layer_idx,
155+
1.0,
156+
x.size(1),
157+
buffer.size(1),
158+
0,
159+
)
160+
punica_kernels.dispatch_bgmv_low_level(
161+
y,
162+
buffer,
163+
wb_t_all,
164+
indicies,
165+
layer_idx,
166+
scale,
167+
buffer.size(1),
168+
y_slice_size,
169+
y_offset,
170+
)

0 commit comments

Comments
 (0)