3737from . import StableDiffusionPipelineOutput
3838from .safety_checker_oneflow import OneFlowStableDiffusionSafetyChecker as StableDiffusionSafetyChecker
3939
40-
4140logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4241
4342from timeit import default_timer as timer
4443import os
4544import oneflow as flow
45+
46+
4647class UNetGraph (flow .nn .Graph ):
4748 def __init__ (self , unet ):
4849 super ().__init__ ()
@@ -55,6 +56,37 @@ def build(self, latent_model_input, t, text_embeddings):
5556 text_embeddings = torch ._C .amp_white_identity (text_embeddings )
5657 return self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
5758
59+
60+ class VaePostProcess (flow .nn .Module ):
61+ def __init__ (self , vae ) -> None :
62+ super ().__init__ ()
63+ self .vae = vae
64+
65+ def forward (self , latents ):
66+ latents = 1 / 0.18215 * latents
67+ image = self .vae .decode (latents ).sample
68+ image = (image / 2 + 0.5 ).clamp (0 , 1 )
69+ return image
70+
71+
72+ class VaeGraph (flow .nn .Graph ):
73+ def __init__ (self , vae_post_process ) -> None :
74+ super ().__init__ ()
75+ self .vae_post_process = vae_post_process
76+
77+ def build (self , latents ):
78+ return self .vae_post_process (latents )
79+
80+
81+ class TextEncoderGraph (flow .nn .Graph ):
82+ def __init__ (self , text_encoder ) -> None :
83+ super ().__init__ ()
84+ self .text_encoder = text_encoder
85+
86+ def build (self , text_input , attention_mask ):
87+ return self .text_encoder (text_input , attention_mask )[0 ]
88+
89+
5890class OneFlowStableDiffusionPipeline (DiffusionPipeline ):
5991 r"""
6092 Pipeline for text-to-image generation using Stable Diffusion.
@@ -189,9 +221,7 @@ def __init__(
189221 )
190222 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
191223 self .register_to_config (requires_safety_checker = requires_safety_checker )
192- self .unet_graphs = dict ()
193- self .unet_graphs_cache_size = 1
194- self .unet_graphs_lru_cache_time = 0
224+ self .init_graph_compile_cache (1 )
195225
196226 def enable_xformers_memory_efficient_attention (self ):
197227 r"""
@@ -288,9 +318,6 @@ def _execution_device(self):
288318 `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
289319 hooks.
290320 """
291- '''
292- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
293- '''
294321 if not hasattr (self .unet , "_hf_hook" ):
295322 return self .device
296323 for module in self .unet .modules ():
@@ -345,10 +372,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
345372 else :
346373 attention_mask = None
347374
348- text_embeddings = self .text_encoder (
349- text_input_ids .to (device ),
350- attention_mask = attention_mask ,
351- )
375+ text_embeddings = self .text_encoder (text_input_ids .to (device ), attention_mask = attention_mask )
352376 text_embeddings = text_embeddings [0 ]
353377
354378 # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -480,14 +504,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
480504 def set_unet_graphs_cache_size (self , cache_size : int ):
481505 r"""
482506 Set the cache size of compiled unet graphs.
483-
484507 This option is designed to control the GPU memory size.
485-
486508 Args:
487509 cache_size ([`int`]):
488510 New cache size, i.e., the maximum number of unet graphs.
489511 """
490- self .unet_graphs_cache_size = cache_size
512+ logger .warning (f"`set_unet_graphs_cache_size` is deprecated, please use `set_graph_compile_cache_size` instead." )
513+ self .set_graph_compile_cache_size (cache_size )
491514
492515 @torch .no_grad ()
493516 def __call__ (
@@ -507,6 +530,7 @@ def __call__(
507530 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
508531 callback_steps : Optional [int ] = 1 ,
509532 compile_unet : bool = True ,
533+ compile_vae : bool = True ,
510534 ):
511535 r"""
512536 Function invoked when calling the pipeline for generation.
@@ -599,35 +623,25 @@ def __call__(
599623 latents ,
600624 )
601625
602- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
603- extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
626+ # compile vae graph
627+ if compile_vae :
628+ cache_key = (height , width , num_images_per_prompt )
629+ vae_post_process = VaePostProcess (self .vae )
630+ vae_post_process .eval ()
631+ vae_post_process_graph = self .graph_compile_cache .get_graph (VaeGraph , cache_key , vae_post_process )
632+ vae_post_process_graph .compile (latents )
604633
605- compilation_start = timer ()
606- compilation_time = 0
634+ # compile unet graph
607635 if compile_unet :
608- self .unet_graphs_lru_cache_time += 1
609- if (height , width ) in self .unet_graphs :
610- _ , unet_graph = self .unet_graphs [height , width ]
611- self .unet_graphs [height , width ] = (self .unet_graphs_lru_cache_time , unet_graph )
612- else :
613- while len (self .unet_graphs ) >= self .unet_graphs_cache_size :
614- shape_to_del = min (self .unet_graphs .keys (), key = lambda shape : self .unet_graphs [shape ][0 ])
615- print ("[oneflow]" , f"a compiled unet (height={ shape_to_del [0 ]} , width={ shape_to_del [1 ]} ) "
616- "is deleted according to the LRU policy" )
617- print ("[oneflow]" , "cache size can be changed by `pipeline.set_unet_graphs_cache_size`" )
618- del self .unet_graphs [shape_to_del ]
619- print ("[oneflow]" , "compiling unet beforehand to make sure the progress bar is more accurate" )
620- i , t = list (enumerate (self .scheduler .timesteps ))[0 ]
621-
636+ cache_key = (height , width , num_images_per_prompt )
637+ unet_graph = self .graph_compile_cache .get_graph (UNetGraph , cache_key , self .unet )
638+ if unet_graph .is_compiled is False :
622639 latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
623- latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
640+ _ , t = list (enumerate (self .scheduler .timesteps ))[0 ]
641+ unet_graph .compile (latent_model_input , t , text_embeddings )
624642
625- unet_graph = UNetGraph (self .unet )
626- unet_graph ._compile (latent_model_input , t , text_embeddings )
627- unet_graph (latent_model_input , t , text_embeddings ) # warmup
628- compilation_time = timer () - compilation_start
629- print ("[oneflow]" , "[elapsed(s)]" , "[unet compilation]" , compilation_time )
630- self .unet_graphs [height , width ] = (self .unet_graphs_lru_cache_time , unet_graph )
643+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
644+ extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
631645
632646 # 7. Denoising loop
633647 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
@@ -660,7 +674,11 @@ def __call__(
660674 callback (i , t , latents )
661675
662676 # 8. Post-processing
663- image = self .decode_latents (latents )
677+ if compile_vae :
678+ image = vae_post_process_graph (latents )
679+ image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
680+ else :
681+ image = self .decode_latents (latents )
664682
665683 # 9. Run safety checker
666684 image , has_nsfw_concept = self .run_safety_checker (image , device , text_embeddings .dtype )
0 commit comments