diff --git a/.gitignore b/.gitignore index 574baaa..40479c7 100644 --- a/.gitignore +++ b/.gitignore @@ -159,5 +159,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -resources/ +scripts/ outputs/ \ No newline at end of file diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md index 679fe5b..cdfa01d 100644 --- a/README.md +++ b/README.md @@ -1 +1,104 @@ -# MotionInversion \ No newline at end of file + + +## Motion Inversion for Video Customization + +[Luozhou Wang](https://wileewang.github.io/), [Guibao Shen](), [Yixun Liang](https://yixunliang.github.io/), [Xin Tao](http://www.xtao.website/), Pengfei Wan, Di Zhang, [Yijun Li](https://yijunmaverick.github.io/), [Yingcong Chen](https://www.yingcong.me) + +HKUST(GZ), HKUST, Kuaishou Technology, Adobe Research. + + + we present a novel approach to motion customization in video generation, addressing the widespread gap in the thorough exploration of motion representation within video generative models. Recognizing the unique challenges posed by video's spatiotemporal nature, our method introduces **Motion Embeddings**, a set of explicit, temporally coherent one-dimensional embeddings derived from a given video. These embeddings are designed to integrate seamlessly with the temporal transformer modules of video diffusion models, modulating self-attention computations across frames without compromising spatial integrity. Furthermore, we identify the **Temporal Discrepancy** in video generative models, which refers to variations in how different motion modules process temporal relationships between frames. We leverage this understanding to optimize the integration of our motion embeddings. + + +

Customizing motion of your video with less than 1m parameters and 10 minutes.

