Skip to content

Commit b269150

Browse files
committed
garbage collection for model
1 parent 8170fae commit b269150

File tree

6 files changed

+169
-30
lines changed

6 files changed

+169
-30
lines changed

examples/star/inference.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,9 @@
1-
import gc
21
from typing import List
32
from datasets import Dataset
43
from vllm import LLM, SamplingParams
5-
from utils import generate_prompt
4+
from utils import generate_prompt, cleanup
65

76

8-
def cleanup(model):
9-
try:
10-
import torch
11-
import contextlib
12-
if torch.cuda.is_available():
13-
from vllm.distributed.parallel_state import (
14-
destroy_model_parallel, destroy_distributed_environment
15-
)
16-
destroy_model_parallel()
17-
destroy_distributed_environment()
18-
del model.llm_engine.model_executor
19-
del model
20-
with contextlib.suppress(AssertionError):
21-
torch.distributed.destroy_process_group()
22-
gc.collect()
23-
torch.cuda.empty_cache()
24-
torch.cuda.synchronize()
25-
except ImportError:
26-
del model
27-
287
def generate_predictions(
298
model_name: str, dataset: Dataset, temperature: float = 1.0, n: int = 1
309
) -> List[List[str]]:
@@ -62,8 +41,5 @@ def generate_predictions(
6241
for output in outputs:
6342
generated_texts = [one.text for one in output.outputs]
6443
results.append(generated_texts)
65-
cleanup(llm)
44+
cleanup(llm, vllm=True)
6645
return results
67-
# out_name = dataset_name.split("/")[-1]
68-
# out_name = f"wentingzhao/{out_name}_predictions_{n}"
69-
# ds.push_to_hub(out_name)

examples/star/star.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def main():
2121
ds[split] = ds[split].add_column(name="text", column=texts)
2222

2323
model_name = args.model_name_or_path
24+
ds["train"] = ds["train"].select(range(10))
2425
for i in range(args.iteration):
2526
# sample
2627
all_samples = generate_predictions(

examples/star/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
get_scheduler,
5050
)
5151

52+
from utils import cleanup
53+
5254

5355
logger = get_logger(__name__)
5456

@@ -404,6 +406,9 @@ def tokenize_function(examples):
404406
)
405407
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
406408
json.dump({"perplexity": perplexity}, f)
409+
cleanup(model)
410+
#cleanup(optimizer)
411+
#cleanup(lr_scheduler)
407412

408413

409414
if __name__ == "__main__":

examples/star/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import gc
23
import subprocess
34
from concurrent.futures import ThreadPoolExecutor, as_completed
45
from datasets import Dataset
@@ -275,3 +276,35 @@ def parse_args():
275276
)
276277

277278
return args
279+
280+
281+
def cleanup(model, vllm=False):
282+
"""
283+
Clean up resources associated with the given model.
284+
285+
Parameters
286+
----------
287+
model : Any
288+
The model object whose resources are to be cleaned up.
289+
"""
290+
try:
291+
import torch
292+
import contextlib
293+
if torch.cuda.is_available():
294+
if vllm:
295+
from vllm.distributed.parallel_state import (
296+
destroy_model_parallel, destroy_distributed_environment
297+
)
298+
destroy_model_parallel()
299+
destroy_distributed_environment()
300+
del model.llm_engine.model_executor
301+
if not vllm:
302+
model = model.cpu()
303+
del model
304+
with contextlib.suppress(AssertionError):
305+
torch.distributed.destroy_process_group()
306+
gc.collect()
307+
torch.cuda.empty_cache()
308+
torch.cuda.synchronize()
309+
except ImportError:
310+
del model

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ example = [
4343
"transformers",
4444
"setuptools",
4545
"accelerate",
46+
"wandb>=0.19.0",
4647
]
4748

4849
[project.urls]

0 commit comments

Comments
 (0)