Skip to content

Commit 99a8c6f

Browse files
Compile VAE V2 (huggingface#46)
Co-authored-by: leaves-zwx <kunta0932@gmail.com>
1 parent d9f76de commit 99a8c6f

File tree

4 files changed

+179
-44
lines changed

4 files changed

+179
-44
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from collections import deque
2+
from timeit import default_timer as timer
3+
from .utils import logging
4+
5+
logger = logging.get_logger(__name__)
6+
7+
8+
class OneFlowGraph(object):
9+
def __init__(self, graph_class, *args, **kwargs):
10+
self.graph_ = graph_class(*args, **kwargs)
11+
self.is_compiled_ = False
12+
13+
@property
14+
def is_compiled(self):
15+
return self.is_compiled_
16+
17+
def compile(self, *args, **kwargs):
18+
if self.is_compiled_:
19+
return
20+
21+
global_class_name = self.graph_.__class__.__name__
22+
logger.info(
23+
f"[oneflow] compiling {global_class_name} beforehand to make sure the progress bar is more accurate",
24+
)
25+
compilation_start = timer()
26+
compilation_time = 0
27+
self.graph_._compile(*args, **kwargs)
28+
compilation_time = timer() - compilation_start
29+
logger.info(f"[oneflow] [elapsed(s)] [{global_class_name} compilation] {compilation_time:.3f}")
30+
31+
self.is_compiled_ = True
32+
33+
def __call__(self, *args, **kwargs):
34+
if not self.is_compiled_:
35+
self.compile(*args, **kwargs)
36+
37+
return self.graph_(*args, **kwargs)
38+
39+
40+
class LRUCache(object):
41+
def __init__(self, cache_size):
42+
self.cache_size = cache_size
43+
self.queue = deque()
44+
self.hash_map = dict()
45+
46+
def is_queue_full(self):
47+
return len(self.queue) == self.cache_size
48+
49+
def pop(self):
50+
pop_key = self.queue.pop()
51+
value = self.hash_map.pop(pop_key)
52+
del value
53+
return pop_key
54+
55+
def set(self, key, value):
56+
if key in self.hash_map:
57+
return None
58+
59+
pop_key = None
60+
while self.is_queue_full():
61+
pop_key = self.pop()
62+
63+
self.queue.appendleft(key)
64+
self.hash_map[key] = value
65+
return pop_key if pop_key is not None else key
66+
67+
def get(self, key):
68+
if key in self.hash_map:
69+
self.queue.remove(key)
70+
self.queue.appendleft(key)
71+
return self.hash_map[key]
72+
73+
return None
74+
75+
76+
class OneFlowGraphCompileCache(object):
77+
def __init__(self, cache_size=1):
78+
self.cache_size_ = cache_size
79+
self.cache_bucket_ = dict()
80+
81+
def set_cache_size(self, cache_size):
82+
self.cache_size_ = cache_size
83+
84+
for cache in self.cache_bucket_.values():
85+
cache.cache_size = cache_size
86+
87+
def get_graph(self, graph_class, cache_key, *args, **kwargs):
88+
graph_class_name = graph_class.__name__
89+
if graph_class_name not in self.cache_bucket_:
90+
self.cache_bucket_[graph_class_name] = LRUCache(self.cache_size_)
91+
92+
compile_cache = self.cache_bucket_[graph_class_name]
93+
94+
graph = compile_cache.get(cache_key)
95+
if graph is None:
96+
graph = OneFlowGraph(graph_class, *args, **kwargs)
97+
ret = compile_cache.set(cache_key, graph)
98+
assert ret is not None
99+
100+
if ret != cache_key:
101+
logger.info(
102+
f"[oneflow] a {graph_class_name} with cache key {ret} "
103+
"is deleted from cache according to the LRU policy",
104+
)
105+
if self.cache_size_ == 1:
106+
logger.info("[oneflow] cache size can be changed by `set_cache_size`")
107+
108+
logger.info(
109+
f"[oneflow] a {graph_class_name} with cache key {cache_key} is appending to "
110+
f"cache (cache_size={compile_cache.cache_size})",
111+
)
112+
113+
return graph

src/diffusers/pipeline_oneflow_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
is_transformers_available,
5151
logging,
5252
)
53+
from .oneflow_graph_compile_cache import OneFlowGraphCompileCache
5354