+ +Your content is generally clear and well-structured. I've made some minor grammatical corrections and clarity improvements: + +## 📰 News + +* **[2024.03.31]** We have released the project page, arXiv paper, and training code. + +## 🚧 Todo List +* [x] Released code for the UNet3D model (ZeroScope, ModelScope). + +* [ ] Release code for the Sora-like model (Open-Sora, Latte). + + + +## Contents + +* [Installation](#installation) +* [Training](#training) +* [Inference](#inference) +* [Acknowledgement](#acknowledgement) +* [Citation](#citation) + + + +## Installation + +```bash +# install torch +pip install torch torchvision + +# install diffusers and transformers +pip install diffusers==0.26.3 transformers==4.27.4 +``` + + +## Training + +To start training, first download the [ZeroScope](https://huggingface.co/cerspense/zeroscope_v2_576w) weights and specify the path in the config file. Then, run the following commands to begin training: + +```bash +python train.py --config ./configs/train_config.yaml +``` + +Stay tuned for training other models and advanced usage! + +## Inference + +```bash +python inference.py --config ./configs/inference_config.yaml +``` + +We will also provide a Gradio application in this repository. + + +## Acknowledgement + +* [MotionDirector](https://github.com/showlab/MotionDirector): We followed their implementation of loss design and techniques to reduce computation resources. +* [ZeroScope](https://huggingface.co/cerspense/zeroscope_v2_576w): The pretrained video checkpoint we used in our main paper. +* [AnimateDiff](https://github.com/guoyww/animatediff/): The pretrained video checkpoint we used in our main paper. +* [Latte](https://github.com/Vchitect/Latte): A video generation model with a similar architecture to Sora. +* [Open-Sora](https://github.com/hpcaitech/Open-Sora): A video generation model with a similar architecture to Sora. + +We are grateful for their exceptional work and generous contribution to the open-source community. + +## Citation + + + + \ No newline at end of file diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..299b402 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,5 @@ +from .cached_dataset import CachedDataset +from .image_dataset import ImageDataset +from .single_video_dataset import SingleVideoDataset +from .video_folder_dataset import VideoFolderDataset +from .video_json_dataset import VideoJsonDataset \ No newline at end of file diff --git a/dataset/cached_dataset.py b/dataset/cached_dataset.py new file mode 100644 index 0000000..ac4d52c --- /dev/null +++ b/dataset/cached_dataset.py @@ -0,0 +1,17 @@ +from utils.dataset_utils import * + +class CachedDataset(Dataset): + def __init__(self,cache_dir: str = ''): + self.cache_dir = cache_dir + self.cached_data_list = self.get_files_list() + + def get_files_list(self): + tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')] + return sorted(tensors_list) + + def __len__(self): + return len(self.cached_data_list) + + def __getitem__(self, index): + cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0') + return cached_latent \ No newline at end of file diff --git a/dataset/image_dataset.py b/dataset/image_dataset.py new file mode 100644 index 0000000..2a06168 --- /dev/null +++ b/dataset/image_dataset.py @@ -0,0 +1,95 @@ +from utils.dataset_utils import * + +class ImageDataset(Dataset): + + def __init__( + self, + tokenizer = None, + width: int = 256, + height: int = 256, + base_width: int = 256, + base_height: int = 256, + use_caption: bool = False, + image_dir: str = '', + single_img_prompt: str = '', + use_bucketing: bool = False, + fallback_prompt: str = '', + **kwargs + ): + self.tokenizer = tokenizer + self.img_types = (".png", ".jpg", ".jpeg", '.bmp') + self.use_bucketing = use_bucketing + + self.image_dir = self.get_images_list(image_dir) + self.fallback_prompt = fallback_prompt + + self.use_caption = use_caption + self.single_img_prompt = single_img_prompt + + self.width = width + self.height = height + + def get_images_list(self, image_dir): + if os.path.exists(image_dir): + imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)] + full_img_dir = [] + + for img in imgs: + full_img_dir.append(f"{image_dir}/{img}") + + return sorted(full_img_dir) + + return [''] + + def image_batch(self, index): + train_data = self.image_dir[index] + img = train_data + + try: + img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB) + except: + img = T.transforms.PILToTensor()(Image.open(img).convert("RGB")) + + width = self.width + height = self.height + + if self.use_bucketing: + _, h, w = img.shape + width, height = sensible_buckets(width, height, w, h) + + resize = T.transforms.Resize((height, width), antialias=True) + + img = resize(img) + img = repeat(img, 'c h w -> f c h w', f=16) + + prompt = get_text_prompt( + file_path=train_data, + text_prompt=self.single_img_prompt, + fallback_prompt=self.fallback_prompt, + ext_types=self.img_types, + use_caption=True + ) + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return img, prompt, prompt_ids + + @staticmethod + def __getname__(): return 'image' + + def __len__(self): + # Image directory + if os.path.exists(self.image_dir[0]): + return len(self.image_dir) + else: + return 0 + + def __getitem__(self, index): + img, prompt, prompt_ids = self.image_batch(index) + example = { + "pixel_values": (img / 127.5 - 1.0), + "prompt_ids": prompt_ids[0], + "text_prompt": prompt, + 'dataset': self.__getname__() + } + + return example \ No newline at end of file diff --git a/dataset/single_video_dataset.py b/dataset/single_video_dataset.py new file mode 100644 index 0000000..023a6a4 --- /dev/null +++ b/dataset/single_video_dataset.py @@ -0,0 +1,102 @@ +from utils.dataset_utils import * + +class SingleVideoDataset(Dataset): + def __init__( + self, + tokenizer = None, + width: int = 256, + height: int = 256, + n_sample_frames: int = 4, + frame_step: int = 1, + single_video_path: str = "", + single_video_prompt: str = "", + use_caption: bool = False, + use_bucketing: bool = False, + **kwargs + ): + self.tokenizer = tokenizer + self.use_bucketing = use_bucketing + self.frames = [] + self.index = 1 + + self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") + self.n_sample_frames = n_sample_frames + self.frame_step = frame_step + + self.single_video_path = single_video_path + self.single_video_prompt = single_video_prompt + + self.width = width + self.height = height + + def create_video_chunks(self): + vr = decord.VideoReader(self.single_video_path) + vr_range = range(0, len(vr), self.frame_step) + + self.frames = list(self.chunk(vr_range, self.n_sample_frames)) + return self.frames + + def chunk(self, it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + def get_frame_batch(self, vr, resize=None): + index = self.index + frames = vr.get_batch(self.frames[self.index]) + video = rearrange(frames, "f h w c -> f c h w") + + if resize is not None: video = resize(video) + return video + + def get_frame_buckets(self, vr): + h, w, c = vr[0].shape + width, height = sensible_buckets(self.width, self.height, w, h) + resize = T.transforms.Resize((height, width), antialias=True) + + return resize + + def process_video_wrapper(self, vid_path): + video, vr = process_video( + vid_path, + self.use_bucketing, + self.width, + self.height, + self.get_frame_buckets, + self.get_frame_batch + ) + + return video, vr + + def single_video_batch(self, index): + train_data = self.single_video_path + self.index = index + + if train_data.endswith(self.vid_types): + video, _ = self.process_video_wrapper(train_data) + + prompt = self.single_video_prompt + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return video, prompt, prompt_ids + else: + raise ValueError(f"Single video is not a video type. Types: {self.vid_types}") + + @staticmethod + def __getname__(): return 'single_video' + + def __len__(self): + + return len(self.create_video_chunks()) + + def __getitem__(self, index): + + video, prompt, prompt_ids = self.single_video_batch(index) + + example = { + "pixel_values": (video / 127.5 - 1.0), + "prompt_ids": prompt_ids[0], + "text_prompt": prompt, + 'dataset': self.__getname__() + } + + return example \ No newline at end of file diff --git a/dataset/video_folder_dataset.py b/dataset/video_folder_dataset.py new file mode 100644 index 0000000..2d049f4 --- /dev/null +++ b/dataset/video_folder_dataset.py @@ -0,0 +1,90 @@ +from utils.dataset_utils import * + +class VideoFolderDataset(Dataset): + def __init__( + self, + tokenizer=None, + width: int = 256, + height: int = 256, + n_sample_frames: int = 16, + fps: int = 8, + path: str = "./data", + fallback_prompt: str = "", + use_bucketing: bool = False, + **kwargs + ): + self.tokenizer = tokenizer + self.use_bucketing = use_bucketing + + self.fallback_prompt = fallback_prompt + + self.video_files = glob(f"{path}/*.mp4") + + self.width = width + self.height = height + + self.n_sample_frames = n_sample_frames + self.fps = fps + + def get_frame_buckets(self, vr): + h, w, c = vr[0].shape + width, height = sensible_buckets(self.width, self.height, w, h) + resize = T.transforms.Resize((height, width), antialias=True) + + return resize + + def get_frame_batch(self, vr, resize=None): + n_sample_frames = self.n_sample_frames + native_fps = vr.get_avg_fps() + + every_nth_frame = max(1, round(native_fps / self.fps)) + every_nth_frame = min(len(vr), every_nth_frame) + + effective_length = len(vr) // every_nth_frame + if effective_length < n_sample_frames: + n_sample_frames = effective_length + + effective_idx = random.randint(0, (effective_length - n_sample_frames)) + idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames) + + video = vr.get_batch(idxs) + video = rearrange(video, "f h w c -> f c h w") + + if resize is not None: video = resize(video) + return video, vr + + def process_video_wrapper(self, vid_path): + video, vr = process_video( + vid_path, + self.use_bucketing, + self.width, + self.height, + self.get_frame_buckets, + self.get_frame_batch + ) + return video, vr + + def get_prompt_ids(self, prompt): + return self.tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + @staticmethod + def __getname__(): return 'folder' + + def __len__(self): + return len(self.video_files) + + def __getitem__(self, index): + + video, _ = self.process_video_wrapper(self.video_files[index]) + + prompt = self.fallback_prompt + + prompt_ids = self.get_prompt_ids(prompt) + + return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()} \ No newline at end of file diff --git a/dataset/video_json_dataset.py b/dataset/video_json_dataset.py new file mode 100644 index 0000000..9b693c2 --- /dev/null +++ b/dataset/video_json_dataset.py @@ -0,0 +1,183 @@ +from utils.dataset_utils import * + +# https://github.com/ExponentialML/Video-BLIP2-Preprocessor +class VideoJsonDataset(Dataset): + def __init__( + self, + tokenizer = None, + width: int = 256, + height: int = 256, + n_sample_frames: int = 4, + sample_start_idx: int = 1, + frame_step: int = 1, + json_path: str ="", + json_data = None, + vid_data_key: str = "video_path", + preprocessed: bool = False, + use_bucketing: bool = False, + **kwargs + ): + self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") + self.use_bucketing = use_bucketing + self.tokenizer = tokenizer + self.preprocessed = preprocessed + + self.vid_data_key = vid_data_key + self.train_data = self.load_from_json(json_path, json_data) + + self.width = width + self.height = height + + self.n_sample_frames = n_sample_frames + self.sample_start_idx = sample_start_idx + self.frame_step = frame_step + + def build_json(self, json_data): + extended_data = [] + for data in json_data['data']: + for nested_data in data['data']: + self.build_json_dict( + data, + nested_data, + extended_data + ) + json_data = extended_data + return json_data + + def build_json_dict(self, data, nested_data, extended_data): + clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None + + extended_data.append({ + self.vid_data_key: data[self.vid_data_key], + 'frame_index': nested_data['frame_index'], + 'prompt': nested_data['prompt'], + 'clip_path': clip_path + }) + + def load_from_json(self, path, json_data): + try: + with open(path) as jpath: + print(f"Loading JSON from {path}") + json_data = json.load(jpath) + + return self.build_json(json_data) + + except: + self.train_data = [] + print("Non-existant JSON path. Skipping.") + + def validate_json(self, base_path, path): + return os.path.exists(f"{base_path}/{path}") + + def get_frame_range(self, vr): + return get_video_frames( + vr, + self.sample_start_idx, + self.frame_step, + self.n_sample_frames + ) + + def get_vid_idx(self, vr, vid_data=None): + frames = self.n_sample_frames + + if vid_data is not None: + idx = vid_data['frame_index'] + else: + idx = self.sample_start_idx + + return idx + + def get_frame_buckets(self, vr): + _, h, w = vr[0].shape + width, height = sensible_buckets(self.width, self.height, h, w) + # width, height = self.width, self.height + resize = T.transforms.Resize((height, width), antialias=True) + + return resize + + def get_frame_batch(self, vr, resize=None): + frame_range = self.get_frame_range(vr) + frames = vr.get_batch(frame_range) + video = rearrange(frames, "f h w c -> f c h w") + + if resize is not None: video = resize(video) + return video + + def process_video_wrapper(self, vid_path): + video, vr = process_video( + vid_path, + self.use_bucketing, + self.width, + self.height, + self.get_frame_buckets, + self.get_frame_batch + ) + + return video, vr + + def train_data_batch(self, index): + + # If we are training on individual clips. + if 'clip_path' in self.train_data[index] and \ + self.train_data[index]['clip_path'] is not None: + + vid_data = self.train_data[index] + + clip_path = vid_data['clip_path'] + + # Get video prompt + prompt = vid_data['prompt'] + + video, _ = self.process_video_wrapper(clip_path) + + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return video, prompt, prompt_ids + + # Assign train data + train_data = self.train_data[index] + + # Get the frame of the current index. + self.sample_start_idx = train_data['frame_index'] + + # Initialize resize + resize = None + + video, vr = self.process_video_wrapper(train_data[self.vid_data_key]) + + # Get video prompt + prompt = train_data['prompt'] + vr.seek(0) + + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return video, prompt, prompt_ids + + @staticmethod + def __getname__(): return 'json' + + def __len__(self): + if self.train_data is not None: + return len(self.train_data) + else: + return 0 + + def __getitem__(self, index): + + # Initialize variables + video = None + prompt = None + prompt_ids = None + + # Use default JSON training + if self.train_data is not None: + video, prompt, prompt_ids = self.train_data_batch(index) + + example = { + "pixel_values": (video / 127.5 - 1.0), + "prompt_ids": prompt_ids[0], + "text_prompt": prompt, + 'dataset': self.__getname__() + } + + return example \ No newline at end of file diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000..5ecbbf5 --- /dev/null +++ b/loss/__init__.py @@ -0,0 +1,4 @@ +from .base_loss import BaseLoss +from .debiashybrid_loss import DebiasHybridLoss +from .debias_loss import DebiasLoss + diff --git a/loss/base_loss.py b/loss/base_loss.py new file mode 100644 index 0000000..65a38c1 --- /dev/null +++ b/loss/base_loss.py @@ -0,0 +1,73 @@ +from utils.func_utils import * + +def BaseLoss( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ): + cache_latents = config.train.cache_latents + + if not cache_latents: + latents = tensor_to_vae_latent(batch["pixel_values"], vae) + else: + latents = batch["latents"] + + # Sample noise that we'll add to the latents + # use_offset_noise = use_offset_noise and not rescale_schedule + + noise = sample_noise(latents, 0.1, False) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # *Potentially* Fixes gradient checkpointing training. + # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + # if kwargs.get('eval_train', False): + # unet.eval() + # text_encoder.eval() + + # Encode text embeddings + token_ids = batch['prompt_ids'] + encoder_hidden_states = text_encoder(token_ids)[0] + detached_encoder_state = encoder_hidden_states.clone().detach() + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + encoder_hidden_states = detached_encoder_state + + + # optimization + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean() + train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps + + accelerator.backward(loss_temporal) + optimizers[0].step() + lr_schedulers[0].step() + + return loss_temporal, train_loss_temporal + diff --git a/loss/debias_loss.py b/loss/debias_loss.py new file mode 100644 index 0000000..7cb3872 --- /dev/null +++ b/loss/debias_loss.py @@ -0,0 +1,84 @@ +from utils.lora import extract_lora_child_module +from utils.func_utils import * + +def DebiasLoss( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ): + cache_latents = config.train.cache_latents + + + + if not cache_latents: + latents = tensor_to_vae_latent(batch["pixel_values"], vae) + else: + latents = batch["latents"] + + # Sample noise that we'll add to the latents + # use_offset_noise = use_offset_noise and not rescale_schedule + + noise = sample_noise(latents, 0.1, False) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # *Potentially* Fixes gradient checkpointing training. + # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + # if kwargs.get('eval_train', False): + # unet.eval() + # text_encoder.eval() + + # Encode text embeddings + token_ids = batch['prompt_ids'] + encoder_hidden_states = text_encoder(token_ids)[0] + detached_encoder_state = encoder_hidden_states.clone().detach() + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + encoder_hidden_states = detached_encoder_state + + + # optimization + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + beta = 1 + alpha = (beta ** 2 + 1) ** 0.5 + ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item() + model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2) + target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2) + loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean") + loss_temporal = loss_temporal + loss_ad_temporal + + avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean() + train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps + + accelerator.backward(loss_temporal) + optimizers[0].step() + + lr_schedulers[0].step() + + return loss_temporal, train_loss_temporal \ No newline at end of file diff --git a/loss/debiashybrid_loss.py b/loss/debiashybrid_loss.py new file mode 100644 index 0000000..ea9d260 --- /dev/null +++ b/loss/debiashybrid_loss.py @@ -0,0 +1,144 @@ +from utils.lora import extract_lora_child_module +from utils.func_utils import * + +def DebiasHybridLoss( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ): + mask_spatial_lora = random.uniform(0, 1) < 0.2 + random_hflip_img = config.loss.random_hflip_img + spatial_lora_num = config.loss.spatial_lora_num + cache_latents = config.train.cache_latents + + + + if not cache_latents: + latents = tensor_to_vae_latent(batch["pixel_values"], vae) + else: + latents = batch["latents"] + + # Sample noise that we'll add to the latents + # use_offset_noise = use_offset_noise and not rescale_schedule + + noise = sample_noise(latents, 0.1, False) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # *Potentially* Fixes gradient checkpointing training. + # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + # if kwargs.get('eval_train', False): + # unet.eval() + # text_encoder.eval() + + # Encode text embeddings + token_ids = batch['prompt_ids'] + encoder_hidden_states = text_encoder(token_ids)[0] + detached_encoder_state = encoder_hidden_states.clone().detach() + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + encoder_hidden_states = detached_encoder_state + + + # optimization + if mask_spatial_lora: + loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) + for lora_i in loras: + lora_i.scale = 0. + loss_spatial = None + else: + loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) + + if spatial_lora_num == 1: + for lora_i in loras: + lora_i.scale = 1. + else: + for lora_i in loras: + lora_i.scale = 0. + + for lora_idx in range(0, len(loras), spatial_lora_num): + loras[lora_idx + step].scale = 1. + + loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"]) + if len(loras) > 0: + for lora_i in loras: + lora_i.scale = 0. + + ran_idx = torch.randint(0, noisy_latents.shape[2], (1,)).item() + + if random.uniform(0, 1) < random_hflip_img: + pixel_values_spatial = transforms.functional.hflip( + batch["pixel_values"][:, ran_idx, :, :, :]).unsqueeze(1) + latents_spatial = tensor_to_vae_latent(pixel_values_spatial, vae) + noise_spatial = sample_noise(latents_spatial, 0.1, False) + noisy_latents_input = noise_scheduler.add_noise(latents_spatial, noise_spatial, timesteps) + target_spatial = noise_spatial + model_pred_spatial = unet(noisy_latents_input, timesteps, + encoder_hidden_states=encoder_hidden_states).sample + loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(), + target_spatial[:, :, 0, :, :].float(), reduction="mean") + else: + noisy_latents_input = noisy_latents[:, :, ran_idx, :, :] + target_spatial = target[:, :, ran_idx, :, :] + model_pred_spatial = unet(noisy_latents_input.unsqueeze(2), timesteps, + encoder_hidden_states=encoder_hidden_states).sample + loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(), + target_spatial.float(), reduction="mean") + + + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + beta = 1 + alpha = (beta ** 2 + 1) ** 0.5 + ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item() + model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2) + target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2) + loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean") + loss_temporal = loss_temporal + loss_ad_temporal + + avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean() + train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps + + if not mask_spatial_lora: + accelerator.backward(loss_spatial, retain_graph=True) + if spatial_lora_num == 1: + optimizers[1].step() + else: + optimizers[step+1].step() + + accelerator.backward(loss_temporal) + optimizers[0].step() + + if spatial_lora_num == 1: + lr_schedulers[1].step() + else: + lr_schedulers[1 + step].step() + + lr_schedulers[0].step() + + return loss_temporal, train_loss_temporal \ No newline at end of file diff --git a/main.ipynb b/main.ipynb deleted file mode 100644 index 283623b..0000000 --- a/main.ipynb +++ /dev/null @@ -1,1220 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 1. Prepare model" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The config attributes {'motion_activation_fn': 'geglu', 'motion_attention_bias': False, 'motion_cross_attention_dim': None} were passed to MotionAdapter, but are not expected and will be ignored. Please verify your config.json configuration file.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "http_proxy: http://oversea-squid5.sgp.txyun:11080\n", - "https_proxy: http://oversea-squid5.sgp.txyun:11080\n", - "no_proxy: localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading pipeline components...: 100%|██████████| 5/5 [00:33<00:00, 6.61s/it]\n" - ] - } - ], - "source": [ - "import os\n", - "import torch\n", - "from pipelines.pipeline_animatediff import *\n", - "from diffusers.schedulers import DDIMInverseScheduler\n", - "from diffusers.utils import export_to_gif, export_to_video, load_image\n", - "from utils.attn_utils import *\n", - "\n", - "# Set proxy environment variables\n", - "os.environ['http_proxy'] = 'http://oversea-squid5.sgp.txyun:11080'\n", - "os.environ['https_proxy'] = 'http://oversea-squid5.sgp.txyun:11080'\n", - "os.environ['no_proxy'] = 'localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com'\n", - "\n", - "# Verify the setting\n", - "print(\"http_proxy:\", os.environ.get('http_proxy'))\n", - "print(\"https_proxy:\", os.environ.get('https_proxy'))\n", - "print(\"no_proxy:\", os.environ.get('no_proxy'))\n", - "\n", - "\n", - "# Load the motion adapter\n", - "adapter = MotionAdapter.from_pretrained(\"/home/wangluozhou/projects/AnimateDiff/models/Motion_Module/animatediff-motion-adapter-v1-5-2\", torch_dtype=torch.float32)\n", - "# Load the controlnet\n", - "# controlnet = ControlNetModel.from_pretrained('/home/wangluozhou/pretrained_models/sd-controlnet-depth', torch_dtype=torch.float16)\n", - "# load SD 1.5 based finetuned model\n", - "model_id = \"/home/wangluozhou/pretrained_models/zeroscope_v2_576w\"\n", - "pipe = VideoDiffPipeline.from_pretrained(\n", - " model_id, \n", - " motion_adapter=None, \n", - " controlnet=None, \n", - " use_motion_mid_block=True,\n", - " torch_dtype=torch.float32\n", - ")\n", - "\n", - "pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder='scheduler')\n", - "device = torch.device('cuda')\n", - "\n", - "# enable memory savings\n", - "pipe.enable_vae_slicing()\n", - "\n", - "pipe.enable_model_cpu_offload()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "from pipelines.pipeline_animatediff import *\n", - "from diffusers.schedulers import DDIMInverseScheduler\n", - "from diffusers.utils import export_to_gif, export_to_video, load_image\n", - "from utils.attn_utils import *\n", - "\n", - "model_id = \"/home/wangluozhou/pretrained_models/Realistic_Vision_V6.0_B1_noVAE\"\n", - "\n", - "tokenizer = CLIPTokenizer.from_pretrained(\n", - " model_id, subfolder=\"tokenizer\", revision=None)\n", - " \n", - "motion_adapter = MotionAdapter.from_pretrained(\n", - " \"/home/wangluozhou/projects/AnimateDiff/models/Motion_Module/animatediff-motion-adapter-v1-5-2\",\n", - " variant=\"fp16\",\n", - " torch_dtype=torch.float16\n", - ")\n", - "text_encoder = CLIPTextModel.from_pretrained(\n", - " model_id, subfolder=\"text_encoder\", revision=None\n", - ")\n", - "vae = AutoencoderKL.from_pretrained(\n", - " model_id, subfolder=\"vae\", revision=None)\n", - "\n", - "# unet = UNet2DConditionModel.from_pretrained(\n", - "# model_id,\n", - "# subfolder=\"unet\",\n", - "# low_cpu_mem_usage=True,\n", - "# )\n", - "unet = UNetMotionModel.from_unet2d(UNet2DConditionModel.from_pretrained(\n", - " model_id,\n", - " subfolder=\"unet\",\n", - " low_cpu_mem_usage=True,\n", - "), motion_adapter)\n", - "\n", - "\n", - "pipe = AnimateDiffPipeline.from_pretrained(\n", - " model_id, \n", - " motion_adapter=None, \n", - " controlnet=None, \n", - " use_motion_mid_block=True,\n", - " use_safetensors=True,\n", - " torch_dtype=torch.float16)\n", - "pipe.unet = unet.to(device='cuda',dtype=torch.float16)\n", - "# pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder='scheduler')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipe.unet.config['use_motion_mid_block']" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 2. Text-to-Video Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output = pipe(\n", - " prompt=(\n", - " # \"a man\"\n", - " \"orange sky, warm lighting, fishing boats, ocean waves seagulls, \"\n", - " \"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, \"\n", - " \"golden hour, coastal landscape, seaside scenery\"\n", - " ),\n", - " negative_prompt=\"bad quality, worse quality\",\n", - " height=256,\n", - " width=256,\n", - " num_frames=16,\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42),\n", - ")\n", - "frames = output.frames[0]\n", - "export_to_gif(frames, \"animation_16.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 3. Video Editing" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3.1 Load Source Video" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load video and into frames\n", - "frames = load_video('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/locomotive_run.mp4')\n", - "\n", - "# 1. encode frames into batch of latents\n", - "latents_frames = pipe.encode_frames(frames, device=device)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3.1.1 Inverse with noise" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "inv_latent, _ = pipe.add_noise_to_latents(\n", - " init_latents=latents_frames, \n", - " strength=0.8,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42),\n", - " num_inference_steps=25,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3.1.2 Inverse with DDIM" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipe.scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder='scheduler')\n", - "inv_latent = pipe(\n", - " prompt=\"\", \n", - " negative_prompt=\"\",\n", - " num_frames=16,\n", - " guidance_scale=7.5,\n", - " output_type='latent', \n", - " num_inference_steps=25,\n", - " strength=0.8, \n", - " latents=latents_frames,\n", - " inverse=True,\n", - " ).frames" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3.2 Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder='scheduler')\n", - "output = pipe(\n", - " prompt=\"a pretty girl, white singlet, dark pants, on the stage\",\n", - " negative_prompt=\"\",\n", - " num_frames=16,\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " latents=inv_latent,\n", - " # frames=frames_controlnet,\n", - " strength=0.8,\n", - " # generator=torch.Generator(\"cpu\").manual_seed(42),\n", - ")\n", - "frames = output.frames[0]\n", - "export_to_gif(frames, \"/home/wangluozhou/projects/VideoDiffusion_Playground/resources/Human/sample_3_edit.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 4. Text-to-Video Generation with ControlNets" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4.1 Load Control Signal" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from controlnet_aux.processor import Processor\n", - "processor = Processor(\"depth_midas\")\n", - "\n", - "# load video and into frames\n", - "frames = load_video('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/Animals/sample_0_src.mp4')\n", - "\n", - "frames_controlnet = []\n", - "for frame in frames:\n", - " frames_controlnet.append(processor(frame, to_pil=True))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "frames[0].size" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4.2 Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder='scheduler')\n", - "output = pipe(\n", - " prompt=\"a sea lion, lying on the ice, winter, snow\",\n", - " negative_prompt=\"\",\n", - " num_frames=16,\n", - " height=320,\n", - " width=512,\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " frames_controlnet=frames_controlnet,\n", - " strength=1.0,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42),\n", - ")\n", - "frames = output.frames[0]\n", - "export_to_gif(frames, \"/home/wangluozhou/projects/VideoDiffusion_Playground/resources/Animals/sample_0_edit_2.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 5. Image-to-Video Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def build_curve_tensor(max_value, min_value, length, frames, strategy='linear'):\n", - " \"\"\"\n", - " Build a curve based on the given strategy and return it as a PyTorch tensor.\n", - " The curve starts from the min_value and increases to the max_value.\n", - "\n", - " Parameters:\n", - " max_value (float): The maximum value of the curve.\n", - " min_value (float): The minimum value of the curve.\n", - " length (int): The length over which the curve changes from min to max.\n", - " frames (int): The total number of frames in the curve.\n", - " strategy (str): The strategy for building the curve. Options: 'linear', 'exponential', 'logarithmic'.\n", - "\n", - " Returns:\n", - " torch.Tensor: A tensor representing the curve.\n", - " \"\"\"\n", - "\n", - " if strategy == 'linear':\n", - " # Linear increase from min_value to max_value over 'length' frames, then constant\n", - " curve = np.linspace(max_value, min_value, length)\n", - " curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value)\n", - "\n", - " elif strategy == 'exponential':\n", - " # Exponential increase from min_value to max_value\n", - " curve = np.geomspace(max_value, min_value, length)\n", - " curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value)\n", - "\n", - " elif strategy == 'logarithmic':\n", - " # Logarithmic increase from min_value to max_value\n", - " log_space = np.linspace(1, length + 1, length)\n", - " curve = (np.log(log_space) / np.log(length + 1)) * (min_value - max_value) + min_value\n", - " curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value)\n", - "\n", - " else:\n", - " raise ValueError(\"Unknown strategy: Choose from 'linear', 'exponential', 'logarithmic'\")\n", - "\n", - " # Convert the numpy array to a PyTorch tensor\n", - " return torch.from_numpy(curve)\n", - "\n", - "# Example usage with reversed curve\n", - "# curve_tensor_reversed = build_curve_tensor_reversed(1, 0.5, 3, 16, strategy='linear')\n", - "# curve_tensor_reversed # Display the generated tensor curve\n", - "# Result\n", - "# tensor([0.5000, 0.7500, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n", - "# 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],\n", - "# dtype=torch.float64)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5.1 load source image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from video\n", - "frames = load_video('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/sample_5_src.mp4')\n", - "frames_inpaint = [frames[0]] * 16\n", - "latents_frames_inpaint = pipe.encode_frames(frames_inpaint, device=device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from image\n", - "frames = load_image('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/4.png')\n", - "frames_inpaint = [frames] * 16\n", - "latents_frames_inpaint = pipe.encode_frames(frames_inpaint, device=device)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5.2 Prepare inputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# prepare inital latents\n", - "latents = pipe.prepare_latents(\n", - " batch_size=1,\n", - " num_channels_latents=4,\n", - " num_frames=16,\n", - " height=frames_inpaint[0].size[1],\n", - " width=frames_inpaint[0].size[0],\n", - " dtype=torch.float16,\n", - " device=device,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42)\n", - ")\n", - "\n", - "mask_inpaint = torch.zeros_like(latents_frames_inpaint)\n", - "\n", - "# # Values to assign along the frames dimension\n", - "# frame_values = build_curve_tensor(\n", - "# max_value=1.0,\n", - "# min_value=0.5,\n", - "# length=8,\n", - "# frames=16,\n", - "# )\n", - "# frame_values[-1]=1\n", - "\n", - "mask_inpaint[:,:,0,:,:]=1\n", - "mask_inpaint[:,:,-1,:,:]=1\n", - "\n", - "# # Assign the values to each frame in the mask\n", - "# for i, value in enumerate(frame_values):\n", - "# mask_inpaint[:, :, i, :, :] = value" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5.3 generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder='scheduler')\n", - "output = pipe(\n", - " prompt=\"Sunny seaside with blue sky\",\n", - " negative_prompt=\"\",\n", - " num_frames=16,\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " latents=latents,\n", - " frames_inpaint=latents_frames_inpaint,\n", - " noise_inpaint=latents,\n", - " mask_inpaint=mask_inpaint,\n", - " strength=1.0,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42),\n", - ")\n", - "frames = output.frames[0]\n", - "export_to_gif(frames, \"sample_4_animation.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 6. Image Animation - Noise Rectification" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from image\n", - "frames = load_image('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/4.png')\n", - "frames_inpaint = [frames] * 16\n", - "latents_frames_inpaint = pipe.encode_frames(frames_inpaint, device=device)\n", - "\n", - "generator = torch.Generator(\"cpu\").manual_seed(42)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "inv_latent, init_noise = pipe.add_noise_to_latents(\n", - " init_latents=latents_frames_inpaint, \n", - " strength=1.0,\n", - " generator=generator,\n", - " num_inference_steps=25,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask_inpaint = torch.ones_like(inv_latent)\n", - "\n", - "# Values to assign along the frames dimension\n", - "frame_values, curves = build_curve_tensor(\n", - " max_value=1.0,\n", - " min_value=0.5,\n", - " length=8,\n", - " frames=16,\n", - ")\n", - "\n", - "plt.show()\n", - "# Assign the values to each frame in the mask\n", - "for i, value in enumerate(frame_values):\n", - " mask_inpaint[:, :, i, :, :] = value" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder='scheduler')\n", - "output = pipe(\n", - " prompt=\"Sunny seaside with blue sky\",\n", - " negative_prompt=\"\",\n", - " num_frames=16,\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " # rect_scheduled_sampling_beta=0.6,\n", - " latents=inv_latent,\n", - " # noise_rect=init_noise,\n", - " # mask_inpaint=mask_inpaint,\n", - " strength=1.0,\n", - " generator=generator,\n", - ")\n", - "frames = output.frames[0]\n", - "export_to_gif(frames, \"sample_4_animation_noise1.0.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 7. Video Outpainting" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7.1 Load source video" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from video\n", - "frames_inpaint = load_video('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/Outpainting/sample_7_src.mp4')\n", - "\n", - "latents_frames_inpaint = pipe.encode_frames(frames_inpaint, device=device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "frames_inpaint[0].size[1]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7.2 Prepare noise and mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# prepare inital latents\n", - "latents = pipe.prepare_latents(\n", - " batch_size=1,\n", - " num_channels_latents=4,\n", - " num_frames=16,\n", - " height=frames_inpaint[0].size[1],\n", - " width=frames_inpaint[0].size[0],\n", - " dtype=torch.float16,\n", - " device=device,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# w/o motion adapter\n", - "\n", - "# [bs, channels, frames, height, width] -> -> [bs * frames, channels, height, width]\n", - "frames_inpaint = frames_inpaint.permute(0,2,1,3,4).reshape((latents.shape[0] * num_frames, -1) + frames_inpaint.shape[3:])\n", - "\n", - "# [bs * frames, channels, height, width]\n", - "mask_inpaint = torch.zeros_like(frames_inpaint)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# with motion adapter\n", - "# [bs, channels frames, height, width]\n", - "\n", - "mask_inpaint = torch.ones_like(latents_frames_inpaint)\n", - "mask_inpaint[:, :, :, mask_inpaint.shape[3]//4:mask_inpaint.shape[3]//4 * 3, :] = 0" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7.3 Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder='scheduler')\n", - "output = pipe(\n", - " prompt=\"a pretty girl, grey t-shirt\",\n", - " negative_prompt=\"\",\n", - " num_frames=16,\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " latents=latents,\n", - " frames_inpaint=latents_frames_inpaint,\n", - " noise_inpaint=latents,\n", - " mask_inpaint=mask_inpaint,\n", - " strength=1.0,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42),\n", - ")\n", - "frames = output.frames[0]\n", - "export_to_gif(frames, \"/home/wangluozhou/projects/VideoDiffusion_Playground/resources/Outpainting/sample_7_edit.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 8. Frames Attention Analysis" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8.1 Prepare Inputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "latents = pipe.prepare_latents(\n", - " batch_size=1,\n", - " num_channels_latents=4,\n", - " num_frames=16,\n", - " height=512,\n", - " width=512,\n", - " dtype=torch.float16,\n", - " device=device,\n", - " generator=torch.Generator(\"cpu\").manual_seed(42)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "latents_2 = pipe.prepare_latents(\n", - " batch_size=1,\n", - " num_channels_latents=4,\n", - " num_frames=8,\n", - " height=512,\n", - " width=512,\n", - " dtype=torch.float16,\n", - " device=device,\n", - " generator=torch.Generator(\"cpu\").manual_seed(0)\n", - ")\n", - "latents = torch.cat([latents, latents_2], dim=2)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8.2 Prepare Attention Controller" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "controller = AttentionStore()\n", - "register_attention_control(pipe, controller=controller)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "controller.target_keys = ('down_self',)\n", - "controller.target_resolutions = [16]\n", - "# prompts = [\"a man is surfing\", \"a cat is climbing\", \"a dog is running\"]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8.3 Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# controller.reset()\n", - "output = pipe(\n", - " prompt=[\n", - " (\n", - " \"a spiderman is surfing\"\n", - " ),\n", - " # (\n", - " # \"masterpiece, bestquality, highlydetailed, ultradetailed, sunset, \"\n", - " # \"orange sky, warm lighting, fishing boats, ocean waves seagulls, \"\n", - " # \"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, \"\n", - " # \"golden hour, coastal landscape, seaside scenery\"\n", - " # )\n", - " # (\n", - " # \"a man is surfing\"\n", - " # )\n", - " ],\n", - " negative_prompt=[\n", - " \"bad quality, worse quality\",\n", - " # \"bad quality, worse quality\"\n", - " ],\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " latents=latents\n", - ")\n", - "# frames = output.frames[1]\n", - "# export_to_gif(frames, \"animation_24_animatediff.gif\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "export_to_gif(output.frames[1], \"outputs/animation_16_ad_seed42_bs1_attn_down_16_32.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8.4 Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "controller.attention_store['down_self'][0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "compute_average_map(controller.attention_store['down_self'], frames=16, pixel_size=5, reduction='spatial')[1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "build_image_grid(controller.attention_store['up_self'], frames=16, pixel_size=20)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8.5 Batch Run" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from itertools import chain, combinations\n", - "import random\n", - "\n", - "def all_combinations(lst):\n", - " return chain(*map(lambda x: combinations(lst, x), range(0, len(lst) + 1)))\n", - "\n", - "controller_target_keys = ['down_self', 'mid_self', 'up_self']\n", - "controller_target_resolutions = [16, 32, 64, 128]\n", - "prompts = [\"a man is surfing\", \"a cat is climbing\", \"a dog is running\"]\n", - "\n", - "key_combinations = list(all_combinations(controller_target_keys))\n", - "resolution_combinations = [controller_target_resolutions[:i + 1] for i in range(len(controller_target_resolutions))]\n", - "\n", - "parameter_combinations = []\n", - "for keys in key_combinations:\n", - " if not keys:\n", - " resolutions_combinations = [[]] # Skip resolution combinations if no key is selected\n", - " else:\n", - " resolutions_combinations = resolution_combinations\n", - "\n", - " for resolutions in resolutions_combinations:\n", - " for prompt in prompts:\n", - " combination = (keys, resolutions, prompt)\n", - " parameter_combinations.append(combination)\n", - "\n", - "def generate_name(combination):\n", - " keys, resolutions, prompt = combination\n", - " keys_name = '_'.join(keys) if keys else 'nokey'\n", - " resolutions_name = '_'.join(map(str, resolutions)) if keys else 'noresolution'\n", - " prompt_name = prompt.replace(' ', '_')\n", - " return f\"{keys_name}_{resolutions_name}_{prompt_name}\"\n", - "\n", - "# output_names = [generate_name(combination) for combination in parameter_combinations]\n", - "\n", - "# # Example output names\n", - "# print(output_names[:5]) # Displaying first 5 names for brevity\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for combination in parameter_combinations:\n", - " (keys, resolutions, prompt) = combination\n", - " controller.target_keys = keys\n", - " controller.target_resolutions = resolutions\n", - " \n", - "\n", - " output = pipe(\n", - " prompt=[\n", - " (\n", - " \"a spiderman is surfing\"\n", - " ),\n", - " # (\n", - " # \"masterpiece, bestquality, highlydetailed, ultradetailed, sunset, \"\n", - " # \"orange sky, warm lighting, fishing boats, ocean waves seagulls, \"\n", - " # \"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, \"\n", - " # \"golden hour, coastal landscape, seaside scenery\"\n", - " # )\n", - " (\n", - " prompt\n", - " )\n", - " ],\n", - " negative_prompt=[\n", - " \"bad quality, worse quality\",\n", - " \"bad quality, worse quality\"\n", - " ],\n", - " guidance_scale=7.5,\n", - " num_inference_steps=25,\n", - " latents=latents\n", - " )\n", - " export_to_gif(output.frames[1], f\"outputs/{generate_name(combination)}.gif\")\n", - " controller.reset()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 9. PE Inversion Testing" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from utils.pe_utils import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "replace_positional_embedding_unet3d(pipe.unet, target_size=[320, 640, 1280], target_module=['down','mid','up'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pe_path = '/home/wangluozhou/projects/VideoDiffusion_Playground/outputs/size_1280_unet3d/pos_embed.pt'\n", - "load_positional_embedding(pipe.unet, pe_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/wangluozhou/projects/VideoDiffusion_Playground/pipelines/pipeline_animatediff.py:743: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet3DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet3DConditionModel's config object instead, e.g. 'unet.config.in_channels'.\n", - " num_channels_latents = self.unet.in_channels\n" - ] - } - ], - "source": [ - "pipe.init_filter(\n", - " video_length=16,\n", - " height=320,\n", - " width=512\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "src_frames = load_video('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/Objects/sample_0_src.mp4')" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "latents_frames = pipe.encode_frames(src_frames, device=device)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 4, 16, 40, 64])" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "latents_frames.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 50/50 [01:05<00:00, 1.32s/it]\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/wangluozhou/projects/VideoDiffusion_Playground/outputs/test/animation.gif'" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "output = pipe(\n", - " prompt=(\n", - " \"Amazing quality, masterpiece, a man rides a bicycle in the snow field\"\n", - " # \"orange sky, warm lighting, fishing boats, ocean waves seagulls, \"\n", - " # \"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, \"\n", - " # \"golden hour, coastal landscape, seaside scenery\"\n", - " ),\n", - " negative_prompt=\"bad quality, distortions, unrealistic, distorted image, watermark, signature\",\n", - " height=320,\n", - " width=512,\n", - " num_frames=16,\n", - " guidance_scale=10,\n", - " num_inference_steps=50,\n", - " generator=torch.Generator(\"cuda\").manual_seed(0),\n", - " freeinit=True,\n", - " frames_video=latents_frames,\n", - ")\n", - "frames = output.frames[0]\n", - "export_to_gif(frames, \"/home/wangluozhou/projects/VideoDiffusion_Playground/outputs/test/animation.gif\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 10. Mischelleos" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import torch\n", - "from pipelines.pipeline_animatediff import *\n", - "from diffusers.schedulers import DDIMInverseScheduler\n", - "from diffusers.utils import export_to_gif, export_to_video, load_image\n", - "from utils.attn_utils import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "frames = load_video('/home/wangluozhou/projects/VideoDiffusion_Playground/resources/dog_jump_water.mp4')\n", - "len(frames)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "for idx, frame in enumerate(frames):\n", - " frame.save(os.path.join('/home/wangluozhou/projects/diffusion-motion-transfer/data/car',f'{str(idx).zfill(4)}.png'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "llava", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/models/dit/latte_t2v.py b/models/dit/latte_t2v.py new file mode 100644 index 0000000..fc96a30 --- /dev/null +++ b/models/dit/latte_t2v.py @@ -0,0 +1,990 @@ +import torch + +import os +import json + +from dataclasses import dataclass +from einops import rearrange, repeat +from typing import Any, Dict, Optional, Tuple +from diffusers.models import Transformer2DModel +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate +from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings, CaptionProjection, PatchEmbed, CombinedTimestepSizeEmbeddings +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero +from diffusers.models.attention_processor import Attention +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states + +@maybe_allow_in_graph +class BasicTransformerBlock_(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # # 2. Cross-Attn + # if cross_attention_dim is not None or double_self_attention: + # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # # the second cross attention block. + # self.norm2 = ( + # AdaLayerNorm(dim, num_embeds_ada_norm) + # if self.use_ada_layer_norm + # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + # ) + # self.attn2 = Attention( + # query_dim=dim, + # cross_attention_dim=cross_attention_dim if not double_self_attention else None, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # upcast_attention=upcast_attention, + # ) # is self-attn if encoder_hidden_states is none + # else: + # self.norm2 = None + # self.attn2 = None + + # 3. Feed-forward + # if not self.use_ada_layer_norm_single: + # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # # 3. Cross-Attention + # if self.attn2 is not None: + # if self.use_ada_layer_norm: + # norm_hidden_states = self.norm2(hidden_states, timestep) + # elif self.use_ada_layer_norm_zero or self.use_layer_norm: + # norm_hidden_states = self.norm2(hidden_states) + # elif self.use_ada_layer_norm_single: + # # For PixArt norm2 isn't applied here: + # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + # norm_hidden_states = hidden_states + # else: + # raise ValueError("Incorrect norm") + + # if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + # norm_hidden_states = self.pos_embed(norm_hidden_states) + + # attn_output = self.attn2( + # norm_hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # attention_mask=encoder_attention_mask, + # **cross_attention_kwargs, + # ) + # hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # if not self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = self.norm3(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = CombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class LatteT2V(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + video_length: int = 16, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.video_length = video_length + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( # one attention + inner_dim, + num_attention_heads, # num_attention_heads + attention_head_dim, # attention_head_dim 72 + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + # define temporal positional embedding + temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + input_batch_size, c, frame, h, w = hidden_states.shape + frame = frame - use_image_num + hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w').contiguous() + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + encoder_attention_mask = repeat(encoder_attention_mask, 'b 1 l -> (b f) 1 l', f=frame).contiguous() + elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] + encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', f=frame).contiguous() + encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] + encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) + encoder_attention_mask = rearrange(encoder_attention_mask, 'b n l -> (b n) l').contiguous().unsqueeze(1) + + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_patches: # here + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + # batch_size = hidden_states.shape[0] + batch_size = input_batch_size + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 + + if use_image_num != 0 and self.training: + encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', f=frame).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous() + else: + encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b t d -> (b f) t d', f=frame).contiguous() + + # prepare timesteps for spatial and temporal block + timestep_spatial = repeat(timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + timestep_temp = repeat(timestep, 'b d -> (b p) d', p=num_patches).contiguous() + + for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): + + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + if enable_temporal_attentions: + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0: # image-video joitn training + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + if i == 0: + hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + else: + hidden_states = spatial_block( + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + ) + + if enable_temporal_attentions: + + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0 and self.training: + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + hidden_states_video = temp_block( + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous() + + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + def get_1d_sincos_temp_embed(self, embed_dim, length): + pos = torch.arange(0, length).unsqueeze(1) + return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **kwargs) + + # model_files = [ + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.bin'), + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.safetensors') + # ] + + # model_file = None + + # for fp in model_files: + # if os.path.exists(fp): + # model_file = fp + + # if not model_file: + # raise RuntimeError(f"{model_file} does not exist") + + # if model_file.split(".")[-1] == "safetensors": + # from safetensors import safe_open + # state_dict = {} + # with safe_open(model_file, framework="pt", device="cpu") as f: + # for key in f.keys(): + # state_dict[key] = f.get_tensor(key) + # else: + # state_dict = torch.load(model_file, map_location="cpu") + + # for k, v in model.state_dict().items(): + # if 'temporal_transformer_blocks' in k: + # state_dict.update({k: v}) + + # model.load_state_dict(state_dict) + + return model \ No newline at end of file diff --git a/models/unet/motion_embeddings.py b/models/unet/motion_embeddings.py new file mode 100644 index 0000000..20b4eed --- /dev/null +++ b/models/unet/motion_embeddings.py @@ -0,0 +1,122 @@ +import re +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MotionEmbedding(nn.Module): + + def __init__(self, embed_dim: int = None, max_seq_length: int = 32): + super().__init__() + self.embed = nn.Parameter(torch.zeros(1, max_seq_length, embed_dim)) + self.scale = 1.0 + self.trained_length = -1 + + def set_scale(self, scale: float): + self.scale = scale + + def set_lengths(self, trained_length: int): + if trained_length > self.embed.shape[1] or trained_length <= 0: + raise ValueError("Trained length is out of bounds") + self.trained_length = trained_length + + def forward(self, x): + _, seq_length, _ = x.shape # seq_length here is the target sequence length for x + + # Assuming self.embed is [batch, frames, dim] + embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic + + # Check if interpolation is needed + if self.trained_length != -1 and seq_length != self.trained_length: + # Interpolate embeddings to match x's sequence length + # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames + embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] + embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) + embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] + + # Ensure the interpolated embeddings match the sequence length of x + if embeddings.shape[1] != seq_length: + raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") + + # Now embeddings should have the shape [batch, seq_length, dim] matching x + x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions + + return x + +def inject_motion_embeddings(model, sizes=[320, 640, 1280], modules=['up','down','mid']): + replacement_dict = {} + + for name, module in model.named_modules(): + if 'temp_attention' in name and re.search(r'transformer_blocks\.\d+$', name): + replacement_dict[f'{name}.pos_embed'] = MotionEmbedding(embed_dim=module.norm1.normalized_shape[0]).to(dtype=model.dtype, device=model.device) + + for name, new_module in replacement_dict.items(): + parent_name = name.rsplit('.', 1)[0] if '.' in name else '' + module_name = name.rsplit('.', 1)[-1] + parent_module = model + if parent_name: + parent_module = dict(model.named_modules())[parent_name] + + if new_module.embed.shape[-1] in sizes and parent_name.split('_')[0] in modules: + setattr(parent_module, module_name, new_module) + + parameters_list = [] + for name, para in model.named_parameters(): + if 'pos_embed' in name: + parameters_list.append(para) + para.requires_grad = True + else: + para.requires_grad = False + + return parameters_list + +def save_motion_embeddings(model, file_path): + # Extract motion embedding from all instances of MotionEmbedding + motion_embeddings = { + name: module.embed + for name, module in model.named_modules() + if isinstance(module, MotionEmbedding) + } + # Save the motion embeddings to the specified file path + torch.save(motion_embeddings, file_path) + +def load_motion_embeddings(model, saved_embeddings): + for key, embedding in saved_embeddings.items(): + # Extract parent module and module name from the key + parent_name = key.rsplit('.', 1)[0] if '.' in key else '' + module_name = key.rsplit('.', 1)[-1] + + # Retrieve the parent module + parent_module = model + if parent_name: + parent_module = dict(model.named_modules())[parent_name] + + # Create a new MotionEmbedding instance with the correct dimensions + new_module = MotionEmbedding(embed_dim=embedding.shape[-1], max_seq_length=embedding.shape[-2]) + + # Properly assign the loaded embeddings to the 'embed' parameter wrapped in nn.Parameter + # Ensure the embedding is on the correct device and has the correct dtype + new_module.embed = nn.Parameter(embedding.to(dtype=model.dtype, device=model.device)) + + # Replace the corresponding module in the model with the new MotionEmbedding instance + setattr(parent_module, module_name, new_module) + +def set_motion_embedding_scale(model, scale_value): + # Iterate over all modules in the model + for _, module in model.named_modules(): + # Check if the module is an instance of MotionEmbedding + if isinstance(module, MotionEmbedding): + # Set the scale attribute to the specified value + module.scale = scale_value + +def set_motion_embedding_length(model, trained_length): + # Iterate over all modules in the model + for _, module in model.named_modules(): + # Check if the module is an instance of MotionEmbedding + if isinstance(module, MotionEmbedding): + # Set the length to the specified value + module.trained_length = trained_length + + + + + diff --git a/models/unet/unet_3d_blocks.py b/models/unet/unet_3d_blocks.py new file mode 100644 index 0000000..8e246db --- /dev/null +++ b/models/unet/unet_3d_blocks.py @@ -0,0 +1,842 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from diffusers.models.transformer_2d import Transformer2DModel +from diffusers.models.transformer_temporal import TransformerTemporalModel + +# Assign gradient checkpoint function to simple variable for readability. +g_c = checkpoint.checkpoint + +def use_temporal(module, num_frames, x): + if num_frames == 1: + if isinstance(module, TransformerTemporalModel): + return {"sample": x} + else: + return x + +def custom_checkpoint(module, mode=None): + if mode == None: raise ValueError('Mode for gradient checkpointing cannot be none.') + custom_forward = None + + if mode == 'resnet': + def custom_forward(hidden_states, temb): + inputs = module(hidden_states, temb) + return inputs + + if mode == 'attn': + def custom_forward( + hidden_states, + encoder_hidden_states=None, + cross_attention_kwargs=None + ): + inputs = module( + hidden_states, + encoder_hidden_states, + cross_attention_kwargs + ) + return inputs + + if mode == 'temp': + def custom_forward(hidden_states, num_frames=None): + inputs = use_temporal(module, num_frames, hidden_states) + if inputs is None: inputs = module( + hidden_states, + num_frames=num_frames + ) + return inputs + + return custom_forward + +def transformer_g_c(transformer, sample, num_frames): + sample = g_c(custom_checkpoint(transformer, mode='temp'), + sample, num_frames, use_reentrant=False + )['sample'] + + return sample + +def cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=False + ): + + def ordered_g_c(idx): + + # Self and CrossAttention + if idx == 0: return g_c(custom_checkpoint(attn, mode='attn'), + hidden_states, encoder_hidden_states,cross_attention_kwargs, use_reentrant=False + )['sample'] + + # Temporal Self and CrossAttention + if idx == 1: return g_c(custom_checkpoint(temp_attn, mode='temp'), + hidden_states, num_frames, use_reentrant=False)['sample'] + + # Resnets + if idx == 2: return g_c(custom_checkpoint(resnet, mode='resnet'), + hidden_states, temb, use_reentrant=False) + + # Temporal Convolutions + if idx == 3: return g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, num_frames, use_reentrant=False + ) + + # Here we call the function depending on the order in which they are called. + # For some layers, the orders are different, so we access the appropriate one by index. + + if not inverse_temp: + for idx in [0,1,2,3]: hidden_states = ordered_g_c(idx) + else: + for idx in [2,3,0,1]: hidden_states = ordered_g_c(idx) + + return hidden_states + +def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames): + hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), hidden_states, temb, use_reentrant=False) + hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, num_frames, use_reentrant=False + ) + return hidden_states + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + ): + super().__init__() + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + self.resnets[0], + self.temp_convs[0], + hidden_states, + temb, + num_frames + ) + else: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + self.gradient_checkpointing = False + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.gradient_checkpointing: + hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + self.gradient_checkpointing = False + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/models/unet/unet_3d_condition.py b/models/unet/unet_3d_condition.py new file mode 100644 index 0000000..d1bf3c6 --- /dev/null +++ b/models/unet/unet_3d_condition.py @@ -0,0 +1,500 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + transformer_g_c +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + ): + super().__init__() + + self.sample_size = sample_size + self.gradient_checkpointing = False + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, value=False): + self.gradient_checkpointing = value + self.mid_block.gradient_checkpointing = value + for module in self.down_blocks + self.up_blocks: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + if num_frames > 1: + if self.gradient_checkpointing: + sample = transformer_g_c(self.transformer_in, sample, num_frames) + else: + sample = self.transformer_in(sample, num_frames=num_frames).sample + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/noise_init/__init__.py b/noise_init/__init__.py new file mode 100644 index 0000000..2afbcb7 --- /dev/null +++ b/noise_init/__init__.py @@ -0,0 +1,3 @@ +from .blend_init import initialize_noise_with_blend +from .dmt_init import initialize_noise_with_dmt +from .fft_init import initialize_noise_with_fft \ No newline at end of file diff --git a/noise_init/blend_init.py b/noise_init/blend_init.py new file mode 100644 index 0000000..7c82d5a --- /dev/null +++ b/noise_init/blend_init.py @@ -0,0 +1,20 @@ +import torch + +from utils.ddim_utils import inverse_video + +def initialize_noise_with_blend(noisy_latent, noise=None, seed=0, noise_prior=0.5): + + # noisy_latent = inverse_video(pipe, latents, 50) + shape = noisy_latent.shape + if noise is None: + noise = torch.randn( + shape, + device=noisy_latent.device, + generator=torch.Generator(noisy_latent.device).manual_seed(seed) + ).to(noisy_latent.dtype) + + + latents = (noise_prior) ** 0.5 * noisy_latent + ( + 1-noise_prior) ** 0.5 * noise + + return latents \ No newline at end of file diff --git a/noise_init/dmt_init.py b/noise_init/dmt_init.py new file mode 100644 index 0000000..7fc7861 --- /dev/null +++ b/noise_init/dmt_init.py @@ -0,0 +1,47 @@ +import math + +import torch +import torch.fft as fft +import torch.nn.functional as F + +from einops import rearrange + +from utils.ddim_utils import inverse_video + +def initialize_noise_with_dmt(noisy_latent, noise=None, seed=0, downsample_factor=4, num_frames=24): + + # noisy_latent = inverse_video(pipe, latents, 50) + + shape = noisy_latent.shape + if noise is None: + noise = torch.randn( + shape, + device=noisy_latent.device, + generator=torch.Generator(noisy_latent.device).manual_seed(seed) + ).to(noisy_latent.dtype) + + new_h, new_w = ( + noisy_latent.shape[-2] // downsample_factor, + noisy_latent.shape[-1] // downsample_factor, + ) + noise = rearrange(noise, "b c f h w -> (b f) c h w") + noise_down = F.interpolate(noise, size=(new_h, new_w), mode="bilinear", align_corners=True, antialias=True) + noise_up = F.interpolate( + noise_down, size=(noise.shape[-2], noise.shape[-1]), mode="bilinear", align_corners=True, antialias=True + ) + high_freqs = noise - noise_up + noisy_latent = rearrange(noisy_latent, "b c f h w -> (b f) c h w") + noisy_latent_down = F.interpolate( + noisy_latent, size=(new_h, new_w), mode="bilinear", align_corners=True, antialias=True + ) + low_freqs = F.interpolate( + noisy_latent_down, + size=(noisy_latent.shape[-2], noisy_latent.shape[-1]), + mode="bilinear", + align_corners=True, + antialias=True, + ) + noisy_latent = low_freqs + high_freqs + noisy_latent = rearrange(noisy_latent, "(b f) c h w -> b c f h w", f=num_frames) + + return noisy_latent \ No newline at end of file diff --git a/utils/freeinit_utils.py b/noise_init/fft_init.py similarity index 72% rename from utils/freeinit_utils.py rename to noise_init/fft_init.py index a55cc29..bd5f515 100644 --- a/utils/freeinit_utils.py +++ b/noise_init/fft_init.py @@ -1,7 +1,8 @@ -import torch -import torch.fft as fft import math +import torch +import torch.fft as fft +import torch.nn.functional as F def freq_mix_3d(x, noise, LPF): """ @@ -30,7 +31,6 @@ def freq_mix_3d(x, noise, LPF): return x_mixed - def get_freq_filter(shape, device, filter_type, n, d_s, d_t): """ Form the frequency filter for noise reinitialization. @@ -73,7 +73,6 @@ def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) return mask - def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): """ Compute the butterworth low pass filter mask. @@ -95,7 +94,6 @@ def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) return mask - def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): """ Compute the ideal low pass filter mask. @@ -116,7 +114,6 @@ def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 return mask - def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): """ Compute the ideal low pass filter mask (approximated version). @@ -137,4 +134,55 @@ def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): cframe, crow, ccol = T // 2, H // 2, W //2 mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 - return mask \ No newline at end of file + return mask + +@torch.no_grad() +def init_filter(video_length, height, width, filter_params_method="gaussian", filter_params_n=4, filter_params_d_s=0.25, filter_params_d_t=0.25, num_channels_latents=4, device='cpu'): + # initialize frequency filter for noise reinitialization + batch_size = 1 + num_channels_latents = num_channels_latents + filter_shape = [ + batch_size, + num_channels_latents, + video_length, + height, + width, + ] + freq_filter = get_freq_filter( + filter_shape, + device=device, + filter_type=filter_params_method, + n=filter_params_n if filter_params_method=="butterworth" else None, + d_s=filter_params_d_s, + d_t=filter_params_d_t + ) + return freq_filter + +def initialize_noise_with_fft(pipe, latents, noise=None, seed=0): + + shape = latents.shape + if noise is None: + noise = torch.randn( + shape, + device=latents.device, + generator=torch.Generator(latents.device).manual_seed(seed) + ).to(latents.dtype) + + pipe.scheduler.set_timesteps(30, device=latents.device) + timesteps = pipe.scheduler.timesteps + + noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps[:1]) + + dtype = noisy_latents.dtype + freq_filter = init_filter( + video_length=noisy_latents.shape[2], + height=noisy_latents.shape[3], + width=noisy_latents.shape[4], + device=noisy_latents.device + ) + + # make it float32 to accept any kinds of resolution + latents = freq_mix_3d(noisy_latents.to(dtype=torch.float32), noise.to(dtype=torch.float32), LPF=freq_filter) + latents = latents.to(dtype) + + return latents \ No newline at end of file diff --git a/pe_inversion_unet3d.py b/pe_inversion_unet3d.py deleted file mode 100644 index 351c929..0000000 --- a/pe_inversion_unet3d.py +++ /dev/null @@ -1,1003 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Script to fine-tune Stable Video Diffusion.""" -import argparse -import random -import logging -import math -import os -import cv2 -import shutil -from pathlib import Path -from urllib.parse import urlparse - -import accelerate -import numpy as np -import PIL -from PIL import Image, ImageDraw -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.utils.data import RandomSampler -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder -from packaging import version -from tqdm.auto import tqdm -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from einops import rearrange - -import diffusers -from diffusers import DDPMScheduler, DDIMScheduler -from diffusers.models.lora import LoRALinearLayer,adjust_lora_scale_text_encoder -from diffusers.models import AutoencoderKL, UNet3DConditionModel, UNetMotionModel, UNet2DConditionModel -from diffusers.models.unet_motion_model import MotionAdapter -from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image -from diffusers.utils.import_utils import is_xformers_available - -from torch.utils.data import Dataset -from diffusers.models.embeddings import SinusoidalPositionalEmbedding -from utils.pe_utils import * -from pipelines.pipeline_animatediff import load_video, VideoDiffPipeline -from diffusers.utils import export_to_gif - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.24.0.dev0") - -logger = get_logger(__name__, log_level="INFO") - - -class DummyDataset(Dataset): - def __init__(self, base_folder=None, width=512, height=320, sample_frames=16): - """ - Args: - num_samples (int): Number of samples in the dataset. - channels (int): Number of channels, default is 3 for RGB. - """ - # Define the path to the folder containing video frames - self.base_folder = base_folder - self.channels = 3 - self.width = width - self.height = height - self.sample_frames = sample_frames - - def __len__(self): - return 1 - - def __getitem__(self, idx): - """ - Args: - idx (int): Index of the sample to return. - - Returns: - dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). - """ - # Randomly select a folder (representing a video) from the base folder - # chosen_folder = random.choice(self.folders) - # video = os.listdir(self.base_folder)[idx] - video_path = self.base_folder - # Sort the frames by name - # frames.sort() - - # # Ensure the selected folder has at least `sample_frames`` frames - # if len(frames) < self.sample_frames: - # raise ValueError( - # f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.") - - # Randomly select a start index for frame sequence - # start_idx = random.randint(0, len(frames) - self.sample_frames) - # selected_frames = frames[start_idx:start_idx + self.sample_frames] - - # Initialize a tensor to store the pixel values - pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width)) - frames = load_video(video_path) - - # Load and process each frame - for i, img in enumerate(frames): - # frame_path = os.path.join(folder_path, frame_name) - # with Image.open(frame_path) as img: - # Resize the image and convert it to a tensor - img_resized = img.resize((self.width, self.height)) - img_tensor = torch.from_numpy(np.array(img_resized)).float() - - # Normalize the image by scaling pixel values to [-1, 1] - img_normalized = img_tensor / 127.5 - 1 - - # Rearrange channels if necessary - if self.channels == 3: - img_normalized = img_normalized.permute( - 2, 0, 1) # For RGB images - elif self.channels == 1: - img_normalized = img_normalized.mean( - dim=2, keepdim=True) # For grayscale images - - pixel_values[i] = img_normalized - return {'pixel_values': pixel_values, 'text':""} - -# resizing utils - -def read_text_prompts(file_path): - """ - Reads text prompts from a text file, where each line represents a text prompt. - - Args: - file_path (str): The file path to the text file containing the prompts. - - Returns: - List[str]: A list of text prompts. - """ - prompts = [] - with open(file_path, 'r', encoding='utf-8') as file: - for line in file: - # Strip removes leading/trailing whitespace, including newlines - prompts.append(line.strip()) - return prompts - - -def tensor_to_vae_latent(t, vae): - video_length = t.shape[1] - - t = rearrange(t, "b f c h w -> (b f) c h w") - latents = vae.encode(t).latent_dist.sample() - latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length).permute(0,2,1,3,4) - latents = latents * vae.config.scaling_factor - - return latents - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Script to train Stable Diffusion XL for InstructPix2Pix." - ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--validation_file", - type=str, - default=None, - help="A file that contain many prompts during training for inference.", - ) - parser.add_argument( - "--video_path", - type=str, - default=None, - help="A folder that contain your video" - ) - parser.add_argument( - "--num_frames", - type=int, - default=16, - ) - parser.add_argument( - "--width", - type=int, - default=1024, - ) - parser.add_argument( - "--height", - type=int, - default=576, - ) - parser.add_argument( - "--pe_size", - type=int, - nargs='+', - default=[320, 640, 1280], - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--pe_module", - type=str, - nargs='+', - default=["up", "down", "mid"], - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--num_validation_videos", - type=int, - default=1, - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--validation_steps", - type=int, - default=500, - help=( - "Run fine-tuning validation every X epochs. The validation process consists of running the text/image prompt" - " multiple times: `args.num_validation_images`." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="/home/wangluozhou/projects/VideoDiffusion_Playground/outputs", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) - parser.add_argument( - "--per_gpu_batch_size", - type=int, - default=1, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument("--num_train_epochs", type=int, default=100) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", - ) - - parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.", - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--use_ema", action="store_true", help="Whether to use EMA model." - ) - parser.add_argument( - "--non_ema_revision", - type=str, - default=None, - required=False, - help=( - "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" - " remote repository specified with --pretrained_model_name_or_path." - ), - ) - parser.add_argument( - "--num_workers", - type=int, - default=8, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." - ) - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-08, - help="Epsilon value for the Adam optimizer", - ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=2, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", - action="store_true", - help="Whether or not to use xformers.", - ) - - parser.add_argument( - "--pretrain_unet", - type=str, - default=None, - help="use weight for unet block", - ) - parser.add_argument( - "--rank", - type=int, - default=128, - help=("The dimension of the LoRA update matrices."), - ) - - args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - # default to using the same revision for the non-ema model if not specified - if args.non_ema_revision is None: - args.non_ema_revision = args.revision - - return args - - -def download_image(url): - original_image = ( - lambda image_url_or_path: load_image(image_url_or_path) - if urlparse(image_url_or_path).scheme - else PIL.Image.open(image_url_or_path).convert("RGB") - )(url) - return original_image - - -def main(): - args = parse_args() - - if args.non_ema_revision is not None: - deprecate( - "non_ema_revision!=None", - "0.15.0", - message=( - "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" - " use `--variant=non_ema` instead." - ), - ) - logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir) - # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - # kwargs_handlers=[ddp_kwargs] - ) - - generator = torch.Generator( - device=accelerator.device).manual_seed(args.seed) - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training.") - import wandb - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id - - # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision - ) - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) - - unet = UNet3DConditionModel.from_pretrained( - args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet, - subfolder="unet", - low_cpu_mem_usage=True, - ) - - replace_positional_embedding_unet3d(unet, target_size=args.pe_size, target_module=args.pe_module) - - # Freeze vae and image_encoder - vae.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.requires_grad_(False) - - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move image_encoder and vae to gpu and cast to weight_dtype - # image_encoder.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) - # unet.to(accelerator.device, dtype=weight_dtype) - - # Create EMA for the unet. - if args.use_ema: - ema_unet = EMAModel(unet.parameters( - ), model_cls=UNetSpatioTemporalConditionModel, model_config=unet.config) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError( - "xformers is not available. Make sure it is installed correctly") - - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if args.use_ema: - ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) - - for i, model in enumerate(models): - # if isinstance(model, UNetMotionModel): - save_positional_embeddings(model, os.path.join(output_dir, "pos_embed.pt")) - # model.save_pretrained(os.path.join(output_dir, "unet")) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join( - input_dir, "unet_ema"), UNetSpatioTemporalConditionModel) - ema_unet.load_state_dict(load_model.state_dict()) - ema_unet.to(accelerator.device) - del load_model - - for i in range(len(models)): - # pop models so that they are not loaded again - model = models.pop() - - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained( - input_dir, subfolder="unet") - model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) - del load_model - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.per_gpu_batch_size * accelerator.num_processes - ) - - # Initialize the optimizer - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" - ) - - optimizer_cls = bnb.optim.AdamW8bit - else: - optimizer_cls = torch.optim.AdamW - - unet.requires_grad_(True) - parameters_list = [] - - # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself. - for name, para in unet.named_parameters(): - if 'pos_embed' in name: - parameters_list.append(para) - para.requires_grad = True - else: - para.requires_grad = False - - - optimizer = optimizer_cls( - parameters_list, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - # check parameters - if accelerator.is_main_process: - rec_txt1 = open('rec_para.txt', 'w') - rec_txt2 = open('rec_para_train.txt', 'w') - for name, para in unet.named_parameters(): - if para.requires_grad is False: - rec_txt1.write(f'{name}\n') - else: - rec_txt2.write(f'{name}\n') - rec_txt1.close() - rec_txt2.close() - - # DataLoaders creation: - args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes - - train_dataset = DummyDataset(base_folder=args.video_path, width=args.width, height=args.height, sample_frames=args.num_frames) - sampler = RandomSampler(train_dataset) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - sampler=sampler, - batch_size=args.per_gpu_batch_size, - num_workers=args.num_workers, - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - ) - - # Prepare everything with our `accelerator`. - unet, optimizer, lr_scheduler, train_dataloader = accelerator.prepare( - unet, optimizer, lr_scheduler, train_dataloader - ) - - if args.use_ema: - ema_unet.to(accelerator.device) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil( - args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - accelerator.init_trackers("PE-Inversion", config=vars(args)) - - # Train! - total_batch_size = args.per_gpu_batch_size * \ - accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info( - f" Instantaneous batch size per device = {args.per_gpu_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info( - f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % ( - num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), - disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") - - for epoch in range(first_epoch, args.num_train_epochs): - unet.train() - train_loss = 0.0 - for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - - with accelerator.accumulate(unet): - # first, convert images to latent space. - pixel_values = batch["pixel_values"].to(weight_dtype).to( - accelerator.device, non_blocking=True - ) - latents = tensor_to_vae_latent(pixel_values, vae) - # print(latents.shape) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - with torch.no_grad(): - prompt_ids = tokenizer( - batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" - ).input_ids.to(latents.device) - encoder_hidden_states = text_encoder(prompt_ids)[0] - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - raise NotImplementedError - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Gather the losses across all processes for logging (if we use distributed training). - avg_loss = accelerator.gather( - loss.repeat(args.per_gpu_batch_size)).mean() - train_loss += avg_loss.item() / args.gradient_accumulation_steps - - # Backpropagate - accelerator.backward(loss) - # if accelerator.sync_gradients: - # accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.use_ema: - ema_unet.step(unet.parameters()) - progress_bar.update(1) - global_step += 1 - # accelerator.log({"train_loss": train_loss}, step=global_step) - train_loss = 0.0 - - if accelerator.is_main_process: - # save checkpoints! - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len( - checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}") - - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - # sample images! - if ( - (global_step % args.validation_steps == 0) - or (global_step == 1) - ): - logger.info( - f"Running validation... \n Generating {args.num_validation_videos} videos." - ) - # create pipeline - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - # The models need unwrapping because for compatibility in distributed training mode. - pipeline = VideoDiffPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - vae=accelerator.unwrap_model(vae), - text_encoder=accelerator.unwrap_model(text_encoder), - tokenizer=tokenizer, - image_encoder=None, - motion_adapter=None, - controlnet=None, - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline.scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') - pipeline = pipeline.to(device=accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - val_save_dir = os.path.join( - args.output_dir, "validation_images") - - if not os.path.exists(val_save_dir): - os.makedirs(val_save_dir) - - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision =='fp16' - ): - # open a validation prompts - validation_prompts = read_text_prompts(args.validation_file) - validation_video_frames = load_video(args.video_path) - pipeline.init_filter( - video_length=len(validation_video_frames), - height=validation_video_frames[0].size[1], - width=validation_video_frames[0].size[0], - ) - validation_video_latents = pipeline.encode_frames(validation_video_frames, device=accelerator.device) - for val_prompt_idx, val_prompt in enumerate(validation_prompts): - num_frames = args.num_frames - for seed in range(args.num_validation_videos): - video_frames = pipeline( - prompt="Amazing quality, masterpiece, " + val_prompt, - height=validation_video_frames[0].size[1], - negative_prompt="bad quality, distortions, unrealistic, distorted image, watermark, signature", - width=validation_video_frames[0].size[0], - num_frames=len(validation_video_frames), - guidance_scale=10.0, - num_inference_steps=50, - frames_video=validation_video_latents, - freeinit=True, - generator=torch.Generator(device=accelerator.device).manual_seed(seed), - ).frames[0] - - prompt_name = val_prompt.replace(' ', '_') - out_file = os.path.join( - val_save_dir, - f"step_{global_step}_val_img_{prompt_name}_seed_{seed}.gif", - ) - export_to_gif(video_frames, out_file) - - if args.use_ema: - # Switch back to the original UNet parameters. - ema_unet.restore(unet.parameters()) - - del pipeline - torch.cuda.empty_cache() - - logs = {"step_loss": loss.detach().item( - ), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - # Create the pipeline using the trained modules and save it. - accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) - if args.use_ema: - ema_unet.copy_to(unet.parameters()) - - # pipelne = AnimateDiffPipeline.from_pretrained( - # args.pretrained_model_name_or_path, - # unet=accelerator.unwrap_model(unet), - # vae=accelerator.unwrap_model(vae), - # text_encoder=accelerator.unwrap_model(text_encoder), - # tokenizer=tokenizer, - # revision=args.revision, - # torch_dtype=weight_dtype, - # ) - # pipeline.save_pretrained(args.output_dir) - save_positional_embeddings(unet, os.path.join(args.output_dir, 'pos_embed.pt')) - - - # if args.push_to_hub: - # upload_folder( - # repo_id=repo_id, - # folder_path=args.output_dir, - # commit_message="End of training", - # ignore_patterns=["step_*", "epoch_*"], - # ) - accelerator.end_training() - - -if __name__ == "__main__": - main() diff --git a/pe_inversion_unetmotion.py b/pe_inversion_unetmotion.py deleted file mode 100644 index bc1ec2e..0000000 --- a/pe_inversion_unetmotion.py +++ /dev/null @@ -1,1008 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Script to fine-tune Stable Video Diffusion.""" -import argparse -import random -import logging -import math -import os -import cv2 -import shutil -from pathlib import Path -from urllib.parse import urlparse - -import accelerate -import numpy as np -import PIL -from PIL import Image, ImageDraw -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.utils.data import RandomSampler -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder -from packaging import version -from tqdm.auto import tqdm -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from einops import rearrange - -import diffusers -from diffusers import DDPMScheduler, DDIMScheduler -from diffusers.models.lora import LoRALinearLayer,adjust_lora_scale_text_encoder -from diffusers.models import AutoencoderKL, UNet3DConditionModel, UNetMotionModel, UNet2DConditionModel -from diffusers.models.unet_motion_model import MotionAdapter -from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image -from diffusers.utils.import_utils import is_xformers_available - -from torch.utils.data import Dataset -from diffusers.models.embeddings import SinusoidalPositionalEmbedding -from utils.pe_utils import * -from pipelines.pipeline_animatediff import load_video, AnimateDiffPipeline -from diffusers.utils import export_to_gif - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.24.0.dev0") - -logger = get_logger(__name__, log_level="INFO") - - -class DummyDataset(Dataset): - def __init__(self, base_folder=None, width=512, height=320, sample_frames=16): - """ - Args: - num_samples (int): Number of samples in the dataset. - channels (int): Number of channels, default is 3 for RGB. - """ - # Define the path to the folder containing video frames - self.base_folder = base_folder - self.channels = 3 - self.width = width - self.height = height - self.sample_frames = sample_frames - - def __len__(self): - return 1 - - def __getitem__(self, idx): - """ - Args: - idx (int): Index of the sample to return. - - Returns: - dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). - """ - # Randomly select a folder (representing a video) from the base folder - # chosen_folder = random.choice(self.folders) - # video = os.listdir(self.base_folder)[idx] - video_path = self.base_folder - # Sort the frames by name - # frames.sort() - - # # Ensure the selected folder has at least `sample_frames`` frames - # if len(frames) < self.sample_frames: - # raise ValueError( - # f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.") - - # Randomly select a start index for frame sequence - # start_idx = random.randint(0, len(frames) - self.sample_frames) - # selected_frames = frames[start_idx:start_idx + self.sample_frames] - - # Initialize a tensor to store the pixel values - pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width)) - frames = load_video(video_path) - - # Load and process each frame - for i, img in enumerate(frames): - # frame_path = os.path.join(folder_path, frame_name) - # with Image.open(frame_path) as img: - # Resize the image and convert it to a tensor - img_resized = img.resize((self.width, self.height)) - img_tensor = torch.from_numpy(np.array(img_resized)).float() - - # Normalize the image by scaling pixel values to [-1, 1] - img_normalized = img_tensor / 127.5 - 1 - - # Rearrange channels if necessary - if self.channels == 3: - img_normalized = img_normalized.permute( - 2, 0, 1) # For RGB images - elif self.channels == 1: - img_normalized = img_normalized.mean( - dim=2, keepdim=True) # For grayscale images - - pixel_values[i] = img_normalized - return {'pixel_values': pixel_values, 'text':""} - -# resizing utils - -def read_text_prompts(file_path): - """ - Reads text prompts from a text file, where each line represents a text prompt. - - Args: - file_path (str): The file path to the text file containing the prompts. - - Returns: - List[str]: A list of text prompts. - """ - prompts = [] - with open(file_path, 'r', encoding='utf-8') as file: - for line in file: - # Strip removes leading/trailing whitespace, including newlines - prompts.append(line.strip()) - return prompts - - -def tensor_to_vae_latent(t, vae): - video_length = t.shape[1] - - t = rearrange(t, "b f c h w -> (b f) c h w") - latents = vae.encode(t).latent_dist.sample() - latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length).permute(0,2,1,3,4) - latents = latents * vae.config.scaling_factor - - return latents - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Script to train Stable Diffusion XL for InstructPix2Pix." - ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--validation_file", - type=str, - default=None, - help="A file that contain many prompts during training for inference.", - ) - parser.add_argument( - "--video_path", - type=str, - default=None, - help="A folder that contain your video" - ) - parser.add_argument( - "--num_frames", - type=int, - default=16, - ) - parser.add_argument( - "--width", - type=int, - default=1024, - ) - parser.add_argument( - "--height", - type=int, - default=576, - ) - parser.add_argument( - "--pe_size", - type=int, - nargs='+', - default=[320, 640, 1280], - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--pe_module", - type=str, - nargs='+', - default=["up", "down", "mid"], - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--num_validation_images", - type=int, - default=1, - help="Number of images that should be generated during validation with `validation_prompt`.", - ) - parser.add_argument( - "--validation_steps", - type=int, - default=500, - help=( - "Run fine-tuning validation every X epochs. The validation process consists of running the text/image prompt" - " multiple times: `args.num_validation_images`." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="/home/wangluozhou/projects/VideoDiffusion_Playground/outputs", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) - parser.add_argument( - "--per_gpu_batch_size", - type=int, - default=1, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument("--num_train_epochs", type=int, default=100) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", - ) - - parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.", - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--use_ema", action="store_true", help="Whether to use EMA model." - ) - parser.add_argument( - "--non_ema_revision", - type=str, - default=None, - required=False, - help=( - "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" - " remote repository specified with --pretrained_model_name_or_path." - ), - ) - parser.add_argument( - "--num_workers", - type=int, - default=8, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument( - "--adam_beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_beta2", - type=float, - default=0.999, - help="The beta2 parameter for the Adam optimizer.", - ) - parser.add_argument( - "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." - ) - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-08, - help="Epsilon value for the Adam optimizer", - ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=2, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--enable_xformers_memory_efficient_attention", - action="store_true", - help="Whether or not to use xformers.", - ) - - parser.add_argument( - "--pretrain_unet", - type=str, - default=None, - help="use weight for unet block", - ) - parser.add_argument( - "--rank", - type=int, - default=128, - help=("The dimension of the LoRA update matrices."), - ) - - args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - # default to using the same revision for the non-ema model if not specified - if args.non_ema_revision is None: - args.non_ema_revision = args.revision - - return args - - -def download_image(url): - original_image = ( - lambda image_url_or_path: load_image(image_url_or_path) - if urlparse(image_url_or_path).scheme - else PIL.Image.open(image_url_or_path).convert("RGB") - )(url) - return original_image - - -def main(): - args = parse_args() - - if args.non_ema_revision is not None: - deprecate( - "non_ema_revision!=None", - "0.15.0", - message=( - "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" - " use `--variant=non_ema` instead." - ), - ) - logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration( - project_dir=args.output_dir, logging_dir=logging_dir) - # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - # kwargs_handlers=[ddp_kwargs] - ) - - generator = torch.Generator( - device=accelerator.device).manual_seed(args.seed) - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError( - "Make sure to install wandb if you want to use it for logging during training.") - import wandb - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id - - # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision - ) - - motion_adapter = MotionAdapter.from_pretrained( - "/home/wangluozhou/projects/AnimateDiff/models/Motion_Module/animatediff-motion-adapter-v1-5-2", - revision=args.revision - ) - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) - - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet, - subfolder="unet", - low_cpu_mem_usage=True, - ) - unet = UNetMotionModel.from_unet2d(unet, motion_adapter) - - # - replace_positional_embedding(unet, target_size=args.pe_size, target_module=args.pe_module) - - # Freeze vae and image_encoder - vae.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.requires_grad_(False) - - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move image_encoder and vae to gpu and cast to weight_dtype - # image_encoder.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) - # unet.to(accelerator.device, dtype=weight_dtype) - - # Create EMA for the unet. - if args.use_ema: - ema_unet = EMAModel(unet.parameters( - ), model_cls=UNetSpatioTemporalConditionModel, model_config=unet.config) - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError( - "xformers is not available. Make sure it is installed correctly") - - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if args.use_ema: - ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) - - for i, model in enumerate(models): - # if isinstance(model, UNetMotionModel): - save_positional_embeddings(model, os.path.join(output_dir, "pos_embed.pt")) - # model.save_pretrained(os.path.join(output_dir, "unet")) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join( - input_dir, "unet_ema"), UNetSpatioTemporalConditionModel) - ema_unet.load_state_dict(load_model.state_dict()) - ema_unet.to(accelerator.device) - del load_model - - for i in range(len(models)): - # pop models so that they are not loaded again - model = models.pop() - - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained( - input_dir, subfolder="unet") - model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) - del load_model - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.per_gpu_batch_size * accelerator.num_processes - ) - - # Initialize the optimizer - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" - ) - - optimizer_cls = bnb.optim.AdamW8bit - else: - optimizer_cls = torch.optim.AdamW - - unet.requires_grad_(True) - parameters_list = [] - - # Customize the parameters that need to be trained; if necessary, you can uncomment them yourself. - for name, para in unet.named_parameters(): - if 'pos_embed' in name: - parameters_list.append(para) - para.requires_grad = True - else: - para.requires_grad = False - - - optimizer = optimizer_cls( - parameters_list, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - # optimizer = optimizer_cls( - # unet.parameters(), - # lr=args.learning_rate, - # betas=(args.adam_beta1, args.adam_beta2), - # weight_decay=args.adam_weight_decay, - # eps=args.adam_epsilon, - # ) - - # check parameters - if accelerator.is_main_process: - rec_txt1 = open('rec_para.txt', 'w') - rec_txt2 = open('rec_para_train.txt', 'w') - for name, para in unet.named_parameters(): - if para.requires_grad is False: - rec_txt1.write(f'{name}\n') - else: - rec_txt2.write(f'{name}\n') - rec_txt1.close() - rec_txt2.close() - - # DataLoaders creation: - args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes - - train_dataset = DummyDataset(base_folder=args.video_path, width=args.width, height=args.height, sample_frames=args.num_frames) - sampler = RandomSampler(train_dataset) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - sampler=sampler, - batch_size=args.per_gpu_batch_size, - num_workers=args.num_workers, - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - ) - - # Prepare everything with our `accelerator`. - unet, optimizer, lr_scheduler, train_dataloader = accelerator.prepare( - unet, optimizer, lr_scheduler, train_dataloader - ) - - if args.use_ema: - ema_unet.to(accelerator.device) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil( - args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - accelerator.init_trackers("SVDXtend", config=vars(args)) - - # Train! - total_batch_size = args.per_gpu_batch_size * \ - accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info( - f" Instantaneous batch size per device = {args.per_gpu_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info( - f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % ( - num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), - disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") - - for epoch in range(first_epoch, args.num_train_epochs): - unet.train() - train_loss = 0.0 - for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - - with accelerator.accumulate(unet): - # first, convert images to latent space. - pixel_values = batch["pixel_values"].to(weight_dtype).to( - accelerator.device, non_blocking=True - ) - latents = tensor_to_vae_latent(pixel_values, vae) - # print(latents.shape) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - with torch.no_grad(): - prompt_ids = tokenizer( - batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" - ).input_ids.to(latents.device) - encoder_hidden_states = text_encoder(prompt_ids)[0] - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - raise NotImplementedError - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Gather the losses across all processes for logging (if we use distributed training). - avg_loss = accelerator.gather( - loss.repeat(args.per_gpu_batch_size)).mean() - train_loss += avg_loss.item() / args.gradient_accumulation_steps - - # Backpropagate - accelerator.backward(loss) - # if accelerator.sync_gradients: - # accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.use_ema: - ema_unet.step(unet.parameters()) - progress_bar.update(1) - global_step += 1 - accelerator.log({"train_loss": train_loss}, step=global_step) - train_loss = 0.0 - - if accelerator.is_main_process: - # save checkpoints! - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [ - d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted( - checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len( - checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info( - f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join( - args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join( - args.output_dir, f"checkpoint-{global_step}") - - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - # sample images! - if ( - (global_step % args.validation_steps == 0) - or (global_step == 1) - ): - logger.info( - f"Running validation... \n Generating {args.num_validation_images} videos." - ) - # create pipeline - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - # The models need unwrapping because for compatibility in distributed training mode. - pipeline = AnimateDiffPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - vae=accelerator.unwrap_model(vae), - text_encoder=accelerator.unwrap_model(text_encoder), - tokenizer=tokenizer, - image_encoder=None, - motion_adapter=None, - controlnet=None, - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline.scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - val_save_dir = os.path.join( - args.output_dir, "validation_images") - - if not os.path.exists(val_save_dir): - os.makedirs(val_save_dir) - - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" - ): - # open a validation prompts - validation_prompts = read_text_prompts(args.validation_file) - # validation_prompts=["a man is running", "a cat is running", "a dog is running"] - for val_prompt_idx, val_prompt in enumerate(validation_prompts): - num_frames = args.num_frames - video_frames = pipeline( - prompt=val_prompt, - height=args.height, - width=args.width, - num_frames=num_frames, - guidance_scale=7.5, - num_inference_steps=25, - generator=generator.manual_seed(args.seed), - ).frames[0] - - prompt_name = val_prompt.replace(' ', '_') - out_file = os.path.join( - val_save_dir, - f"step_{global_step}_val_img_{prompt_name}.gif", - ) - export_to_gif(video_frames, out_file) - - if args.use_ema: - # Switch back to the original UNet parameters. - ema_unet.restore(unet.parameters()) - - del pipeline - torch.cuda.empty_cache() - - logs = {"step_loss": loss.detach().item( - ), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - # Create the pipeline using the trained modules and save it. - accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) - if args.use_ema: - ema_unet.copy_to(unet.parameters()) - - # pipelne = AnimateDiffPipeline.from_pretrained( - # args.pretrained_model_name_or_path, - # unet=accelerator.unwrap_model(unet), - # vae=accelerator.unwrap_model(vae), - # text_encoder=accelerator.unwrap_model(text_encoder), - # tokenizer=tokenizer, - # revision=args.revision, - # torch_dtype=weight_dtype, - # ) - # pipeline.save_pretrained(args.output_dir) - save_positional_embeddings(unet, os.path.join(args.output_dir, 'pos_embed.pt')) - - - # if args.push_to_hub: - # upload_folder( - # repo_id=repo_id, - # folder_path=args.output_dir, - # commit_message="End of training", - # ignore_patterns=["step_*", "epoch_*"], - # ) - accelerator.end_training() - - -if __name__ == "__main__": - main() diff --git a/pipelines/pipeline_animatediff.py b/pipelines/pipeline_animatediff.py deleted file mode 100644 index ca02cf5..0000000 --- a/pipelines/pipeline_animatediff.py +++ /dev/null @@ -1,1151 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, UNet3DConditionModel, UNetMotionModel, UNet2DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.unet_motion_model import MotionAdapter -from diffusers.schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) -from utils.freeinit_utils import * -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers -from diffusers.utils.torch_utils import randn_tensor, is_compiled_module -# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionControlNetPipeline -from diffusers import StableDiffusionControlNetPipeline, ControlNetModel -from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler - >>> from diffusers.utils import export_to_gif - - >>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter") - >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) - >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) - >>> output = pipe(prompt="A corgi walking in the park") - >>> frames = output.frames[0] - >>> export_to_gif(frames, "animation.gif") - ``` -""" - -import os -import cv2 -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import PIL.Image -import PIL.ImageOps -import numpy as np -import matplotlib.pyplot as plt - -def build_curve_tensor(max_value, min_value, length, frames, strategy='linear'): - """ - Build a curve based on the given strategy and return it as a PyTorch tensor. - The curve starts from the min_value and increases to the max_value. - - Parameters: - max_value (float): The maximum value of the curve. - min_value (float): The minimum value of the curve. - length (int): The length over which the curve changes from min to max. - frames (int): The total number of frames in the curve. - strategy (str): The strategy for building the curve. Options: 'linear', 'exponential', 'logarithmic'. - - Returns: - torch.Tensor: A tensor representing the curve. - """ - - if strategy == 'linear': - # Linear increase from min_value to max_value over 'length' frames, then constant - curve = np.linspace(max_value, min_value, length) - curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value) - - elif strategy == 'exponential': - # Exponential increase from min_value to max_value - curve = np.geomspace(max_value, min_value, length) - curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value) - - elif strategy == 'logarithmic': - # Logarithmic increase from min_value to max_value - log_space = np.linspace(1, length + 1, length) - curve = (np.log(log_space) / np.log(length + 1)) * (min_value - max_value) + min_value - curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value) - - else: - raise ValueError("Unknown strategy: Choose from 'linear', 'exponential', 'logarithmic'") - - curve_tensor = torch.from_numpy(curve) - # Plot the curve - fig, ax = plt.subplots() - ax.plot(curve, label=strategy) - ax.set_title('Curve Visualization') - ax.set_xlabel('Frame') - ax.set_ylabel('Value') - ax.legend() - - return curve_tensor, fig - -def load_video(video_path: str) -> List[PIL.Image.Image]: - """ - Loads a video from the given path and returns a list of its frames as PIL Images. - - Args: - video_path (str): - The path to the video file. - Returns: - List[PIL.Image.Image]: - A list of frames as PIL Images. - """ - if not os.path.isfile(video_path): - raise ValueError(f"{video_path} is not a valid path to a video file.") - - # Open the video file - cap = cv2.VideoCapture(video_path) - frames = [] - - while True: - # Read each frame - ret, frame = cap.read() - if not ret: - break - - # Convert the frame to RGB format and then to a PIL Image - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = PIL.Image.fromarray(frame) - frames.append(frame) - - cap.release() - return frames - -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - return outputs - -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" - ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - -@dataclass -class VideoDiffPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] - -class VideoDiffPipeline(StableDiffusionControlNetPipeline): - r""" - Pipeline for text-to-video generation. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer (`CLIPTokenizer`): - A [`~transformers.CLIPTokenizer`] to tokenize text. - unet ([`UNet2DConditionModel`]): - A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. - motion_adapter ([`MotionAdapter`]): - A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - - model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["feature_extractor", "image_encoder"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - motion_adapter: MotionAdapter, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], - feature_extractor: CLIPImageProcessor = None, - image_encoder: CLIPVisionModelWithProjection = None, - use_motion_mid_block = False - ): - # super().__init__() - if motion_adapter is not None: - unet.config['use_motion_mid_block'] = use_motion_mid_block - motion_adapter.config['use_motion_mid_block'] = use_motion_mid_block - unet = UNetMotionModel.from_unet2d(unet, motion_adapter) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - motion_adapter=motion_adapter, - scheduler=scheduler, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds - - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - - batch_size, channels, num_frames, height, width = latents.shape - latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - - image = self.vae.decode(latents).sample - video = ( - image[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + image.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - video = video.float() - return video - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents - def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None - ): - shape = ( - batch_size, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def get_timesteps(self, num_inference_steps, strength, device, inverse=False): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - if inverse: - if t_start == 0: - timesteps = self.scheduler.timesteps - else: - timesteps = self.scheduler.timesteps[:-t_start * self.scheduler.order] - - else: - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - - return timesteps, num_inference_steps - t_start - - @torch.no_grad() - def encode_frames(self, frames, device): - - dtype = next(self.vae.parameters()).dtype - frames = self.image_processor.preprocess(frames).to(device=device, dtype=dtype) - latent_frames = self.vae.encode(frames).latent_dist.sample() * self.vae.config.scaling_factor - num_frames, channels, height, width = latent_frames.shape - latents_frames = latent_frames.reshape(1, num_frames, channels, height, width).permute(0,2,1,3,4) - - return latents_frames - - def add_noise_to_latents(self, init_latents, strength, generator=None, num_inference_steps=50): - # dtype - device = init_latents.device - dtype = init_latents.dtype - - # set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps=None) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, inverse=False) - latent_timestep = timesteps[:1].repeat(init_latents.shape[0]) - - # sample noise - shape = init_latents.shape - init_noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - init_latents = self.scheduler.add_noise(init_latents, init_noise, latent_timestep) - latents = init_latents - return latents, init_noise - - def freeinit(self, latents_video, latents_noise, device, dtype, generator): - - # initial_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=dtype) - - current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 # diffuse to t=999 noise level - - diffuse_timesteps = torch.full((latents_noise.shape[0],),int(current_diffuse_timestep)) - diffuse_timesteps = diffuse_timesteps.long() - - z_T = self.scheduler.add_noise( - original_samples=latents_video.to(device), - noise=latents_noise.to(device), - timesteps=diffuse_timesteps.to(device) - ) - - latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents_noise.to(dtype=torch.float32), LPF=self.freq_filter) - latents = latents.to(dtype) - - return latents - - @torch.no_grad() - def init_filter( - self, - video_length, - height, - width, - filter_params_method="gaussian", - filter_params_n=4, - filter_params_d_s=0.25, - filter_params_d_t=0.25 - ): - # initialize frequency filter for noise reinitialization - batch_size = 1 - num_channels_latents = self.unet.in_channels - filter_shape = [ - batch_size, - num_channels_latents, - video_length, - height // self.vae_scale_factor, - width // self.vae_scale_factor - ] - # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params) - self.freq_filter = get_freq_filter( - filter_shape, - device=self._execution_device, - filter_type=filter_params_method, - n=filter_params_n if filter_params_method=="butterworth" else None, - d_s=filter_params_d_s, - d_t=filter_params_d_t - ) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - num_frames: Optional[int] = 16, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - timesteps: List[int] = None, - guidance_scale: float = 7.5, - strength: float = 1.0, - rect_scheduled_sampling_beta=1.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_videos_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - clip_skip: Optional[int] = None, - inverse=False, - frames_video=None, # controlnet_frames - noise_rect=None, # noise rectification - frames_controlnet=None, # controlnet_frames - frames_inpaint=None, # blend_diffusion - noise_inpaint=None, # blend_diffusion - mask_inpaint=None, # mask - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, - freeinit=False, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated video. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated video. - num_frames (`int`, *optional*, defaults to 16): - The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds - amounts to 2 seconds of video. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality videos at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. Latents should be of shape - `(batch_size, num_channel, num_frames, height, width)`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or - `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead - of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - Examples: - - Returns: - [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is - returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. - """ - # # - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - num_videos_per_prompt = 1 - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - # global_pool_conditions = ( - # controlnet.config.global_pool_conditions - # if isinstance(controlnet, ControlNetModel) - # else controlnet.nets[0].config.global_pool_conditions - # ) - # guess_mode = guess_mode or global_pool_conditions - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) - if do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - - # 3.1 Prepare frames for controlnet - if isinstance(controlnet, ControlNetModel) and frames_controlnet is not None: - frames_preprocessed = [] - for frame in frames_controlnet: - frame = self.prepare_image( - image=frame, - width=width, - height=height, - batch_size=batch_size * num_videos_per_prompt, - num_images_per_prompt=num_videos_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - frames_preprocessed.append(frame) - frames_controlnet = torch.cat(frames_preprocessed, dim=0) - height, width = frames_controlnet.shape[-2:] - - # elif isinstance(controlnet, MultiControlNetModel): - # images = [] - - # for image_ in image: - # image_ = self.prepare_image( - # image=image_, - # width=width, - # height=height, - # batch_size=batch_size * num_videos_per_prompt, - # num_images_per_prompt=num_videos_per_prompt, - # device=device, - # dtype=controlnet.dtype, - # do_classifier_free_guidance=do_classifier_free_guidance, - # guess_mode=guess_mode, - # ) - - # images.append(image_) - - # image = images - # height, width = image[0].shape[-2:] - else: - # assert False - pass - - # 5. set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, inverse) - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - - num_rect_steps = int(rect_scheduled_sampling_beta * len(timesteps)) - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - - # 7.2 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # motion_adatper - if self.unet.__class__.__name__ == 'UNet2DConditionModel': - # [bs, channels, frames, height ,width] -> [bs * frames, channels, height, width] - latents = latents.permute(0, 2, 1, 3, 4).reshape((latents.shape[0] * num_frames, -1) + latents.shape[3:]) - # [bs, 77, 768] -> [bs * frames, channels, height, width] - prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) - - if freeinit: - latents = self.freeinit( - latents_video=frames_video, - latents_noise=latents, - device=device, - dtype=prompt_embeds.dtype, - generator=generator - ) - - # Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # # controlnet(s) inference - # reshape latent - if frames_controlnet is not None: - if self.motion_adapter is None: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - else: - control_model_input = latent_model_input.permute(0, 2, 1, 3, 4).reshape((latent_model_input.shape[0] * num_frames, -1) + latent_model_input.shape[3:]) - controlnet_prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] - - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=frames_controlnet, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - return_dict=False, - ) - else: - down_block_res_samples, mid_block_res_sample = None, None - - # if guess_mode and self.do_classifier_free_guidance: - # # Infered ControlNet only for the conditional batch. - # # To apply the output of ControlNet to both the unconditional and conditional batches, - # # add 0 to the unconditional batch to keep it unchanged. - # down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - # mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - if isinstance(self.unet,UNet3DConditionModel): - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - else: - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - added_cond_kwargs=added_cond_kwargs, - ).sample - - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # noise rectification - if noise_rect is not None and i <= num_rect_steps: - delta = noise_rect - noise_pred - noise_pred = noise_pred + mask_inpaint * delta[:, :, 0, :, :].unsqueeze(2).repeat(1, 1, num_frames, 1, 1) + (1 - mask_inpaint) * delta - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if frames_inpaint is not None: - init_latents_proper = frames_inpaint # [batch_size, channels, frames, height, width] - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise_inpaint, torch.tensor([noise_timestep]) - ) - - latents = mask_inpaint * latents + (1 - mask_inpaint) * init_latents_proper - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if self.unet.__class__.__name__ == 'UNet2DConditionModel': - latents = latents.reshape((1, num_frames) + latents.shape[-3:]).permute(0,2,1,3,4) - - - if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) - - # Post-processing - video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (video,) - - return VideoDiffPipelineOutput(frames=video) diff --git a/pipelines/pipeline_stable_video_diffusion.py b/pipelines/pipeline_stable_video_diffusion.py deleted file mode 100644 index 0e1d4ca..0000000 --- a/pipelines/pipeline_stable_video_diffusion.py +++ /dev/null @@ -1,1080 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, UNet3DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.unet_motion_model import MotionAdapter -from diffusers.schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers -from diffusers.utils.torch_utils import randn_tensor, is_compiled_module -# from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionControlNetPipeline -from diffusers import StableDiffusionControlNetPipeline, ControlNetModel -from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler - >>> from diffusers.utils import export_to_gif - - >>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter") - >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) - >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) - >>> output = pipe(prompt="A corgi walking in the park") - >>> frames = output.frames[0] - >>> export_to_gif(frames, "animation.gif") - ``` -""" - -import os -import cv2 -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import PIL.Image -import PIL.ImageOps -import numpy as np -import matplotlib.pyplot as plt - -def build_curve_tensor(max_value, min_value, length, frames, strategy='linear'): - """ - Build a curve based on the given strategy and return it as a PyTorch tensor. - The curve starts from the min_value and increases to the max_value. - - Parameters: - max_value (float): The maximum value of the curve. - min_value (float): The minimum value of the curve. - length (int): The length over which the curve changes from min to max. - frames (int): The total number of frames in the curve. - strategy (str): The strategy for building the curve. Options: 'linear', 'exponential', 'logarithmic'. - - Returns: - torch.Tensor: A tensor representing the curve. - """ - - if strategy == 'linear': - # Linear increase from min_value to max_value over 'length' frames, then constant - curve = np.linspace(max_value, min_value, length) - curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value) - - elif strategy == 'exponential': - # Exponential increase from min_value to max_value - curve = np.geomspace(max_value, min_value, length) - curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value) - - elif strategy == 'logarithmic': - # Logarithmic increase from min_value to max_value - log_space = np.linspace(1, length + 1, length) - curve = (np.log(log_space) / np.log(length + 1)) * (min_value - max_value) + min_value - curve = np.pad(curve, (0, frames - length), mode='constant', constant_values=min_value) - - else: - raise ValueError("Unknown strategy: Choose from 'linear', 'exponential', 'logarithmic'") - - curve_tensor = torch.from_numpy(curve) - # Plot the curve - fig, ax = plt.subplots() - ax.plot(curve, label=strategy) - ax.set_title('Curve Visualization') - ax.set_xlabel('Frame') - ax.set_ylabel('Value') - ax.legend() - - return curve_tensor, fig - -def load_video(video_path: str) -> List[PIL.Image.Image]: - """ - Loads a video from the given path and returns a list of its frames as PIL Images. - - Args: - video_path (str): - The path to the video file. - Returns: - List[PIL.Image.Image]: - A list of frames as PIL Images. - """ - if not os.path.isfile(video_path): - raise ValueError(f"{video_path} is not a valid path to a video file.") - - # Open the video file - cap = cv2.VideoCapture(video_path) - frames = [] - - while True: - # Read each frame - ret, frame = cap.read() - if not ret: - break - - # Convert the frame to RGB format and then to a PIL Image - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = PIL.Image.fromarray(frame) - frames.append(frame) - - cap.release() - return frames - -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - return outputs - -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" - ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - -@dataclass -class VideoDiffPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] - - -class VideoDiffPipeline(StableDiffusionControlNetPipeline): - r""" - Pipeline for text-to-video generation. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer (`CLIPTokenizer`): - A [`~transformers.CLIPTokenizer`] to tokenize text. - unet ([`UNet2DConditionModel`]): - A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. - motion_adapter ([`MotionAdapter`]): - A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - - model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["feature_extractor", "image_encoder"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet3DConditionModel, - motion_adapter: MotionAdapter, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], - feature_extractor: CLIPImageProcessor = None, - image_encoder: CLIPVisionModelWithProjection = None, - ): - # super().__init__() - if motion_adapter is not None: - unet = UNetMotionModel.from_unet2d(unet, motion_adapter) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - motion_adapter=motion_adapter, - scheduler=scheduler, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds - - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - - batch_size, channels, num_frames, height, width = latents.shape - latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - - image = self.vae.decode(latents).sample - video = ( - image[None, :] - .reshape( - ( - batch_size, - num_frames, - -1, - ) - + image.shape[2:] - ) - .permute(0, 2, 1, 3, 4) - ) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - video = video.float() - return video - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents - def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None - ): - shape = ( - batch_size, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def get_timesteps(self, num_inference_steps, strength, device, inverse=False): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - if inverse: - if t_start == 0: - timesteps = self.scheduler.timesteps - else: - timesteps = self.scheduler.timesteps[:-t_start * self.scheduler.order] - - else: - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - - return timesteps, num_inference_steps - t_start - - @torch.no_grad() - def encode_frames(self, frames, device): - - dtype = next(self.unet.parameters()).dtype - frames = self.image_processor.preprocess(frames).to(device=device, dtype=dtype) - latent_frames = self.vae.encode(frames).latent_dist.sample() * self.vae.config.scaling_factor - num_frames, channels, height, width = latent_frames.shape - latents_frames = latent_frames.reshape(1, num_frames, channels, height, width).permute(0,2,1,3,4) - - return latents_frames - - - def add_noise_to_latents(self, init_latents, strength, generator=None, num_inference_steps=50): - # dtype - device = init_latents.device - dtype = init_latents.dtype - - # set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps=None) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, inverse=False) - latent_timestep = timesteps[:1].repeat(init_latents.shape[0]) - - # sample noise - shape = init_latents.shape - init_noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - init_latents = self.scheduler.add_noise(init_latents, init_noise, latent_timestep) - latents = init_latents - return latents, init_noise - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - num_frames: Optional[int] = 16, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - timesteps: List[int] = None, - guidance_scale: float = 7.5, - strength: float = 1.0, - rect_scheduled_sampling_beta=1.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_videos_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - clip_skip: Optional[int] = None, - inverse=False, - noise_rect=None, # noise rectification - frames_controlnet=None, # controlnet_frames - frames_inpaint=None, # blend_diffusion - noise_inpaint=None, # blend_diffusion - mask_inpaint=None, # mask - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated video. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated video. - num_frames (`int`, *optional*, defaults to 16): - The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds - amounts to 2 seconds of video. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality videos at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. Latents should be of shape - `(batch_size, num_channel, num_frames, height, width)`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or - `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead - of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - Examples: - - Returns: - [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is - returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. - """ - # # - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - num_videos_per_prompt = 1 - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - # global_pool_conditions = ( - # controlnet.config.global_pool_conditions - # if isinstance(controlnet, ControlNetModel) - # else controlnet.nets[0].config.global_pool_conditions - # ) - # guess_mode = guess_mode or global_pool_conditions - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) - if do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - - # 3.1 Prepare frames for controlnet - if isinstance(controlnet, ControlNetModel) and frames_controlnet is not None: - frames_preprocessed = [] - for frame in frames_controlnet: - frame = self.prepare_image( - image=frame, - width=width, - height=height, - batch_size=batch_size * num_videos_per_prompt, - num_images_per_prompt=num_videos_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - frames_preprocessed.append(frame) - frames_controlnet = torch.cat(frames_preprocessed, dim=0) - height, width = frames_controlnet.shape[-2:] - - # elif isinstance(controlnet, MultiControlNetModel): - # images = [] - - # for image_ in image: - # image_ = self.prepare_image( - # image=image_, - # width=width, - # height=height, - # batch_size=batch_size * num_videos_per_prompt, - # num_images_per_prompt=num_videos_per_prompt, - # device=device, - # dtype=controlnet.dtype, - # do_classifier_free_guidance=do_classifier_free_guidance, - # guess_mode=guess_mode, - # ) - - # images.append(image_) - - # image = images - # height, width = image[0].shape[-2:] - else: - # assert False - pass - - # 5. set timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, inverse) - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - - num_rect_steps = int(rect_scheduled_sampling_beta * len(timesteps)) - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - - # 7.2 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # motion_adatper - if self.motion_adapter is None: - # [bs, channels, frames, height ,width] -> [bs * frames, channels, height, width] - latents = latents.permute(0, 2, 1, 3, 4).reshape((latents.shape[0] * num_frames, -1) + latents.shape[3:]) - # [bs, 77, 768] -> [bs * frames, channels, height, width] - prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) - - - # Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # # controlnet(s) inference - # reshape latent - if frames_controlnet is not None: - if self.motion_adapter is None: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - else: - control_model_input = latent_model_input.permute(0, 2, 1, 3, 4).reshape((latent_model_input.shape[0] * num_frames, -1) + latent_model_input.shape[3:]) - controlnet_prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] - - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=frames_controlnet, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - return_dict=False, - ) - else: - down_block_res_samples, mid_block_res_sample = None, None - - # if guess_mode and self.do_classifier_free_guidance: - # # Infered ControlNet only for the conditional batch. - # # To apply the output of ControlNet to both the unconditional and conditional batches, - # # add 0 to the unconditional batch to keep it unchanged. - # down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - # mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # noise rectification - if noise_rect is not None and i <= num_rect_steps: - delta = noise_rect - noise_pred - noise_pred = noise_pred + mask_inpaint * delta[:, :, 0, :, :].unsqueeze(2).repeat(1, 1, num_frames, 1, 1) + (1 - mask_inpaint) * delta - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if frames_inpaint is not None: - init_latents_proper = frames_inpaint # [batch_size, channels, frames, height, width] - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise_inpaint, torch.tensor([noise_timestep]) - ) - - latents = mask_inpaint * latents + (1 - mask_inpaint) * init_latents_proper - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if self.motion_adapter is None: - latents = latents.reshape((1, num_frames) + latents.shape[-3:]).permute(0,2,1,3,4) - - - if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) - - # Post-processing - video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor - else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (video,) - - return AnimateDiffPipelineOutput(frames=video) diff --git a/scripts/convert.sh b/scripts/convert.sh deleted file mode 100644 index cc9c594..0000000 --- a/scripts/convert.sh +++ /dev/null @@ -1,4 +0,0 @@ -python scripts/convert_original_stable_diffusion_to_diffusers.py \ ---checkpoint_path /home/wangluozhou/pretrained_models/Realistic_Vision_V6.0_B1_noVAE/Realistic_Vision_V6.0_NV_B1_fp16.safetensors \ ---dump_path /home/wangluozhou/pretrained_models/Realistic_Vision_V6.0_B1_noVAE/ \ ---from_safetensors \ No newline at end of file diff --git a/scripts/convert_lora_safetensor_to_diffusers.py b/scripts/convert_lora_safetensor_to_diffusers.py deleted file mode 100644 index e7b525b..0000000 --- a/scripts/convert_lora_safetensor_to_diffusers.py +++ /dev/null @@ -1,128 +0,0 @@ -# coding=utf-8 -# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" Conversion script for the LoRA's safetensors checkpoints. """ - -import argparse - -import torch -from safetensors.torch import load_file - -from diffusers import StableDiffusionPipeline - - -def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): - # load base model - pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) - - # load LoRA weight from .safetensors - state_dict = load_file(checkpoint_path) - - visited = [] - - # directly update weight in diffusers model - for key in state_dict: - # it is suggested to print out the key, it usually will be something like below - # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" - - # as we have set the alpha beforehand, so just skip - if ".alpha" in key or key in visited: - continue - - if "text" in key: - layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") - curr_layer = pipeline.text_encoder - else: - layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") - curr_layer = pipeline.unet - - # find the target layer - temp_name = layer_infos.pop(0) - while len(layer_infos) > -1: - try: - curr_layer = curr_layer.__getattr__(temp_name) - if len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - elif len(layer_infos) == 0: - break - except Exception: - if len(temp_name) > 0: - temp_name += "_" + layer_infos.pop(0) - else: - temp_name = layer_infos.pop(0) - - pair_keys = [] - if "lora_down" in key: - pair_keys.append(key.replace("lora_down", "lora_up")) - pair_keys.append(key) - else: - pair_keys.append(key) - pair_keys.append(key.replace("lora_up", "lora_down")) - - # update weight - if len(state_dict[pair_keys[0]].shape) == 4: - weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) - weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - weight_up = state_dict[pair_keys[0]].to(torch.float32) - weight_down = state_dict[pair_keys[1]].to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) - - # update visited list - for item in pair_keys: - visited.append(item) - - return pipeline - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "--base_model_path", default='/home/wangluozhou/pretrained_models/stable-diffusion-v1-5', type=str, required=False, help="Path to the base model in diffusers format." - ) - parser.add_argument( - "--checkpoint_path", default='/home/wangluozhou/projects/AnimateDiff/models/DreamBooth_LoRA/realisticVisionV60B1_v20Novae.safetensors', type=str, required=False, help="Path to the checkpoint to convert." - ) - parser.add_argument("--dump_path", default='/home/wangluozhou/projects/AnimateDiff/models/DreamBooth_LoRA/realisticVisionV60B1_v20Novae', type=str, required=False, help="Path to the output model.") - parser.add_argument( - "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" - ) - parser.add_argument( - "--lora_prefix_text_encoder", - default="lora_te", - type=str, - help="The prefix of text encoder weight in safetensors", - ) - parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") - parser.add_argument( - "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." - ) - parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") - - args = parser.parse_args() - - base_model_path = args.base_model_path - checkpoint_path = args.checkpoint_path - dump_path = args.dump_path - lora_prefix_unet = args.lora_prefix_unet - lora_prefix_text_encoder = args.lora_prefix_text_encoder - alpha = args.alpha - - pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) - - pipe = pipe.to(args.device) - pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py deleted file mode 100644 index 2ca7096..0000000 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ /dev/null @@ -1,188 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Conversion script for the LDM checkpoints. """ - -import argparse -import importlib - -import torch - -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." - ) - # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml - parser.add_argument( - "--original_config_file", - default=None, - type=str, - help="The YAML config file corresponding to the original architecture.", - ) - parser.add_argument( - "--config_files", - default=None, - type=str, - help="The YAML config file corresponding to the architecture.", - ) - parser.add_argument( - "--num_in_channels", - default=None, - type=int, - help="The number of input channels. If `None` number of input channels will be automatically inferred.", - ) - parser.add_argument( - "--scheduler_type", - default="pndm", - type=str, - help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", - ) - parser.add_argument( - "--pipeline_type", - default=None, - type=str, - help=( - "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" - ". If `None` pipeline will be automatically inferred." - ), - ) - parser.add_argument( - "--image_size", - default=None, - type=int, - help=( - "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" - " Base. Use 768 for Stable Diffusion v2." - ), - ) - parser.add_argument( - "--prediction_type", - default=None, - type=str, - help=( - "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" - " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." - ), - ) - parser.add_argument( - "--extract_ema", - action="store_true", - help=( - "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" - " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" - " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." - ), - ) - parser.add_argument( - "--upcast_attention", - action="store_true", - help=( - "Whether the attention computation should always be upcasted. This is necessary when running stable" - " diffusion 2.1." - ), - ) - parser.add_argument( - "--from_safetensors", - action="store_true", - help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", - ) - parser.add_argument( - "--to_safetensors", - action="store_true", - help="Whether to store pipeline in safetensors format or not.", - ) - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") - parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") - parser.add_argument( - "--stable_unclip", - type=str, - default=None, - required=False, - help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", - ) - parser.add_argument( - "--stable_unclip_prior", - type=str, - default=None, - required=False, - help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", - ) - parser.add_argument( - "--clip_stats_path", - type=str, - help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", - required=False, - ) - parser.add_argument( - "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." - ) - parser.add_argument("--half", action="store_true", help="Save weights in half precision.") - parser.add_argument( - "--vae_path", - type=str, - default=None, - required=False, - help="Set to a path, hub id to an already converted vae to not convert it again.", - ) - parser.add_argument( - "--pipeline_class_name", - type=str, - default=None, - required=False, - help="Specify the pipeline class name", - ) - - args = parser.parse_args() - - if args.pipeline_class_name is not None: - library = importlib.import_module("diffusers") - class_obj = getattr(library, args.pipeline_class_name) - pipeline_class = class_obj - else: - pipeline_class = None - - pipe = download_from_original_stable_diffusion_ckpt( - checkpoint_path_or_dict=args.checkpoint_path, - original_config_file=args.original_config_file, - config_files=args.config_files, - image_size=args.image_size, - prediction_type=args.prediction_type, - model_type=args.pipeline_type, - extract_ema=args.extract_ema, - scheduler_type=args.scheduler_type, - num_in_channels=args.num_in_channels, - upcast_attention=args.upcast_attention, - from_safetensors=args.from_safetensors, - device=args.device, - stable_unclip=args.stable_unclip, - stable_unclip_prior=args.stable_unclip_prior, - clip_stats_path=args.clip_stats_path, - controlnet=args.controlnet, - vae_path=args.vae_path, - pipeline_class=pipeline_class, - ) - - if args.half: - pipe.to(torch_dtype=torch.float16) - - if args.controlnet: - # only save the controlnet model - pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) - else: - pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/scripts/train.sh b/scripts/train.sh deleted file mode 100644 index 599b931..0000000 --- a/scripts/train.sh +++ /dev/null @@ -1,20 +0,0 @@ -CUDA_VISIBLE_DEVICES=1 accelerate launch pe_inversion_unet3d.py \ - --pretrained_model_name_or_path='/home/wangluozhou/pretrained_models/zeroscope_v2_576w' \ - --per_gpu_batch_size=1 --gradient_accumulation_steps=1 \ - --max_train_steps=600 \ - --width=288 \ - --height=160 \ - --num_frames=24 \ - --checkpointing_steps=100 --checkpoints_total_limit=3 \ - --learning_rate=1e-1 --lr_warmup_steps=0 \ - --seed=0 \ - --validation_steps=100 \ - --output_dir='/home/wangluozhou/projects/MotionInversion/outputs/0205/05' \ - --validation_file='/home/wangluozhou/projects/MotionInversion/resources/05.txt' \ - --video_path='/home/wangluozhou/projects/MotionInversion/resources/05_cats_play_24.mp4' \ - --pe_size 1280 \ - --pe_module down mid up \ - --mixed_precision="fp16" \ - --enable_xformers_memory_efficient_attention \ - --num_validation_videos 3 - diff --git a/train.py b/train.py new file mode 100644 index 0000000..de6aad7 --- /dev/null +++ b/train.py @@ -0,0 +1,580 @@ +import argparse +import datetime +import logging +import inspect +import math +import os +import random +import gc +import copy + +from typing import Dict, Optional, Tuple +from omegaconf import OmegaConf + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import diffusers +import transformers + +from torchvision import transforms +from tqdm.auto import tqdm + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed + +from models.unet.unet_3d_condition import UNet3DConditionModel +from diffusers.models import AutoencoderKL +from diffusers import DDIMScheduler, TextToVideoSDPipeline +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention_processor import AttnProcessor2_0, Attention +from diffusers.models.attention import BasicTransformerBlock + +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPEncoder +from einops import rearrange, repeat +from utils.ddim_utils import inverse_video + +import imageio +import numpy as np + +from dataset import * +from loss import * +from noise_init import * + +from utils.func_utils import * + + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(accelerator, config, batch, global_step, text_prompt, unet, text_encoder, vae, output_dir): + with accelerator.autocast(): + unet.eval() + text_encoder.eval() + unet_and_text_g_c(unet, text_encoder, False, False) + + # handle spatial lora + if 'spatial_scale' in config.val.keys(): + loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) + for lora_i in loras: + lora_i.scale = config.val.spatial_scale + + # preset_noise = batch['inversion_noise'] + + pipeline = TextToVideoSDPipeline.from_pretrained( + config.model.pretrained_model_path, + text_encoder=text_encoder, + vae=vae, + unet=unet + ) + + diffusion_scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + pipeline.scheduler = diffusion_scheduler + + prompt_list = text_prompt if len(config.val.prompt) <= 0 else config.val.prompt + for seed in range(config.val.seed_range[0], config.val.seed_range[1]): + + config.noise_init.seed = seed + # handle different noise initialization strategy + init_func_name = f'initialize_noise_with_{config.noise_init.type}' + # Assuming config.dataset is a DictConfig object + init_params_dict = OmegaConf.to_container(config.noise_init, resolve=True) + # Remove the 'type' key + init_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist + + init_func_to_call = globals().get(init_func_name) + init_noise = init_func_to_call(batch['inversion_noise'], **init_params_dict) + + for prompt in prompt_list: + file_name = f"{prompt.replace(' ', '_')}_seed_{seed}.mp4" + file_path = f"{output_dir}/samples_{global_step}/" + if not os.path.exists(file_path): + os.makedirs(file_path) + + with torch.no_grad(): + video_frames = pipeline( + prompt=config.val.prompt_prefix + prompt, + negative_prompt=config.val.negative_prompt, + width=config.val.width, + height=config.val.height, + num_frames=config.val.num_frames, + num_inference_steps=config.val.num_inference_steps, + guidance_scale=config.val.guidance_scale, + latents=init_noise, + ).frames[0] + export_to_video(video_frames, os.path.join(file_path, file_name), config.dataset.fps) + logger.info(f"Saved a new sample to {os.path.join(file_path, file_name)}") + del pipeline + torch.cuda.empty_cache() + +def create_logging(logging, logger, accelerator): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + +def accelerate_set_verbose(accelerator): + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + +def export_to_video(video_frames, output_video_path, fps): + video_writer = imageio.get_writer(output_video_path, fps=fps) + for img in video_frames: + video_writer.append_data(np.array(img)) + video_writer.close() + +def create_output_folders(output_dir, config): + # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + out_dir = os.path.join(output_dir) + + os.makedirs(out_dir, exist_ok=True) + # os.makedirs(f"{out_dir}/samples", exist_ok=True) + OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) + + return out_dir + +def load_primary_models(pretrained_model_path): + noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") + + return noise_scheduler, tokenizer, text_encoder, vae, unet + +def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable): + unet._set_gradient_checkpointing(value=unet_enable) + text_encoder._set_gradient_checkpointing(CLIPEncoder, value=text_enable) + +def freeze_models(models_to_freeze): + for model in models_to_freeze: + if model is not None: model.requires_grad_(False) + +def is_attn(name): + return ('attn1' or 'attn2' == name.split('.')[-1]) + +def set_processors(attentions): + for attn in attentions: attn.set_processor(AttnProcessor2_0()) + +def set_torch_2_attn(unet): + optim_count = 0 + + for name, module in unet.named_modules(): + if is_attn(name): + if isinstance(module, torch.nn.ModuleList): + for m in module: + if isinstance(m, BasicTransformerBlock): + set_processors([m.attn1, m.attn2]) + optim_count += 1 + if optim_count > 0: + print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") + +def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet): + try: + is_torch_2 = hasattr(F, 'scaled_dot_product_attention') + enable_torch_2 = is_torch_2 and enable_torch_2_attn + + if enable_xformers_memory_efficient_attention and not enable_torch_2: + if is_xformers_available(): + from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if enable_torch_2: + set_torch_2_attn(unet) + + except: + print("Could not enable memory efficient attention for xformers or Torch 2.0.") + +def negate_params(name, negation): + # We have to do this if we are co-training with LoRA. + # This ensures that parameter groups aren't duplicated. + if negation is None: return False + for n in negation: + if n in name and 'temp' not in name: + return True + return False + +def is_mixed_precision(accelerator): + weight_dtype = torch.float32 + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + return weight_dtype + +def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): + for model in model_list: + if model is not None: model.to(accelerator.device, dtype=weight_dtype) + +def handle_cache_latents( + should_cache, + output_dir, + train_dataloader, + train_batch_size, + vae, + unet, + pretrained_model_path, + cached_latent_dir=None, +): + # Cache latents by storing them in VRAM. + # Speeds up training and saves memory by not encoding during the train loop. + if not should_cache: return None + vae.to('cuda', dtype=torch.float16) + vae.enable_slicing() + + pipe = TextToVideoSDPipeline.from_pretrained( + pretrained_model_path, + vae=vae, + unet=copy.deepcopy(unet).to('cuda', dtype=torch.float16) + ) + pipe.text_encoder.to('cuda', dtype=torch.float16) + + cached_latent_dir = ( + os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None + ) + + if cached_latent_dir is None: + cache_save_dir = f"{output_dir}/cached_latents" + os.makedirs(cache_save_dir, exist_ok=True) + + for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): + save_name = f"cached_{i}" + full_out_path = f"{cache_save_dir}/{save_name}.pt" + + pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) + batch['latents'] = tensor_to_vae_latent(pixel_values, vae) + + batch['inversion_noise'] = inverse_video(pipe, batch['latents'], 50) + for k, v in batch.items(): batch[k] = v[0] + + torch.save(batch, full_out_path) + del pixel_values + del batch + + # We do this to avoid fragmentation from casting latents between devices. + torch.cuda.empty_cache() + else: + cache_save_dir = cached_latent_dir + + return torch.utils.data.DataLoader( + CachedDataset(cache_dir=cache_save_dir), + batch_size=train_batch_size, + shuffle=True, + num_workers=0 + ) + +def enforce_zero_terminal_snr(betas): + """ + Corrects noise in diffusion schedulers. + From: Common Diffusion Noise Schedules and Sample Steps are Flawed + https://arxiv.org/pdf/2305.08891.pdf + """ + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( + alphas_bar_sqrt_0 - alphas_bar_sqrt_T + ) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + +def should_sample(global_step, validation_steps, validation_data): + return global_step % validation_steps == 0 and validation_data.sample_preview + +def save_pipe( + path, + global_step, + accelerator, + unet, + text_encoder, + vae, + output_dir, + is_checkpoint=False, + save_pretrained_model=True, + **extra_params +): + if is_checkpoint: + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + os.makedirs(save_path, exist_ok=True) + else: + save_path = output_dir + + # Save the dtypes so we can continue training at the same precision. + u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype + + # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled. + unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False)) + text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder.cpu(), keep_fp32_wrapper=False)) + + pipeline = TextToVideoSDPipeline.from_pretrained( + path, + unet=unet_out, + text_encoder=text_encoder_out, + vae=vae, + ).to(torch_dtype=torch.float32) + + lora_managers_spatial = extra_params.get('lora_managers_spatial', [None]) + lora_manager_spatial = lora_managers_spatial[-1] + if lora_manager_spatial is not None: + lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step) + + save_motion_embeddings(unet_out, os.path.join(save_path, 'motion_embed.pt')) + + if save_pretrained_model: + pipeline.save_pretrained(save_path) + + if is_checkpoint: + unet, text_encoder = accelerator.prepare(unet, text_encoder) + models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)] + [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back] + + logger.info(f"Saved model at {save_path} on step {global_step}") + + del pipeline + del unet_out + del text_encoder_out + torch.cuda.empty_cache() + gc.collect() + +def main(config): + # Initialize the Accelerator + accelerator = Accelerator( + gradient_accumulation_steps=config.train.gradient_accumulation_steps, + mixed_precision=config.train.mixed_precision, + log_with=config.train.logger_type, + project_dir=config.train.output_dir + ) + + # Create output directories and set up logging + if accelerator.is_main_process: + output_dir = create_output_folders(config.train.output_dir, config) + create_logging(logging, logger, accelerator) + accelerate_set_verbose(accelerator) + + # Load primary models + noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(config.model.pretrained_model_path) + freeze_models([vae, text_encoder]) + handle_memory_attention(config.train.enable_xformers_memory_efficient_attention, config.train.enable_torch_2_attn, unet) + + train_dataloader, train_dataset = prepare_data(config, tokenizer) + + # Handle latents caching + cached_data_loader = handle_cache_latents( + config.train.cache_latents, + output_dir, + train_dataloader, + config.train.train_batch_size, + vae, + unet, + config.model.pretrained_model_path, + config.train.cached_latent_dir, + ) + if cached_data_loader is not None: + train_dataloader = cached_data_loader + + # Prepare parameters and optimization + params, extra_params = prepare_params(unet, config, train_dataset) + optimizers, lr_schedulers = prepare_optimizers(params, config, **extra_params) + + + # Prepare models and data for training + unet, optimizers, train_dataloader, lr_schedulers, text_encoder = accelerator.prepare( + unet, optimizers, train_dataloader, lr_schedulers, text_encoder + ) + + # Additional model setups + unet_and_text_g_c(unet, text_encoder, config.train.gradient_checkpointing, config.train.text_encoder_gradient_checkpointing) + vae.enable_slicing() + + # Setup for mixed precision training + weight_dtype = is_mixed_precision(accelerator) + cast_to_gpu_and_type([text_encoder, vae], accelerator, weight_dtype) + + # Recalculate training steps and epochs + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.train.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.train.max_train_steps / num_update_steps_per_epoch) + + # Initialize trackers and store configuration + if accelerator.is_main_process: + accelerator.init_trackers("text2video-fine-tune") + + # Train! + total_batch_size = config.train.train_batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {config.train.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {config.train.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, config.train.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + + for epoch in range(first_epoch, num_train_epochs): + train_loss_temporal = 0.0 + + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if config.train.resume_from_checkpoint and epoch == first_epoch and step < config.train.resume_step: + if step % config.train.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): + + text_prompt = batch['text_prompt'][0] + + for optimizer in optimizers: + optimizer.zero_grad(set_to_none=True) + + with accelerator.autocast(): + if global_step == 0: + unet.train() + + loss_func_name = f'{config.loss.type}Loss' + loss_func_to_call = globals().get(loss_func_name) + + loss_temporal, train_loss_temporal = loss_func_to_call( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss_temporal}, step=global_step) + train_loss_temporal = 0.0 + if global_step % config.train.checkpointing_steps == 0 and global_step > 0: + save_pipe( + config.model.pretrained_model_path, + global_step, + accelerator, + unet, + text_encoder, + vae, + output_dir, + is_checkpoint=True, + save_pretrained_model=config.train.save_pretrained_model, + **extra_params + ) + + if should_sample(global_step, config.train.validation_steps, config.val): + if accelerator.is_main_process: + log_validation( + accelerator, + config, + batch, + global_step, + text_prompt, + unet, + text_encoder, + vae, + output_dir + ) + + + unet_and_text_g_c( + unet, + text_encoder, + config.train.gradient_checkpointing, + config.train.text_encoder_gradient_checkpointing + ) + + if loss_temporal is not None: + accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step) + + if global_step >= config.train.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_pipe( + config.model.pretrained_model_path, + global_step, + accelerator, + unet, + text_encoder, + vae, + output_dir, + is_checkpoint=False, + save_pretrained_model=config.train.save_pretrained_model, + **extra_params + ) + accelerator.end_training() + +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument("--config", type=str, default='/remote-home/lzwang/projects/MotionInversion/configs/config.yaml') +# args = parser.parse_args() + +# # Load and merge configurations +# config = OmegaConf.load(args.config) +# main(config) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default='/remote-home/lzwang/projects/MotionInversion/configs/config.yaml') + parser.add_argument("--single_video_path", type=str) + parser.add_argument("--prompts", type=str, help="JSON string of prompts") + args = parser.parse_args() + + # Load and merge configurations + config = OmegaConf.load(args.config) + + # Update the config with the command-line arguments + if args.single_video_path: + config.dataset.single_video_path = args.single_video_path + # Set the output dir + config.train.output_dir = os.path.join(config.train.output_dir, os.path.basename(args.single_video_path).split('.')[0]) + + if args.prompts: + config.val.prompt = json.loads(args.prompts) + + + + main(config) diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/utils/attn_utils.py b/utils/attn_utils.py deleted file mode 100644 index 3b86422..0000000 --- a/utils/attn_utils.py +++ /dev/null @@ -1,356 +0,0 @@ -import abc -import torch -from diffusers.models.attention_processor import Attention, AttnProcessor2_0 - -import torch -from PIL import Image -import numpy as np -import matplotlib.pyplot as plt - -def reduce_batch_size_and_convert_to_image(attention_map, pixel_size=10, frames=16): - # Reducing the batch size dimension by averaging - reduced_map = attention_map.mean(dim=0) - - # Normalizing the attention map to be between 0 and 1 - normalized_map = (reduced_map - reduced_map.min()) / (reduced_map.max() - reduced_map.min()) - - # Converting to 24x24 image - image = normalized_map.view(frames, frames).cpu().detach().numpy() - - # Scaling the normalized image to be between 0 and 255 for uint8 - image_scaled = np.uint8(image * 255) - - # Convert to PIL Image in 'L' mode (grayscale) and resize - pil_image = Image.fromarray(image_scaled, 'L').resize((frames * pixel_size, frames * pixel_size), resample=Image.NEAREST) - - return pil_image - -def build_image_grid(attention_maps_list, pixel_size=10, frames=16): - images = [reduce_batch_size_and_convert_to_image(attention_map, frames=frames) for attention_map in attention_maps_list] - # Calculate grid size - grid_size = int(np.ceil(np.sqrt(len(images)))) - - # Resize each image to make each pixel a larger square - resized_images = [image.resize((frames * pixel_size, frames * pixel_size), resample=Image.NEAREST) for image in images] - width, height = resized_images[0].size - - # Create a new image with white background for the grid - grid_img = Image.new('RGB', size=(grid_size * width, grid_size * height), color=(255, 255, 255)) - - for i, image in enumerate(resized_images): - grid_x = (i % grid_size) * width - grid_y = (i // grid_size) * height - grid_img.paste(image, (grid_x, grid_y)) - - return grid_img - -def compute_average_map(attention_maps_list, pixel_size=20, frames=16, reduction='mean', batch_size=2, height=16, width=16): - - dtype = attention_maps_list[0].dtype - device = attention_maps_list[0].device - if reduction == 'temporal': - # Initialize an empty tensor for averaging - average_map = torch.zeros(batch_size, height, width, frames, frames, dtype=dtype, device=device) - - for attention_map in attention_maps_list: - # Restore each attention map back to [batch_size, height, width, num_frames, num_frames] - reshaped_map = attention_map.reshape(batch_size, height, width, frames, frames) - average_map += reshaped_map - - # Compute the average - average_map /= len(attention_maps_list) - - image_batch = [] - for b in range(batch_size): - # Create a grid for each batch - grid = torch.zeros(height * frames, width * frames).to(device, dtype) - - for h in range(height): - for w in range(width): - # Extract each num_frames * num_frames image - img = average_map[b, h, w, :, :] - grid[h*frames:(h+1)*frames, w*frames:(w+1)*frames] = img - - # Normalize and convert to PIL image - grid_normalized = grid.cpu().detach().numpy() - grid_normalized = (grid_normalized - grid_normalized.min()) / (grid_normalized.max() - grid_normalized.min()) * 255 - grid_image = Image.fromarray(grid_normalized.astype(np.uint8), 'L') - resized_image = grid_image.resize((width * frames * pixel_size, height * frames * pixel_size), resample=Image.NEAREST) - image_batch.append(resized_image) - return image_batch - - elif reduction =='spatial': - average_map = torch.zeros(batch_size, frames, frames, height, width, dtype=dtype, device=device) - - for attention_map in attention_maps_list: - # Restore each attention map back to [batch_size, height, width, num_frames, num_frames] - reshaped_map = attention_map.reshape(batch_size, height, width, frames, frames) - average_map += reshaped_map - - # Compute the average - average_map /= len(attention_maps_list) - - # Process the average map to create a batch of frame grid images - image_batch = [] - for b in range(batch_size): - # Create a grid for each batch - grid = torch.zeros(frames * height, frames * width, dtype=dtype, device=device) - - for f1 in range(frames): - for f2 in range(frames): - # Extract each height * width image - img = average_map[b, :, :, f1, f2] - grid[f1*height:(f1+1)*height, f2*width:(f2+1)*width] = img - - # Normalize and convert to PIL image - grid_normalized = grid.cpu().numpy() - grid_normalized = (grid_normalized - grid_normalized.min()) / (grid_normalized.max() - grid_normalized.min()) * 255 - grid_image = Image.fromarray(grid_normalized.astype(np.uint8), 'L') - resized_image = grid_image.resize((width * frames * pixel_size, height * frames * pixel_size), resample=Image.NEAREST) - - image_batch.append(resized_image) - - return image_batch - - elif reduction =='mean': - # Initialize an empty tensor for averaging - average_map = torch.zeros(frames, frames).to(device, dtype) - - for attention_map in attention_maps_list: - # Reduce each attention map and add to the average - reduced_map = attention_map.mean(dim=0) - average_map += reduced_map.view(frames, frames) - - # Compute the average - average_map /= len(attention_maps_list) - - # Convert the average tensor to a numpy array - average_array = average_map.cpu().detach().numpy() - - # Normalize the array to be in the range [0, 255] - average_array_normalized = (average_array - average_array.min()) / (average_array.max() - average_array.min()) * 255 - average_array_normalized = average_array_normalized.astype(np.uint8) - - # Convert to a PIL image in 'L' mode (grayscale) - average_image = Image.fromarray(average_array_normalized, 'L') - - # Resize the image to make each pixel a larger square - new_size = (frames * pixel_size, frames * pixel_size) - resized_image = average_image.resize(new_size, resample=Image.NEAREST) - - return resized_image - -def register_attention_control(self, controller): - - attn_procs = {} - temp_attn_count = 0 - - for name in self.unet.attn_processors.keys(): - if 'temp_attentions' in name or 'motion_modules' in name: - if name.endswith("fuser.attn.processor"): - attn_procs[name] = DummyAttnProcessor() - continue - - if name.startswith("mid_block"): - place_in_unet = "mid" - - elif name.startswith("up_blocks"): - place_in_unet = "up" - - elif name.startswith("down_blocks"): - place_in_unet = "down" - - else: - continue - - temp_attn_count += 1 - attn_procs[name] = MyAttnProcessor( - attnstore=controller, place_in_unet=place_in_unet - ) - else: - attn_procs[name] = AttnProcessor2_0() - - - self.unet.set_attn_processor(attn_procs) - controller.num_att_layers = temp_attn_count - -class MyAttnProcessor: - - def __init__(self, attnstore, place_in_unet, hidden_states_store=None): - super().__init__() - self.attnstore = attnstore - self.hidden_states_store = hidden_states_store - self.place_in_unet = place_in_unet - - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size) - - query = attn.to_q(hidden_states) - - is_cross = encoder_hidden_states is not None - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - - attention_probs = self.attnstore(attention_probs, is_cross, self.place_in_unet) # - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - -class AttentionControl(abc.ABC): - - def step_callback(self, x_t): - return x_t - - def between_steps(self): - return - - @property - def num_uncond_att_layers(self): - return 0 # compute in parrallel - - @abc.abstractmethod - def forward(self, attn, is_cross: bool, place_in_unet: str): - raise NotImplementedError - - def __call__(self, attn, is_cross: bool, place_in_unet: str): - if self.cur_att_layer >= self.num_uncond_att_layers: - h = attn.shape[0] - attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) - self.cur_att_layer += 1 - if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: - self.cur_att_layer = 0 - self.cur_step += 1 - self.between_steps() - return attn - - def reset(self): - self.cur_step = 0 - self.cur_att_layer = 0 - - def __init__(self): - self.cur_step = 0 - self.num_att_layers = -1 - self.cur_att_layer = 0 - -class AttentionStore(AttentionControl): - - @staticmethod - def get_empty_store(): - return {"down_cross": [], "mid_cross": [], "up_cross": [], - "down_self": [], "mid_self": [], "up_self": []} - - def forward(self, attn, is_cross: bool, place_in_unet: str): - key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" - # if attn.shape[1] <= 64 ** 2: # avoid memory overhead - resolution = int((attn.shape[0] // (self.batch_size * 2)) ** (0.5)) - if key in self.target_keys and resolution in self.target_resolutions: - self.step_store[key].append(attn) - return attn - - def between_steps(self): - if len(self.attention_store) == 0: - self.attention_store = self.step_store - else: - for key in self.attention_store: - for i in range(len(self.attention_store[key])): - self.attention_store[key][i] += self.step_store[key][i] - self.step_store = self.get_empty_store() - - def get_average_attention(self): - average_attention = self.attention_store - return average_attention - - def get_average_global_attention(self, type=None): - average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in - self.attention_store} - return average_attention - - def reset(self): - super(AttentionStore, self).reset() - self.step_store = self.get_empty_store() - self.attention_store = {} - self.global_store = {} - - def __init__(self): - ''' - Initialize an empty AttentionStore - :param step_index: used to visualize only a specific step in the diffusion process - ''' - super(AttentionStore, self).__init__() - self.step_store = self.get_empty_store() - self.attention_store = {} - self.target_keys = ['down_self', 'mid_self', 'up_self'] - self.target_resolutions = [16, 32, 64, 128] - self.batch_size = 1 - - - - -class AttentionReplacement(AttentionControl): - - @staticmethod - def get_empty_store(): - return {"down_cross": [], "mid_cross": [], "up_cross": [], - "down_self": [], "mid_self": [], "up_self": []} - - def forward(self, attn, is_cross: bool, place_in_unet: str): - key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" - # if attn.shape[1] <= 64 ** 2: # avoid memory overhead - # self.step_store[key].append(attn) - resolution = int((attn.shape[0] // (self.batch_size * 2)) ** (0.5)) - if key in self.target_keys and resolution in self.target_resolutions: - h = attn.shape[0] // 2 - attn[h:] = attn[:h] - - return attn - - # def between_steps(self): - # if len(self.attention_store) == 0: - # self.attention_store = self.step_store - # else: - # for key in self.attention_store: - # for i in range(len(self.attention_store[key])): - # self.attention_store[key][i] += self.step_store[key][i] - # self.step_store = self.get_empty_store() - - # def get_average_attention(self): - # average_attention = self.attention_store - # return average_attention - - # def get_average_global_attention(self, type=None): - # average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in - # self.attention_store} - # return average_attention - - def reset(self): - super(AttentionReplacement, self).reset() - self.step_store = self.get_empty_store() - self.attention_store = {} - self.global_store = {} - - def __init__(self): - ''' - Initialize an empty AttentionStore - :param step_index: used to visualize only a specific step in the diffusion process - ''' - super(AttentionReplacement, self).__init__() - self.step_store = self.get_empty_store() - self.attention_store = {} - self.target_keys = ['down_self', 'mid_self', 'up_self'] - self.target_resolutions = [16, 32, 64, 128] - self.batch_size = 2 \ No newline at end of file diff --git a/utils/convert_diffusers_to_original_ms_text_to_video.py b/utils/convert_diffusers_to_original_ms_text_to_video.py new file mode 100644 index 0000000..83758fe --- /dev/null +++ b/utils/convert_diffusers_to_original_ms_text_to_video.py @@ -0,0 +1,465 @@ +# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. +# *Only* converts the UNet, and Text Encoder. +# Does not convert optimizer state or any other thing. + +import argparse +import os.path as osp +import re + +import torch +from safetensors.torch import load_file, save_file + +# =================# +# UNet Conversion # +# =================# + +print ('Initializing the conversion map') + +unet_conversion_map = [ + # (ModelScope, HF Diffusers) + + # from Vanilla ModelScope/StableDiffusion + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + + + # from Vanilla ModelScope/StableDiffusion + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + + + # from Vanilla ModelScope/StableDiffusion + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), +] + +unet_conversion_map_resnet = [ + # (ModelScope, HF Diffusers) + + # SD + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + + # MS + #("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha +] + +unet_conversion_map_layer = [] + +# Convert input TemporalTransformer +unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in')) + +# Reference for the default settings + +# "model_cfg": { +# "unet_in_dim": 4, +# "unet_dim": 320, +# "unet_y_dim": 768, +# "unet_context_dim": 1024, +# "unet_out_dim": 4, +# "unet_dim_mult": [1, 2, 4, 4], +# "unet_num_heads": 8, +# "unet_head_dim": 64, +# "unet_res_blocks": 2, +# "unet_attn_scales": [1, 0.5, 0.25], +# "unet_dropout": 0.1, +# "temporal_attention": "True", +# "num_timesteps": 1000, +# "mean_type": "eps", +# "var_type": "fixed_small", +# "loss_type": "mse" +# } + +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + + # Spacial SD stuff + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + # Temporal MS stuff + hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0.temopral_conv." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.2." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + + # Spacial SD stuff + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0.temopral_conv." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.2." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + # Up/Downsamplers are 2D, so don't need to touch them + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 3}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + +# Handle the middle block + +# Spacial +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{3*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + +# Temporal +hf_mid_atn_prefix = "mid_block.temp_attentions.0." +sd_mid_atn_prefix = "middle_block.2." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.temp_convs.{j}." + sd_mid_res_prefix = f"middle_block.{3*j}.temopral_conv." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + +# The pipeline +def convert_unet_state_dict(unet_state_dict, strict_mapping=False): + print ('Converting the UNET') + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + + for sd_name, hf_name in unet_conversion_map: + if strict_mapping: + if hf_name in mapping: + mapping[hf_name] = sd_name + else: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + # elif "temp_convs" in k: + # for sd_part, hf_part in unet_conversion_map_resnet: + # v = v.replace(hf_part, sd_part) + # mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + + + # there must be a pattern, but I don't want to bother atm + do_not_unsqueeze = [f'output_blocks.{i}.1.proj_out.weight' for i in range(3, 12)] + [f'output_blocks.{i}.1.proj_in.weight' for i in range(3, 12)] + ['middle_block.1.proj_in.weight', 'middle_block.1.proj_out.weight'] + [f'input_blocks.{i}.1.proj_out.weight' for i in [1, 2, 4, 5, 7, 8]] + [f'input_blocks.{i}.1.proj_in.weight' for i in [1, 2, 4, 5, 7, 8]] + print (do_not_unsqueeze) + + new_state_dict = {v: (unet_state_dict[k].unsqueeze(-1) if ('proj_' in k and ('bias' not in k) and (k not in do_not_unsqueeze)) else unet_state_dict[k]) for k, v in mapping.items()} + # HACK: idk why the hell it does not work with list comprehension + for k, v in new_state_dict.items(): + has_k = False + for n in do_not_unsqueeze: + if k == n: + has_k = True + + if has_k: + v = v.squeeze(-1) + new_state_dict[k] = v + + return new_state_dict + +# TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha + +# =========================# +# Text Encoder Conversion # +# =========================# + +# IT IS THE SAME CLIP ENCODER, SO JUST COPYPASTING IT --kabachuha + +# =========================# +# Text Encoder Conversion # +# =========================# + + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_text_enc_state_dict_v20(text_enc_dict): + #print ('Converting the text encoder') + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_text_enc_state_dict_v20(text_enc_dict): + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.") + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." + ) + + args = parser.parse_args() + + assert args.model_path is not None, "Must provide a model path!" + + assert args.checkpoint_path is not None, "Must provide a checkpoint path!" + + assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!" + + # Path for safetensors + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") + #vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + # if osp.exists(vae_path): + # vae_state_dict = load_file(vae_path, device="cpu") + # else: + # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") + # vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict(unet_state_dict) + #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + # vae_state_dict = convert_vae_state_dict(vae_state_dict) + # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + + # MODELSCOPE always uses the 2.X encoder, btw --kabachuha + + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) + #text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = convert_text_enc_state_dict(text_enc_dict) + #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha + # Save CLIP and the Diffusion model to their own files + + #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + print ('Saving UNET') + state_dict = {**unet_state_dict} + + if args.half: + state_dict = {k: v.half() for k, v in state_dict.items()} + + if args.use_safetensors: + save_file(state_dict, args.checkpoint_path) + else: + #state_dict = {"state_dict": state_dict} + torch.save(state_dict, args.checkpoint_path) + + # TODO: CLIP conversion doesn't work atm + # print ('Saving CLIP') + # state_dict = {**text_enc_dict} + + # if args.half: + # state_dict = {k: v.half() for k, v in state_dict.items()} + + # if args.use_safetensors: + # save_file(state_dict, args.checkpoint_path) + # else: + # #state_dict = {"state_dict": state_dict} + # torch.save(state_dict, args.clip_checkpoint_path) + + print('Operation successfull') diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py new file mode 100644 index 0000000..a6d1d19 --- /dev/null +++ b/utils/dataset_utils.py @@ -0,0 +1,113 @@ +import os +import json +import decord +decord.bridge.set_bridge('torch') + +import torch +from torch.utils.data import Dataset +import torchvision +import torchvision.transforms as T + +from itertools import islice + +from glob import glob +from PIL import Image +from einops import rearrange, repeat + + +def read_caption_file(caption_file): + with open(caption_file, 'r', encoding="utf8") as t: + return t.read() + +def get_text_prompt( + text_prompt: str = '', + fallback_prompt: str= '', + file_path:str = '', + ext_types=['.mp4'], + use_caption=False + ): + try: + if use_caption: + if len(text_prompt) > 1: return text_prompt + caption_file = '' + # Use caption on per-video basis (One caption PER video) + for ext in ext_types: + maybe_file = file_path.replace(ext, '.txt') + if maybe_file.endswith(ext_types): continue + if os.path.exists(maybe_file): + caption_file = maybe_file + break + + if os.path.exists(caption_file): + return read_caption_file(caption_file) + + # Return fallback prompt if no conditions are met. + return fallback_prompt + + return text_prompt + except: + print(f"Couldn't read prompt caption for {file_path}. Using fallback.") + return fallback_prompt + +def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24): + max_range = len(vr) + frame_number = sorted((0, start_idx, max_range))[1] + + frame_range = range(frame_number, max_range, sample_rate) + frame_range_indices = list(frame_range)[:max_frames] + + return frame_range_indices + +def get_prompt_ids(prompt, tokenizer): + prompt_ids = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return prompt_ids + +def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch): + if use_bucketing: + vr = decord.VideoReader(vid_path) + resize = get_frame_buckets(vr) + video = get_frame_batch(vr, resize=resize) + + else: + vr = decord.VideoReader(vid_path, width=w, height=h) + video = get_frame_batch(vr) + + return video, vr + +def min_res(size, min_size): return 192 if size < 192 else size + +def up_down_bucket(m_size, in_size, direction): + if direction == 'down': return abs(int(m_size - in_size)) + if direction == 'up': return abs(int(m_size + in_size)) + +def get_bucket_sizes(size, direction: 'down', min_size): + multipliers = [64, 128] + for i, m in enumerate(multipliers): + res = up_down_bucket(m, size, direction) + multipliers[i] = min_res(res, min_size=min_size) + return multipliers + +def closest_bucket(m_size, size, direction, min_size): + lst = get_bucket_sizes(m_size, direction, min_size) + return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))] + +def resolve_bucket(i,h,w): return (i / (h / w)) + +def sensible_buckets(m_width, m_height, w, h, min_size=192): + if h > w: + w = resolve_bucket(m_width, h, w) + w = closest_bucket(m_width, w, 'down', min_size=min_size) + return w, m_height + if h < w: + h = resolve_bucket(m_height, w, h) + h = closest_bucket(m_height, h, 'down', min_size=min_size) + return m_width, h + + return m_width, m_height \ No newline at end of file diff --git a/utils/ddim_utils.py b/utils/ddim_utils.py new file mode 100644 index 0000000..2c0163c --- /dev/null +++ b/utils/ddim_utils.py @@ -0,0 +1,76 @@ +import numpy as np +from typing import Union + +import torch + +from tqdm import tqdm +from diffusers import DDIMScheduler + + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents + + +def inverse_video(pipe, latents, num_steps): + ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + ddim_inv_scheduler.set_timesteps(num_steps) + + ddim_inv_latent = ddim_inversion( + pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device), + num_inv_steps=num_steps, prompt="")[-1] + return ddim_inv_latent \ No newline at end of file diff --git a/utils/extract_16frames.py b/utils/extract_16frames.py new file mode 100644 index 0000000..0985b97 --- /dev/null +++ b/utils/extract_16frames.py @@ -0,0 +1,36 @@ +import cv2 +import imageio + + +def get_total_frames(video_path): + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return total_frames + + +def extract_frames(input_path, output_path, target_fps, selected_frames): + total_frames = get_total_frames(input_path) + + video_reader = imageio.get_reader(input_path) + fps = video_reader.get_meta_data()['fps'] + + target_total_frames = selected_frames + frame_interval = max(1, int(fps / target_fps)) + selected_indices = [int(i * frame_interval) for i in range(target_total_frames)] + + target_frames = [video_reader.get_data(i) for i in selected_indices] + with imageio.get_writer(output_path, fps=target_fps) as video_writer: + for frame in target_frames: + video_writer.append_data(frame) + + +if __name__ == "__main__": + input_video_path = "/home/luozhouwang/projects/MotionInversion_source/resources/DAVIS/bmx-bumps.mp4" + output_video_path = "/home/luozhouwang/projects/MotionInversion/assets/DAVIS/bmx-bumps-24.mp4" + target_fps = 8 + selected_frames = 24 + + extract_frames(input_video_path, output_video_path, target_fps, selected_frames) + + diff --git a/utils/func_utils.py b/utils/func_utils.py new file mode 100644 index 0000000..f3a7203 --- /dev/null +++ b/utils/func_utils.py @@ -0,0 +1,276 @@ +import torch +import random +import torch.nn.functional as F +from torchvision import transforms +from diffusers.optimization import get_scheduler +from einops import rearrange, repeat +from omegaconf import OmegaConf +from dataset import * +from models.unet.motion_embeddings import * +from .lora import * +from .lora_handler import * + +def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): + extra_params = extra_params if len(extra_params.keys()) > 0 else None + return { + "model": model, + "condition": condition, + 'extra_params': extra_params, + 'is_lora': is_lora, + "negation": negation + } + +def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None): + params = { + "name": name, + "params": params, + "lr": lr + } + if extra_params is not None: + for k, v in extra_params.items(): + params[k] = v + + return params + +def create_optimizer_params(model_list, lr): + import itertools + optimizer_params = [] + + for optim in model_list: + model, condition, extra_params, is_lora, negation = optim.values() + # Check if we are doing LoRA training. + if is_lora and condition and isinstance(model, list): + params = create_optim_params( + params=itertools.chain(*model), + extra_params=extra_params + ) + optimizer_params.append(params) + continue + + if is_lora and condition and not isinstance(model, list): + for n, p in model.named_parameters(): + if 'lora' in n: + params = create_optim_params(n, p, lr, extra_params) + optimizer_params.append(params) + continue + + # If this is true, we can train it. + if condition: + for n, p in model.named_parameters(): + should_negate = 'lora' in n and not is_lora + if should_negate: continue + + params = create_optim_params(n, p, lr, extra_params) + optimizer_params.append(params) + + return optimizer_params + +def get_optimizer(use_8bit_adam): + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + return bnb.optim.AdamW8bit + else: + return torch.optim.AdamW + +# Initialize the optimizer +def prepare_optimizers(params, config, **extra_params): + optimizer_cls = get_optimizer(config.train.use_8bit_adam) + + optimizer_temporal = optimizer_cls( + params, + lr=config.loss.learning_rate, + betas=(config.loss.adam_beta1, config.loss.adam_beta2), + weight_decay=config.loss.adam_weight_decay, + eps=config.loss.adam_epsilon, + ) + + lr_scheduler_temporal = get_scheduler( + config.loss.lr_scheduler, + optimizer=optimizer_temporal, + num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, + num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, + ) + + # Insert Spatial LoRAs + if config.loss.type == 'DebiasHybrid': + unet_lora_params_spatial_list = extra_params.get('unet_lora_params_spatial_list', []) + spatial_lora_num = extra_params.get('spatial_lora_num', 1) + + optimizer_spatial_list = [] + lr_scheduler_spatial_list = [] + for i in range(spatial_lora_num): + unet_lora_params_spatial = unet_lora_params_spatial_list[i] + + optimizer_spatial = optimizer_cls( + create_optimizer_params( + [ + param_optim( + unet_lora_params_spatial, + config.loss.use_unet_lora, + is_lora=True, + extra_params={**{"lr": config.loss.learning_rate_spatial}} + ) + ], + config.loss.learning_rate_spatial + ), + lr=config.loss.learning_rate_spatial, + betas=(config.loss.adam_beta1, config.loss.adam_beta2), + weight_decay=config.loss.adam_weight_decay, + eps=config.loss.adam_epsilon, + ) + optimizer_spatial_list.append(optimizer_spatial) + + # Scheduler + lr_scheduler_spatial = get_scheduler( + config.loss.lr_scheduler, + optimizer=optimizer_spatial, + num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, + num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, + ) + lr_scheduler_spatial_list.append(lr_scheduler_spatial) + + else: + optimizer_spatial_list = [] + lr_scheduler_spatial_list = [] + + + + return [optimizer_temporal] + optimizer_spatial_list, [lr_scheduler_temporal] + lr_scheduler_spatial_list + +def sample_noise(latents, noise_strength, use_offset_noise=False): + b, c, f, *_ = latents.shape + noise_latents = torch.randn_like(latents, device=latents.device) + + if use_offset_noise: + offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device) + noise_latents = noise_latents + noise_strength * offset_noise + + return noise_latents + +def tensor_to_vae_latent(t, vae): + video_length = t.shape[1] + + t = rearrange(t, "b f c h w -> (b f) c h w") + latents = vae.encode(t).latent_dist.sample() + latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) + latents = latents * 0.18215 + + return latents + +def extend_datasets(datasets, dataset_items, extend=False): + biggest_data_len = max(x.__len__() for x in datasets) + extended = [] + for dataset in datasets: + if dataset.__len__() == 0: + del dataset + continue + if dataset.__len__() < biggest_data_len: + for item in dataset_items: + if extend and item not in extended and hasattr(dataset, item): + print(f"Extending {item}") + + value = getattr(dataset, item) + value *= biggest_data_len + value = value[:biggest_data_len] + + setattr(dataset, item, value) + + print(f"New {item} dataset length: {dataset.__len__()}") + extended.append(item) + +def get_train_dataset(dataset_types, train_data, tokenizer): + train_datasets = [] + + # Loop through all available datasets, get the name, then add to list of data to process. + for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]: + for dataset in dataset_types: + if dataset == DataSet.__getname__(): + train_datasets.append(DataSet(**train_data, tokenizer=tokenizer)) + + if len(train_datasets) > 0: + return train_datasets + else: + raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'") + +def prepare_data(config, tokenizer): + # Get the training dataset based on types (json, single_video, image) + + # Assuming config.dataset is a DictConfig object + dataset_params_dict = OmegaConf.to_container(config.dataset, resolve=True) + + # Remove the 'type' key + dataset_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist + + train_datasets = get_train_dataset(config.dataset.type, dataset_params_dict, tokenizer) + + # If you have extra train data, you can add a list of however many you would like. + # Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}] + try: + if config.train.extra_train_data is not None and len(config.train.extra_train_data) > 0: + for dataset in config.train.extra_train_data: + d_t = dataset.type + # Assuming config.dataset is a DictConfig object + dataset_params_dict = OmegaConf.to_container(dataset, resolve=True) + + # Remove the 'type' key + dataset_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist + t_d = dataset_params_dict + train_datasets += get_train_dataset(d_t, t_d, tokenizer) + + except Exception as e: + print(f"Could not process extra train datasets due to an error : {e}") + + # Extend datasets that are less than the greatest one. This allows for more balanced training. + attrs = ['train_data', 'frames', 'image_dir', 'video_files'] + extend_datasets(train_datasets, attrs, extend=config.train.extend_dataset) + + # Process one dataset + if len(train_datasets) == 1: + train_dataset = train_datasets[0] + + # Process many datasets + else: + train_dataset = torch.utils.data.ConcatDataset(train_datasets) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=config.train.train_batch_size, + shuffle=True + ) + + return train_dataloader, train_dataset + +# create parameters for optimziation +def prepare_params(unet, config, train_dataset): + extra_params = {} + + params = inject_motion_embeddings( + unet, + sizes=config.model.motion_embeddings.dim, + modules=config.model.motion_embeddings.module + ) + + if config.loss.type == "DebiasHybrid": + if config.loss.spatial_lora_num == -1: + config.loss.spatial_lora_num = train_dataset.__len__() + + lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_all = inject_spatial_loras( + unet=unet, + use_unet_lora=config.loss.use_unet_lora, + lora_unet_dropout=config.loss.lora_unet_dropout, + lora_path=config.loss.lora_path, + lora_rank=config.loss.lora_rank, + spatial_lora_num=config.loss.spatial_lora_num, + ) + + extra_params['lora_managers_spatial'] = lora_managers_spatial + extra_params['unet_lora_params_spatial_list'] = unet_lora_params_spatial_list + extra_params['unet_negation_all'] = unet_negation_all + + return params, extra_params \ No newline at end of file diff --git a/utils/lora.py b/utils/lora.py new file mode 100644 index 0000000..5ac36c1 --- /dev/null +++ b/utils/lora.py @@ -0,0 +1,1481 @@ +import json +import math +from itertools import groupby +import os +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import numpy as np +import PIL +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from safetensors.torch import safe_open + from safetensors.torch import save_file as safe_save + + safetensors_available = True +except ImportError: + from .safe_open import safe_open + + def safe_save( + tensors: Dict[str, torch.Tensor], + filename: str, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + raise EnvironmentError( + "Saving safetensors requires the safetensors library. Please install with pip or similar." + ) + + safetensors_available = False + +from diffusers.models.lora import LoRACompatibleLinear + +class LoraInjectedLinear(nn.Module): + def __init__( + self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 + ): + super().__init__() + + if r > min(in_features, out_features): + #raise ValueError( + # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + #) + print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}") + r = min(in_features, out_features) + + self.r = r + self.linear = nn.Linear(in_features, out_features, bias) + self.lora_down = nn.Linear(in_features, r, bias=False) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Linear(r, out_features, bias=False) + self.scale = scale + self.selector = nn.Identity() + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0): + return ( + self.linear(hidden_states) + + self.dropout(self.lora_up(self.selector(self.lora_down(hidden_states)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Linear(self.r, self.r, bias=False) + self.selector.weight.data = torch.diag(diag) + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class MultiLoraInjectedLinear(nn.Module): + def __init__( + self, in_features, out_features, bias=False, r=4, dropout_p=0.1, lora_num=1, scales=[1.0] + ): + super().__init__() + + if r > min(in_features, out_features): + #raise ValueError( + # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + #) + print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}") + r = min(in_features, out_features) + + self.r = r + self.linear = nn.Linear(in_features, out_features, bias) + + for i in range(lora_num): + if i==0: + self.lora_down =[nn.Linear(in_features, r, bias=False)] + self.dropout = [nn.Dropout(dropout_p)] + self.lora_up = [nn.Linear(r, out_features, bias=False)] + self.scale = scales[i] + self.selector = [nn.Identity()] + else: + self.lora_down.append(nn.Linear(in_features, r, bias=False)) + self.dropout.append( nn.Dropout(dropout_p)) + self.lora_up.append( nn.Linear(r, out_features, bias=False)) + self.scale.append(scales[i]) + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.linear(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Linear(self.r, self.r, bias=False) + self.selector.weight.data = torch.diag(diag) + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class LoraInjectedConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + r: int = 4, + dropout_p: float = 0.1, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}") + r = min(in_channels, out_channels) + + self.r = r + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.lora_down = nn.Conv2d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv2d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv2d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + +class LoraInjectedConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: (3, 1, 1), + padding: (1, 0, 0), + bias: bool = False, + r: int = 4, + dropout_p: float = 0, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}") + r = min(in_channels, out_channels) + + self.r = r + self.kernel_size = kernel_size + self.padding = padding + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.lora_down = nn.Conv3d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + bias=False, + padding=padding + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv3d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv3d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + +UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} + +UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} + +TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} + +TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} + +DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE + +EMBED_FLAG = "" + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = None, + # [ + # LoraInjectedLinear, + # LoraInjectedConv2d, + # LoraInjectedConv3d + # ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = ( + module + for name, module in model.named_modules() + if module.__class__.__name__ in ancestor_class # and ('transformer_in' not in name) + ) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + continue_flag = True + if 'Transformer2DModel' in ancestor_class and ('attn1' in fullname or 'ff' in fullname): + continue_flag = False + if 'TransformerTemporalModel' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname): + continue_flag = False + if continue_flag: + continue + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + while path: + parent = parent.get_submodule(path.pop(0)) + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): + continue + if name in ['lora_up', 'dropout', 'lora_down']: + continue + # Otherwise, yield it + yield parent, name, module + + +def _find_modules_old( + model, + ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], +): + ret = [] + for _module in model.modules(): + if _module.__class__.__name__ in ancestor_class: + + for name, _child_module in _module.named_modules(): + if _child_module.__class__ in search_class: + ret.append((_module, name, _child_module)) + print(ret) + return ret + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora( + model: nn.Module, + target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt + verbose: bool = False, + dropout_p: float = 0.0, + scale: float = 1.0, +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear] + ): + weight = _child_module.weight + bias = _child_module.bias + if verbose: + print("LoRA Injection : injecting lora into ", name) + print("LoRA Injection : weight shape", weight.shape) + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt + dropout_p: float = 0.0, + scale: float = 1.0, +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + if True: + for target_replace_module_i in target_replace_module: + for _module, name, _child_module in _find_modules( + model, [target_replace_module_i], search_class=[LoRACompatibleLinear, nn.Conv2d, nn.Conv3d] + ): + # if name == 'to_q': + # continue + if _child_module.__class__ == LoRACompatibleLinear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv3d( + _child_module.in_channels, + _child_module.out_channels, + bias=_child_module.bias is not None, + kernel_size=_child_module.kernel_size, + padding=_child_module.padding, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + else: + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv3d( + _child_module.in_channels, + _child_module.out_channels, + bias=_child_module.bias is not None, + kernel_size=_child_module.kernel_size, + padding=_child_module.padding, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_inferable_lora( + model, + lora_path='', + unet_replace_modules=["UNet3DConditionModel"], + text_encoder_replace_modules=["CLIPEncoderLayer"], + is_extended=False, + r=16 + ): + from transformers.models.clip import CLIPTextModel + from diffusers import UNet3DConditionModel + + def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel) + def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel" + + if os.path.exists(lora_path): + try: + for f in os.listdir(lora_path): + if f.endswith('.pt'): + lora_file = os.path.join(lora_path, f) + + if is_text_model(f): + monkeypatch_or_replace_lora( + model.text_encoder, + torch.load(lora_file), + target_replace_module=text_encoder_replace_modules, + r=r + ) + print("Successfully loaded Text Encoder LoRa.") + continue + + if is_unet(f): + monkeypatch_or_replace_lora_extended( + model.unet, + torch.load(lora_file), + target_replace_module=unet_replace_modules, + r=r + ) + print("Successfully loaded UNET LoRa.") + continue + + print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)") + + except Exception as e: + print(e) + print("Couldn't inject LoRA's due to an error.") + +def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): + + loras = [] + + for target_replace_module_i in target_replace_module: + + for _m, _n, _child_module in _find_modules( + model, + [target_replace_module_i], + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + loras.append((_child_module.lora_up, _child_module.lora_down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def extract_lora_child_module(model, target_replace_module=DEFAULT_TARGET_REPLACE): + + loras = [] + + for target_replace_module_i in target_replace_module: + + for _m, _n, _child_module in _find_modules( + model, + [target_replace_module_i], + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + loras.append(_child_module) + + return loras + +def extract_lora_as_tensor( + model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True +): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + up, down = _child_module.realize_as_lora() + if as_fp16: + up = up.to(torch.float16) + down = down.to(torch.float16) + + loras.append((up, down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def save_lora_weight( + model, + path="./lora.pt", + target_replace_module=DEFAULT_TARGET_REPLACE, + flag=None +): + weights = [] + for _up, _down in extract_lora_ups_down( + model, target_replace_module=target_replace_module + ): + weights.append(_up.weight.to("cpu").to(torch.float32)) + weights.append(_down.weight.to("cpu").to(torch.float32)) + if not flag: + torch.save(weights, path) + else: + weights_new=[] + for i in range(0, len(weights), 4): + subset = weights[i+(flag-1)*2:i+(flag-1)*2+2] + weights_new.extend(subset) + torch.save(weights_new, path) + +def save_lora_as_json(model, path="./lora.json"): + weights = [] + for _up, _down in extract_lora_ups_down(model): + weights.append(_up.weight.detach().cpu().numpy().tolist()) + weights.append(_down.weight.detach().cpu().numpy().tolist()) + + import json + + with open(path, "w") as f: + json.dump(weights, f) + + +def save_safeloras_with_embeds( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Saves the Lora from multiple modules in a single safetensor file. + + modelmap is a dictionary of { + "module name": (module, target_replace_module) + } + """ + weights = {} + metadata = {} + + for name, (model, target_replace_module) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + for i, (_up, _down) in enumerate( + extract_lora_as_tensor(model, target_replace_module) + ): + rank = _down.shape[0] + + metadata[f"{name}:{i}:rank"] = str(rank) + weights[f"{name}:{i}:up"] = _up + weights[f"{name}:{i}:down"] = _down + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def save_safeloras( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + outpath="./lora.safetensors", +): + return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def convert_loras_to_safeloras_with_embeds( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Converts the Lora from multiple pytorch .pt files into a single safetensor file. + + modelmap is a dictionary of { + "module name": (pytorch_model_path, target_replace_module, rank) + } + """ + + weights = {} + metadata = {} + + for name, (path, target_replace_module, r) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + lora = torch.load(path) + for i, weight in enumerate(lora): + is_up = i % 2 == 0 + i = i // 2 + + if is_up: + metadata[f"{name}:{i}:rank"] = str(r) + weights[f"{name}:{i}:up"] = weight + else: + weights[f"{name}:{i}:down"] = weight + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def convert_loras_to_safeloras( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + outpath="./lora.safetensors", +): + convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def parse_safeloras( + safeloras, +) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: + """ + Converts a loaded safetensor file that contains a set of module Loras + into Parameters and other information + + Output is a dictionary of { + "module name": ( + [list of weights], + [list of ranks], + target_replacement_modules + ) + } + """ + loras = {} + metadata = safeloras.metadata() + + get_name = lambda k: k.split(":")[0] + + keys = list(safeloras.keys()) + keys.sort(key=get_name) + + for name, module_keys in groupby(keys, get_name): + info = metadata.get(name) + + if not info: + raise ValueError( + f"Tensor {name} has no metadata - is this a Lora safetensor?" + ) + + # Skip Textual Inversion embeds + if info == EMBED_FLAG: + continue + + # Handle Loras + # Extract the targets + target = json.loads(info) + + # Build the result lists - Python needs us to preallocate lists to insert into them + module_keys = list(module_keys) + ranks = [4] * (len(module_keys) // 2) + weights = [None] * len(module_keys) + + for key in module_keys: + # Split the model name and index out of the key + _, idx, direction = key.split(":") + idx = int(idx) + + # Add the rank + ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) + + # Insert the weight into the list + idx = idx * 2 + (1 if direction == "down" else 0) + weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) + + loras[name] = (weights, ranks, target) + + return loras + + +def parse_safeloras_embeds( + safeloras, +) -> Dict[str, torch.Tensor]: + """ + Converts a loaded safetensor file that contains Textual Inversion embeds into + a dictionary of embed_token: Tensor + """ + embeds = {} + metadata = safeloras.metadata() + + for key in safeloras.keys(): + # Only handle Textual Inversion embeds + meta = metadata.get(key) + if not meta or meta != EMBED_FLAG: + continue + + embeds[key] = safeloras.get_tensor(key) + + return embeds + + +def load_safeloras(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras) + + +def load_safeloras_embeds(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras_embeds(safeloras) + + +def load_safeloras_both(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras) + + +def collapse_lora(model, alpha=1.0): + + for _module, name, _child_module in _find_modules( + model, + UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE, + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + + if isinstance(_child_module, LoraInjectedLinear): + print("Collapsing Lin Lora in", name) + + _child_module.linear.weight = nn.Parameter( + _child_module.linear.weight.data + + alpha + * ( + _child_module.lora_up.weight.data + @ _child_module.lora_down.weight.data + ) + .type(_child_module.linear.weight.dtype) + .to(_child_module.linear.weight.device) + ) + + else: + print("Collapsing Conv Lora in", name) + _child_module.conv.weight = nn.Parameter( + _child_module.conv.weight.data + + alpha + * ( + _child_module.lora_up.weight.data.flatten(start_dim=1) + @ _child_module.lora_down.weight.data.flatten(start_dim=1) + ) + .reshape(_child_module.conv.weight.data.shape) + .type(_child_module.conv.weight.dtype) + .to(_child_module.conv.weight.device) + ) + + +def monkeypatch_or_replace_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] + ): + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_lora_extended( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, + target_replace_module, + search_class=[ + nn.Linear, + nn.Conv2d, + nn.Conv3d, + LoraInjectedLinear, + LoraInjectedConv2d, + LoraInjectedConv3d, + ], + ): + + if (_child_module.__class__ == nn.Linear) or ( + _child_module.__class__ == LoraInjectedLinear + ): + if len(loras[0].shape) != 2: + continue + + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + elif (_child_module.__class__ == nn.Conv2d) or ( + _child_module.__class__ == LoraInjectedConv2d + ): + if len(loras[0].shape) != 4: + continue + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv2d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv2d( + _source.in_channels, + _source.out_channels, + _source.kernel_size, + _source.stride, + _source.padding, + _source.dilation, + _source.groups, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d or( + _child_module.__class__ == LoraInjectedConv3d + ): + + if len(loras[0].shape) != 5: + continue + + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv3d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv3d( + _source.in_channels, + _source.out_channels, + bias=_source.bias is not None, + kernel_size=_source.kernel_size, + padding=_source.padding, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_safeloras(models, safeloras): + loras = parse_safeloras(safeloras) + + for name, (lora, ranks, target) in loras.items(): + model = getattr(models, name, None) + + if not model: + print(f"No model provided for {name}, contained in Lora") + continue + + monkeypatch_or_replace_lora_extended(model, lora, target, ranks) + + +def monkeypatch_remove_lora(model): + for _module, name, _child_module in _find_modules( + model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d] + ): + if isinstance(_child_module, LoraInjectedLinear): + _source = _child_module.linear + weight, bias = _source.weight, _source.bias + + _tmp = nn.Linear( + _source.in_features, _source.out_features, bias is not None + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + else: + _source = _child_module.conv + weight, bias = _source.weight, _source.bias + + if isinstance(_source, nn.Conv2d): + _tmp = nn.Conv2d( + in_channels=_source.in_channels, + out_channels=_source.out_channels, + kernel_size=_source.kernel_size, + stride=_source.stride, + padding=_source.padding, + dilation=_source.dilation, + groups=_source.groups, + bias=bias is not None, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + if isinstance(_source, nn.Conv3d): + _tmp = nn.Conv3d( + _source.in_channels, + _source.out_channels, + bias=_source.bias is not None, + kernel_size=_source.kernel_size, + padding=_source.padding, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + _module._modules[name] = _tmp + + +def monkeypatch_add_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + alpha: float = 1.0, + beta: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[LoraInjectedLinear] + ): + weight = _child_module.linear.weight + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_up.weight.to(weight.device) * beta + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_down.weight.to(weight.device) * beta + ) + + _module._modules[name].to(weight.device) + + +def tune_lora_scale(model, alpha: float = 1.0): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]: + _module.scale = alpha + + +def set_lora_diag(model, diag: torch.Tensor): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]: + _module.set_selector_from_diag(diag) + + +def _text_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) + + +def _ti_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["ti", "pt"]) + + +def apply_learned_embed_in_clip( + learned_embeds, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + if isinstance(token, str): + trained_tokens = [token] + elif isinstance(token, list): + assert len(learned_embeds.keys()) == len( + token + ), "The number of tokens and the number of embeds should be the same" + trained_tokens = token + else: + trained_tokens = list(learned_embeds.keys()) + + for token in trained_tokens: + print(token) + embeds = learned_embeds[token] + + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + num_added_tokens = tokenizer.add_tokens(token) + + i = 1 + if not idempotent: + while num_added_tokens == 0: + print(f"The tokenizer already contains the token {token}.") + token = f"{token[:-1]}-{i}>" + print(f"Attempting to add the token {token}.") + num_added_tokens = tokenizer.add_tokens(token) + i += 1 + elif num_added_tokens == 0 and idempotent: + print(f"The tokenizer already contains the token {token}.") + print(f"Replacing {token} embedding.") + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + return token + + +def load_learned_embed_in_clip( + learned_embeds_path, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + learned_embeds = torch.load(learned_embeds_path) + apply_learned_embed_in_clip( + learned_embeds, text_encoder, tokenizer, token, idempotent + ) + + +def patch_pipe( + pipe, + maybe_unet_path, + token: Optional[str] = None, + r: int = 4, + patch_unet=True, + patch_text=True, + patch_ti=True, + idempotent_token=True, + unet_target_replace_module=DEFAULT_TARGET_REPLACE, + text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, +): + if maybe_unet_path.endswith(".pt"): + # torch format + + if maybe_unet_path.endswith(".ti.pt"): + unet_path = maybe_unet_path[:-6] + ".pt" + elif maybe_unet_path.endswith(".text_encoder.pt"): + unet_path = maybe_unet_path[:-16] + ".pt" + else: + unet_path = maybe_unet_path + + ti_path = _ti_lora_path(unet_path) + text_path = _text_lora_path(unet_path) + + if patch_unet: + print("LoRA : Patching Unet") + monkeypatch_or_replace_lora( + pipe.unet, + torch.load(unet_path), + r=r, + target_replace_module=unet_target_replace_module, + ) + + if patch_text: + print("LoRA : Patching text encoder") + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=text_target_replace_module, + r=r, + ) + if patch_ti: + print("LoRA : Patching token input") + token = load_learned_embed_in_clip( + ti_path, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + + elif maybe_unet_path.endswith(".safetensors"): + safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu") + monkeypatch_or_replace_safeloras(pipe, safeloras) + tok_dict = parse_safeloras_embeds(safeloras) + if patch_ti: + apply_learned_embed_in_clip( + tok_dict, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + return tok_dict + + +def train_patch_pipe(pipe, patch_unet, patch_text): + if patch_unet: + print("LoRA : Patching Unet") + collapse_lora(pipe.unet) + monkeypatch_remove_lora(pipe.unet) + + if patch_text: + print("LoRA : Patching text encoder") + + collapse_lora(pipe.text_encoder) + monkeypatch_remove_lora(pipe.text_encoder) + +@torch.no_grad() +def inspect_lora(model): + moved = {} + + for name, _module in model.named_modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]: + ups = _module.lora_up.weight.data.clone() + downs = _module.lora_down.weight.data.clone() + + wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1) + + dist = wght.flatten().abs().mean().item() + if name in moved: + moved[name].append(dist) + else: + moved[name] = [dist] + + return moved + + +def save_all( + unet, + text_encoder, + save_path, + placeholder_token_ids=None, + placeholder_tokens=None, + save_lora=True, + save_ti=True, + target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + target_replace_module_unet=DEFAULT_TARGET_REPLACE, + safe_form=True, +): + if not safe_form: + # save ti + if save_ti: + ti_path = _ti_lora_path(save_path) + learned_embeds_dict = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + learned_embeds_dict[tok] = learned_embeds.detach().cpu() + + torch.save(learned_embeds_dict, ti_path) + print("Ti saved to ", ti_path) + + # save text encoder + if save_lora: + save_lora_weight( + unet, save_path, target_replace_module=target_replace_module_unet + ) + print("Unet saved to ", save_path) + + save_lora_weight( + text_encoder, + _text_lora_path(save_path), + target_replace_module=target_replace_module_text, + ) + print("Text Encoder saved to ", _text_lora_path(save_path)) + + else: + assert save_path.endswith( + ".safetensors" + ), f"Save path : {save_path} should end with .safetensors" + + loras = {} + embeds = {} + + if save_lora: + + loras["unet"] = (unet, target_replace_module_unet) + loras["text_encoder"] = (text_encoder, target_replace_module_text) + + if save_ti: + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + embeds[tok] = learned_embeds.detach().cpu() + + save_safeloras_with_embeds(loras, embeds, save_path) \ No newline at end of file diff --git a/utils/lora_handler.py b/utils/lora_handler.py new file mode 100644 index 0000000..8c654fe --- /dev/null +++ b/utils/lora_handler.py @@ -0,0 +1,294 @@ +import os +from logging import warnings +import torch +from typing import Union +from types import SimpleNamespace +from models.unet.unet_3d_condition import UNet3DConditionModel +from transformers import CLIPTextModel +from .convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20 + +from .lora import ( + extract_lora_ups_down, + inject_trainable_lora_extended, + save_lora_weight, + train_patch_pipe, + monkeypatch_or_replace_lora, + monkeypatch_or_replace_lora_extended +) + + +FILE_BASENAMES = ['unet', 'text_encoder'] +LORA_FILE_TYPES = ['.pt', '.safetensors'] +CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r'] +STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias'] + +lora_versions = dict( + stable_lora = "stable_lora", + cloneofsimo = "cloneofsimo" +) + +lora_func_types = dict( + loader = "loader", + injector = "injector" +) + +lora_args = dict( + model = None, + loras = None, + target_replace_module = [], + target_module = [], + r = 4, + search_class = [torch.nn.Linear], + dropout = 0, + lora_bias = 'none' +) + +LoraVersions = SimpleNamespace(**lora_versions) +LoraFuncTypes = SimpleNamespace(**lora_func_types) + +LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] +LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] + +def filter_dict(_dict, keys=[]): + if len(keys) == 0: + assert "Keys cannot empty for filtering return dict." + + for k in keys: + if k not in lora_args.keys(): + assert f"{k} does not exist in available LoRA arguments" + + return {k: v for k, v in _dict.items() if k in keys} + +class LoraHandler(object): + def __init__( + self, + version: LORA_VERSIONS = LoraVersions.cloneofsimo, + use_unet_lora: bool = False, + use_text_lora: bool = False, + save_for_webui: bool = False, + only_for_webui: bool = False, + lora_bias: str = 'none', + unet_replace_modules: list = None, + text_encoder_replace_modules: list = None + ): + self.version = version + self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) + self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) + self.lora_bias = lora_bias + self.use_unet_lora = use_unet_lora + self.use_text_lora = use_text_lora + self.save_for_webui = save_for_webui + self.only_for_webui = only_for_webui + self.unet_replace_modules = unet_replace_modules + self.text_encoder_replace_modules = text_encoder_replace_modules + self.use_lora = any([use_text_lora, use_unet_lora]) + + def is_cloneofsimo_lora(self): + return self.version == LoraVersions.cloneofsimo + + + def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader): + + if self.is_cloneofsimo_lora(): + + if func_type == LoraFuncTypes.loader: + return monkeypatch_or_replace_lora_extended + + if func_type == LoraFuncTypes.injector: + return inject_trainable_lora_extended + + assert "LoRA Version does not exist." + + def check_lora_ext(self, lora_file: str): + return lora_file.endswith(tuple(LORA_FILE_TYPES)) + + def get_lora_file_path( + self, + lora_path: str, + model: Union[UNet3DConditionModel, CLIPTextModel] + ): + if os.path.exists(lora_path): + lora_filenames = [fns for fns in os.listdir(lora_path)] + is_lora = self.check_lora_ext(lora_path) + + is_unet = isinstance(model, UNet3DConditionModel) + is_text = isinstance(model, CLIPTextModel) + idx = 0 if is_unet else 1 + + base_name = FILE_BASENAMES[idx] + + for lora_filename in lora_filenames: + is_lora = self.check_lora_ext(lora_filename) + if not is_lora: + continue + + if base_name in lora_filename: + return os.path.join(lora_path, lora_filename) + + return None + + def handle_lora_load(self, file_name:str, lora_loader_args: dict = None): + self.lora_loader(**lora_loader_args) + print(f"Successfully loaded LoRA from: {file_name}") + + def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,): + try: + lora_file = self.get_lora_file_path(lora_path, model) + + if lora_file is not None: + lora_loader_args.update({"lora_path": lora_file}) + self.handle_lora_load(lora_file, lora_loader_args) + + else: + print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...") + + except Exception as e: + print(f"An error occurred while loading a LoRA file: {e}") + + def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale): + return_dict = lora_args.copy() + + if self.is_cloneofsimo_lora(): + return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) + return_dict.update({ + "model": model, + "loras": self.get_lora_file_path(lora_path, model), + "target_replace_module": replace_modules, + "r": r, + "scale": scale, + "dropout_p": dropout, + }) + + return return_dict + + def do_lora_injection( + self, + model, + replace_modules, + bias='none', + dropout=0, + r=4, + lora_loader_args=None, + ): + REPLACE_MODULES = replace_modules + + params = None + negation = None + is_injection_hybrid = False + + if self.is_cloneofsimo_lora(): + is_injection_hybrid = True + injector_args = lora_loader_args + + params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended + for _up, _down in extract_lora_ups_down( + model, + target_replace_module=REPLACE_MODULES): + + if all(x is not None for x in [_up, _down]): + print(f"Lora successfully injected into {model.__class__.__name__}.") + + break + + return params, negation, is_injection_hybrid + + return params, negation, is_injection_hybrid + + def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0): + + params = None + negation = None + + lora_loader_args = self.get_lora_func_args( + lora_path, + use_lora, + model, + replace_modules, + r, + dropout, + self.lora_bias, + scale + ) + + if use_lora: + params, negation, is_injection_hybrid = self.do_lora_injection( + model, + replace_modules, + bias=self.lora_bias, + lora_loader_args=lora_loader_args, + dropout=dropout, + r=r + ) + + if not is_injection_hybrid: + self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args) + + params = model if params is None else params + return params, negation + + def save_cloneofsimo_lora(self, model, save_path, step, flag): + + def save_lora(model, name, condition, replace_modules, step, save_path, flag=None): + if condition and replace_modules is not None: + save_path = f"{save_path}/{step}_{name}.pt" + save_lora_weight(model, save_path, replace_modules, flag) + + save_lora( + model.unet, + FILE_BASENAMES[0], + self.use_unet_lora, + self.unet_replace_modules, + step, + save_path, + flag + ) + save_lora( + model.text_encoder, + FILE_BASENAMES[1], + self.use_text_lora, + self.text_encoder_replace_modules, + step, + save_path, + flag + ) + + # train_patch_pipe(model, self.use_unet_lora, self.use_text_lora) + + def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None): + save_path = f"{save_path}/lora" + os.makedirs(save_path, exist_ok=True) + + if self.is_cloneofsimo_lora(): + if any([self.save_for_webui, self.only_for_webui]): + warnings.warn( + """ + You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention. + Only 'stable_lora' is supported for saving to a compatible webui file. + """ + ) + self.save_cloneofsimo_lora(model, save_path, step, flag) + + + +def inject_spatial_loras(unet, use_unet_lora, lora_unet_dropout, lora_path, lora_rank, spatial_lora_num): + + lora_managers_spatial = [] + unet_lora_params_spatial_list = [] + for i in range(spatial_lora_num): + lora_manager_spatial = LoraHandler( + use_unet_lora=use_unet_lora, + unet_replace_modules=["Transformer2DModel"] + ) + lora_managers_spatial.append(lora_manager_spatial) + + unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model( + use_unet_lora, + unet, + lora_manager_spatial.unet_replace_modules, + lora_unet_dropout, + lora_path + '/spatial/lora/', + r=lora_rank + ) + unet_lora_params_spatial_list.append(unet_lora_params_spatial) + + return lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_spatial \ No newline at end of file diff --git a/utils/pe_utils.py b/utils/pe_utils.py deleted file mode 100644 index 17b9191..0000000 --- a/utils/pe_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -import re -import torch -from torch import nn -from diffusers.models.embeddings import SinusoidalPositionalEmbedding - -class SinusoidalPositionalEmbeddingForInversion(nn.Module): - """Apply positional information to a sequence of embeddings. - - Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to - them - - Args: - embed_dim: (int): Dimension of the positional embedding. - max_seq_length: Maximum sequence length to apply positional embeddings - - """ - - def __init__(self, pe=None, embed_dim: int = None, max_seq_length: int = 32, dtype=torch.float16): - super().__init__() - if pe is not None: - self.pe = nn.Parameter(pe.to(dtype)) - else: - self.pe = nn.Parameter(torch.zeros(1, max_seq_length, embed_dim).to(dtype)) - - def forward(self, x): - batch_size, seq_length, _ = x.shape - # if seq_length != 16: - - # pe_interpolated = self.pe[:, :16].permute(0, 2, 1) - - # pe_interpolated = F.interpolate(pe_interpolated, size=seq_length, mode='linear', align_corners=False) - - # pe_interpolated = pe_interpolated.permute(0, 2, 1) - - # pe_interpolated = pe_interpolated.expand(batch_size, -1, -1) - - # x = x + pe_interpolated - # else: - x = x + self.pe[:, :seq_length] - - return x - -def replace_positional_embedding(model,target_size=[320, 640, 1280], target_module=['up','down','mid']): - replacement_dict = {} - - # First, identify all modules that need to be replaced - for name, module in model.named_modules(): - if isinstance(module, SinusoidalPositionalEmbedding): - replacement_dict[name] = SinusoidalPositionalEmbeddingForInversion(pe=module.pe, dtype=model.dtype) - - # Now, replace the identified modules - for name, new_module in replacement_dict.items(): - parent_name = name.rsplit('.', 1)[0] if '.' in name else '' - module_name = name.rsplit('.', 1)[-1] - parent_module = model - if parent_name: - parent_module = dict(model.named_modules())[parent_name] - - if new_module.pe.shape[-1] in target_size and parent_name.split('_')[0] in target_module: - setattr(parent_module, module_name, new_module) - -def replace_positional_embedding_unet3d(model,target_size=[320, 640, 1280], target_module=['up','down','mid']): - replacement_dict = {} - - # First, identify all modules that need to be replaced - for name, module in model.named_modules(): - if 'temp_attention' in name and re.search(r'transformer_blocks\.\d+$', name): - replacement_dict[f'{name}.pos_embed'] = SinusoidalPositionalEmbeddingForInversion(embed_dim=module.norm1.normalized_shape[0], dtype=model.dtype) - - # Now, replace the identified modules - for name, new_module in replacement_dict.items(): - parent_name = name.rsplit('.', 1)[0] if '.' in name else '' - module_name = name.rsplit('.', 1)[-1] - parent_module = model - if parent_name: - parent_module = dict(model.named_modules())[parent_name] - - if new_module.pe.shape[-1] in target_size and parent_name.split('_')[0] in target_module: - setattr(parent_module, module_name, new_module) - -def save_positional_embeddings(model, file_path): - # Extract positional embeddings from all instances of SinusoidalPositionalEmbeddingForInversion - positional_embeddings = { - name: module.pe - for name, module in model.named_modules() - if isinstance(module, SinusoidalPositionalEmbeddingForInversion) - } - # Save the positional embeddings to the specified file path - torch.save(positional_embeddings, file_path) - -# def load_positional_embeddings(model, file_path): -# # Load the positional embeddings from the file -# saved_embeddings = torch.load(file_path) -# # Assign the loaded embeddings back to the corresponding modules in the model -# for name, module in model.named_modules(): -# if isinstance(module, SinusoidalPositionalEmbeddingForInversion): -# module.pe.data.copy_(saved_embeddings[name].data) - - -# def load_positional_embeddings(model, file_path): -# # Load the positional embeddings from the file -# saved_embeddings = torch.load(file_path) -# # Assign the loaded embeddings back to the corresponding modules in the model -# for name, module in model.named_modules(): -# if isinstance(module, SinusoidalPositionalEmbeddingForInversion): -# module.pe.data.copy_(saved_embeddings[name].data) - - - -def load_positional_embedding(model,file_path): - replacement_dict = {} - saved_embeddings = torch.load(file_path) - - # First, identify all modules that need to be replaced - # for name, module in model.named_modules(): - # if 'temp_attention' in name and re.search(r'transformer_blocks\.\d+$', name): - # replacement_dict[f'{name}.pos_embed'] = SinusoidalPositionalEmbeddingForInversion(pe=saved_embeddings[f'{name}.pos_embed'].data, dtype=model.dtype) - - for key in saved_embeddings.keys(): - replacement_dict[key] = SinusoidalPositionalEmbeddingForInversion(pe=saved_embeddings[key].data, dtype=model.dtype) - - - # Now, replace the identified modules - for name, new_module in replacement_dict.items(): - parent_name = name.rsplit('.', 1)[0] if '.' in name else '' - module_name = name.rsplit('.', 1)[-1] - parent_module = model - if parent_name: - parent_module = dict(model.named_modules())[parent_name] - # if new_module.pe.shape[-1] in target_size and parent_name.split('_')[0] in target_module: - setattr(parent_module, module_name, new_module) \ No newline at end of file