Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: PatchVAE related problems #86

Merged
merged 7 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pipefuser.egg-info
*.log
*.txt
results/
profile/
pipefusion.egg-info/
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,20 @@ To conduct the FID experiment, follow the detailed instructions provided in the

1. Memory Efficient VAE:

The VAE decode implementation from diffusers can not be applied on high resolution images (8192px).
It has CUDA memory spike issue, [diffusers/issues/5924](https://github.com/huggingface/diffusers/issues/5924).
We fixed the issue by splitting a conv operator into multiple small ones and executing them sequentially to reduce the peak memory.
The VAE decoder implementation in the diffusers library faces significant challenges when applied to high-resolution images (8192px and above). A critical issue is the CUDA memory spike, as documented in [diffusers/issues/5924](https://github.com/huggingface/diffusers/issues/5924).

To address this limitation, we developed [PatchVAE](https://github.com/PipeFusion/PatchVAE), an innovative solution that enables efficient processing of high-resolution images. Our approach incorporates two key strategies:



* Patch Parallelization: We divide the feature maps in the latent space into multiple patches and perform parallel VAE decoding across different devices. This technique reduces the peak memory required for intermediate activations to 1/N, where N is the number of devices utilized.


* Sequential Patch Processing: Building on [previous research](https://hanlab.mit.edu/blog/patch-conv), we implemented a method to process portions of each patch sequentially on individual devices. This approach minimizes temporary memory consumption, further optimizing memory usage.

By synergizing these two methods, we have dramatically expanded the capabilities of VAE decoding. Our implementation successfully handles image resolutions up to 10240 × 10240 pixels - an impressive 11-fold increase compared to the conventional VAE approach.

This advancement represents a significant leap forward in high-resolution image processing, opening new possibilities for applications in various domains of computer vision and image generation.


## Cite Us
Expand Down
2 changes: 1 addition & 1 deletion pipefuser/modules/patchvae
9 changes: 6 additions & 3 deletions pipefuser/pipelines/pixartalpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@


class DistriPixArtAlphaPipeline:
def __init__(self, pipeline: PixArtAlphaPipeline, module_config: DistriConfig, enable_parallel_vae: bool = False):
def __init__(self, pipeline: PixArtAlphaPipeline, module_config: DistriConfig, enable_parallel_vae: bool = False, use_profiler: bool = False):
self.pipeline = pipeline
if enable_parallel_vae:
self.pipeline.vae.decoder = DecoderAdapter(self.pipeline.vae.decoder)
self.pipeline.vae.decoder = DecoderAdapter(self.pipeline.vae.decoder, use_profiler=use_profiler)

# assert module_config.do_classifier_free_guidance == False
assert module_config.split_batch == False
Expand All @@ -56,6 +56,7 @@ def from_pretrained(distri_config: DistriConfig, **kwargs):
)
torch_dtype = kwargs.pop("torch_dtype", torch.float16)
enable_parallel_vae = kwargs.pop("enable_parallel_vae", False)
use_profiler = kwargs.pop("use_profiler", False)
transformer = Transformer2DModel.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
Expand Down Expand Up @@ -88,6 +89,8 @@ def from_pretrained(distri_config: DistriConfig, **kwargs):
scheduler = DDIMSchedulerPiP.from_pretrained(
pretrained_model_name_or_path, subfolder="scheduler"
)
else:
raise ValueError(f"scheduler do not support in pipefusion paralleliem: {distri_config.scheduler}")
scheduler.init(distri_config)

if distri_config.parallelism == "pipefusion":
Expand All @@ -110,7 +113,7 @@ def from_pretrained(distri_config: DistriConfig, **kwargs):
peak_memory = torch.cuda.max_memory_allocated(device="cuda")
print(f"DistriPixArtAlphaPipeline from pretrain stage 2 {peak_memory/1e9} GB")

ret = DistriPixArtAlphaPipeline(pipeline, distri_config, enable_parallel_vae=enable_parallel_vae)
ret = DistriPixArtAlphaPipeline(pipeline, distri_config, enable_parallel_vae=enable_parallel_vae, use_profiler=use_profiler)

peak_memory = torch.cuda.max_memory_allocated(device="cuda")
print(f"DistriPixArtAlphaPipeline from pretrain stage 3 {peak_memory/1e9} GB")
Expand Down
10 changes: 7 additions & 3 deletions scripts/pixart_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def main():
distri_config=distri_config,
pretrained_model_name_or_path=args.model_id,
enable_parallel_vae=enable_parallel_vae,
# variant="fp16",
# use_safetensors=True,
use_profiler=args.use_profiler,
)

pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
Expand All @@ -162,11 +161,16 @@ def main():
case_name = f"{args.parallelism}_hw_{args.height}_sync_{args.sync_mode}_u{args.ulysses_degree}_w{distri_config.world_size}"
if args.output_file:
case_name = args.output_file + "_" + case_name
if enable_parallel_vae:
case_name += "_patchvae"

if args.use_profiler:
start_time = time.time()
with profile(
activities=[ProfilerActivity.CUDA],
activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"./profile/{case_name}"
),
Expand Down
12 changes: 7 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
author_email="muyangli@mit.edu",
packages=find_packages(),
install_requires=[
"torch>=2.2",
"diffusers==0.27.2",
"transformers",
"tqdm",
"torch>=2.2",
"diffusers==0.27.2",
"transformers",
"tqdm",
"sentencepiece",
"accelerate",
"beautifulsoup4",
"ftfy",
f"patchvae @ file://localhost/{os.path.join(os.getcwd(), 'pipefuser/modules/patchvae')}#egg=patchvae",
],
dependency_links=[
"file://"
+ os.path.join(
os.getcwd(), "pipefuser/modules/patchvae#egg=patchvae-0.0.0b1"
os.getcwd(), "pipefuser/modules/patchvae#egg=patchvae-0.0.0b3"
)
],
url="https://github.com/PipeFusion/PipeFusion.",
Expand Down