5455

5556
if is_transformers_available():
@@ -159,6 +160,12 @@ class OneFlowDiffusionPipeline(ConfigMixin):
159160
config_name = "model_index.json"
160161
_optional_components = []
161162

163+
def init_graph_compile_cache(self, cache_size):
164+
self.graph_compile_cache = OneFlowGraphCompileCache(cache_size)
165+
166+
def set_graph_compile_cache_size(self, cache_size):
167+
self.graph_compile_cache.set_cache_size(cache_size)
168+
162169
def register_modules(self, **kwargs):
163170
# import it here to avoid circular import
164171
from diffusers import pipelines

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
312312
else:
313313
attention_mask = None
314314

315-
text_embeddings = self.text_encoder(
316-
text_input_ids.to(device),
317-
attention_mask=attention_mask,
318-
)
315+
text_embeddings = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
319316
text_embeddings = text_embeddings[0]
320317

321318
# duplicate text embeddings for each generation per prompt, using mps friendly method

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_oneflow.py

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@
3737
from . import StableDiffusionPipelineOutput
3838
from .safety_checker_oneflow import OneFlowStableDiffusionSafetyChecker as StableDiffusionSafetyChecker
3939

40-
4140
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4241

4342
from timeit import default_timer as timer
4443
import os
4544
import oneflow as flow
45+
46+
4647
class UNetGraph(flow.nn.Graph):
4748
def __init__(self, unet):
4849
super().__init__()
@@ -55,6 +56,37 @@ def build(self, latent_model_input, t, text_embeddings):
5556
text_embeddings = torch._C.amp_white_identity(text_embeddings)
5657
return self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
5758

59+
60+
class VaePostProcess(flow.nn.Module):
61+
def __init__(self, vae) -> None:
62+
super().__init__()
63+
self.vae = vae
64+
65+
def forward(self, latents):
66+
latents = 1 / 0.18215 * latents
67+
image = self.vae.decode(latents).sample
68+
image = (image / 2 + 0.5).clamp(0, 1)
69+
return image
70+
71+
72+
class VaeGraph(flow.nn.Graph):
73+
def __init__(self, vae_post_process) -> None:
74+
super().__init__()
75+
self.vae_post_process = vae_post_process
76+
77+
def build(self, latents):
78+
return self.vae_post_process(latents)
79+
80+
81+
class TextEncoderGraph(flow.nn.Graph):
82+
def __init__(self, text_encoder) -> None:
83+
super().__init__()
84+
self.text_encoder = text_encoder
85+
86+
def build(self, text_input, attention_mask):
87+
return self.text_encoder(text_input, attention_mask)[0]
88+
89+
5890
class OneFlowStableDiffusionPipeline(DiffusionPipeline):
5991
r"""
6092
Pipeline for text-to-image generation using Stable Diffusion.
@@ -189,9 +221,7 @@ def __init__(
189221
)
190222
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
191223
self.register_to_config(requires_safety_checker=requires_safety_checker)
192-
self.unet_graphs = dict()
193-
self.unet_graphs_cache_size = 1
194-
self.unet_graphs_lru_cache_time = 0
224+
self.init_graph_compile_cache(1)
195225

196226
def enable_xformers_memory_efficient_attention(self):
197227
r"""
@@ -288,9 +318,6 @@ def _execution_device(self):
288318
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
289319
hooks.
290320
"""
291-
'''
292-
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
293-
'''
294321
if not hasattr(self.unet, "_hf_hook"):
295322
return self.device
296323
for module in self.unet.modules():
@@ -345,10 +372,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
345372
else:
346373
attention_mask = None
347374

348-
text_embeddings = self.text_encoder(
349-
text_input_ids.to(device),
350-
attention_mask=attention_mask,
351-
)
375+
text_embeddings = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
352376
text_embeddings = text_embeddings[0]
353377

354378
# duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -480,14 +504,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
480504
def set_unet_graphs_cache_size(self, cache_size: int):
481505
r"""
482506
Set the cache size of compiled unet graphs.
483-
484507
This option is designed to control the GPU memory size.
485-
486508
Args:
487509
cache_size ([`int`]):
488510
New cache size, i.e., the maximum number of unet graphs.
489511
"""
490-
self.unet_graphs_cache_size = cache_size
512+
logger.warning(f"`set_unet_graphs_cache_size` is deprecated, please use `set_graph_compile_cache_size` instead.")
513+
self.set_graph_compile_cache_size(cache_size)
491514

492515
@torch.no_grad()
493516
def __call__(
@@ -507,6 +530,7 @@ def __call__(
507530
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
508531
callback_steps: Optional[int] = 1,
509532
compile_unet: bool = True,
533+
compile_vae: bool = True,
510534
):
511535
r"""
512536
Function invoked when calling the pipeline for generation.
@@ -599,35 +623,25 @@ def __call__(
599623
latents,
600624
)
601625

