forked from Shilin-LU/MACE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
87 lines (76 loc) · 3.44 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os, gc
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from omegaconf import OmegaConf
import argparse
def main(args):
model_id = args.pretrained_model_name_or_path
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(args.device)
pipe.safety_checker = None
pipe.requires_safety_checker = False
torch.Generator(device=args.device).manual_seed(42)
if args.generate_training_data:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
num_images = 8
count = 0
for single_concept in args.multi_concept:
for c, t in single_concept:
count += 1
print(f"Generating training data for concept {count}: {c}...")
c = c.replace('-', ' ')
output_folder = f"{args.output_dir}/{c}"
os.makedirs(output_folder, exist_ok=True)
if t == "object":
prompt = f"a photo of the {c}"
print(f'Inferencing: {prompt}')
images = pipe(prompt, num_inference_steps=args.steps, guidance_scale=7.5, num_images_per_prompt=num_images).images
for i, im in enumerate(images):
im.save(f"{output_folder}/{prompt.replace(' ', '-')}_{i}.jpg")
elif t == "style":
prompt = f"a photo in the style of {c}"
print(f'Inferencing: {prompt}')
images = pipe(prompt, num_inference_steps=args.steps, guidance_scale=7.5, num_images_per_prompt=num_images).images
for i, im in enumerate(images):
im.save(f"{output_folder}/{prompt.replace(' ', '-')}_{i}.jpg")
else:
raise ValueError("unknown concept type.")
del images
torch.cuda.empty_cache()
gc.collect()
else:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
num_images = args.num_images
output_folder = f"{args.output_dir}/generated_images"
os.makedirs(output_folder, exist_ok=True)
print(f"Inference using {args.pretrained_model_name_or_path}...")
prompt = args.prompt
images = pipe(prompt, num_inference_steps=args.steps, guidance_scale=7.5, num_images_per_prompt=num_images).images
for i, im in enumerate(images):
im.save(f"{output_folder}/o_{prompt.replace(' ', '-')}_{i}.jpg")
torch.cuda.empty_cache()
gc.collect()
del pipe
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='')
parser.add_argument('--num_images', type=int, default=3)
parser.add_argument('--prompt', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--save_path', type=str, default=None)
args = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
steps = 30
model_id = args.model_path
output_dir = args.save_path
num_images = args.num_images
prompt = args.prompt
main(OmegaConf.create({
"pretrained_model_name_or_path": model_id,
"generate_training_data": False,
"device": device,
"steps": steps,
"output_dir": output_dir,
"num_images": num_images,
"prompt": prompt,
}))