602-
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
603-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
626+
# compile vae graph
627+
if compile_vae:
628+
cache_key = (height, width, num_images_per_prompt)
629+
vae_post_process = VaePostProcess(self.vae)
630+
vae_post_process.eval()
631+
vae_post_process_graph = self.graph_compile_cache.get_graph(VaeGraph, cache_key, vae_post_process)
632+
vae_post_process_graph.compile(latents)
604633

605-
compilation_start = timer()
606-
compilation_time = 0
634+
# compile unet graph
607635
if compile_unet:
608-
self.unet_graphs_lru_cache_time += 1
609-
if (height, width) in self.unet_graphs:
610-
_, unet_graph = self.unet_graphs[height, width]
611-
self.unet_graphs[height, width] = (self.unet_graphs_lru_cache_time, unet_graph)
612-
else:
613-
while len(self.unet_graphs) >= self.unet_graphs_cache_size:
614-
shape_to_del = min(self.unet_graphs.keys(), key=lambda shape: self.unet_graphs[shape][0])
615-
print("[oneflow]", f"a compiled unet (height={shape_to_del[0]}, width={shape_to_del[1]}) "
616-
"is deleted according to the LRU policy")
617-
print("[oneflow]", "cache size can be changed by `pipeline.set_unet_graphs_cache_size`")
618-
del self.unet_graphs[shape_to_del]
619-
print("[oneflow]", "compiling unet beforehand to make sure the progress bar is more accurate")
620-
i, t = list(enumerate(self.scheduler.timesteps))[0]
621-
636+
cache_key = (height, width, num_images_per_prompt)
637+
unet_graph = self.graph_compile_cache.get_graph(UNetGraph, cache_key, self.unet)
638+
if unet_graph.is_compiled is False:
622639
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
623-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
640+
_, t = list(enumerate(self.scheduler.timesteps))[0]
641+
unet_graph.compile(latent_model_input, t, text_embeddings)
624642

625-
unet_graph = UNetGraph(self.unet)
626-
unet_graph._compile(latent_model_input, t, text_embeddings)
627-
unet_graph(latent_model_input, t, text_embeddings) # warmup
628-
compilation_time = timer() - compilation_start
629-
print("[oneflow]", "[elapsed(s)]", "[unet compilation]", compilation_time)
630-
self.unet_graphs[height, width] = (self.unet_graphs_lru_cache_time, unet_graph)
643+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
644+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
631645

632646
# 7. Denoising loop
633647
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -660,7 +674,11 @@ def __call__(
660674
callback(i, t, latents)
661675

662676
# 8. Post-processing
663-
image = self.decode_latents(latents)
677+
if compile_vae:
678+
image = vae_post_process_graph(latents)
679+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
680+
else:
681+
image = self.decode_latents(latents)
664682

665683
# 9. Run safety checker
666684
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)

0 commit comments

Comments
 (0)