From 3ee90fa04d032b3c96c2c94014c9d55de89b8aae Mon Sep 17 00:00:00 2001 From: yujun <50394665+JunnYu@users.noreply.github.com> Date: Wed, 12 Apr 2023 11:57:25 +0800 Subject: [PATCH] [PPDiffusers] update text2img_laion400M (#5619) * update * update --- .../data/filelist/laion_aes.filelist | 50 ++++++++++++++ .../data/filelist/laion_aes.filelist.list | 1 + .../text_to_image_laion400m/ldm/ldm_args.py | 1 + .../text_to_image_laion400m/ldm/model.py | 67 +++++++++++++------ .../ldm/text_image_pair_dataset.py | 17 +++-- .../train_txt2img_laion400m_trainer.py | 16 ++++- .../ppdiffusers/patches/ppnlp_patch_utils.py | 19 ++++-- .../ppdiffusers/pipelines/pipeline_utils.py | 8 ++- ppdiffusers/ppdiffusers/utils/paddle_utils.py | 4 ++ 9 files changed, 148 insertions(+), 35 deletions(-) create mode 100644 ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist create mode 100644 ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist.list diff --git a/ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist b/ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist new file mode 100644 index 000000000000..86b0e5191d63 --- /dev/null +++ b/ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist @@ -0,0 +1,50 @@ +/root/laion_aes/part-00000 +/root/laion_aes/part-00001 +/root/laion_aes/part-00002 +/root/laion_aes/part-00003 +/root/laion_aes/part-00004 +/root/laion_aes/part-00005 +/root/laion_aes/part-00006 +/root/laion_aes/part-00007 +/root/laion_aes/part-00008 +/root/laion_aes/part-00009 +/root/laion_aes/part-00010 +/root/laion_aes/part-00011 +/root/laion_aes/part-00012 +/root/laion_aes/part-00013 +/root/laion_aes/part-00014 +/root/laion_aes/part-00015 +/root/laion_aes/part-00016 +/root/laion_aes/part-00017 +/root/laion_aes/part-00018 +/root/laion_aes/part-00019 +/root/laion_aes/part-00020 +/root/laion_aes/part-00021 +/root/laion_aes/part-00022 +/root/laion_aes/part-00023 +/root/laion_aes/part-00024 +/root/laion_aes/part-00025 +/root/laion_aes/part-00026 +/root/laion_aes/part-00027 +/root/laion_aes/part-00028 +/root/laion_aes/part-00029 +/root/laion_aes/part-00030 +/root/laion_aes/part-00031 +/root/laion_aes/part-00032 +/root/laion_aes/part-00033 +/root/laion_aes/part-00034 +/root/laion_aes/part-00035 +/root/laion_aes/part-00036 +/root/laion_aes/part-00037 +/root/laion_aes/part-00038 +/root/laion_aes/part-00039 +/root/laion_aes/part-00040 +/root/laion_aes/part-00041 +/root/laion_aes/part-00042 +/root/laion_aes/part-00043 +/root/laion_aes/part-00044 +/root/laion_aes/part-00045 +/root/laion_aes/part-00046 +/root/laion_aes/part-00047 +/root/laion_aes/part-00048 +/root/laion_aes/part-00049 \ No newline at end of file diff --git a/ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist.list b/ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist.list new file mode 100644 index 000000000000..0e36e494e2a3 --- /dev/null +++ b/ppdiffusers/examples/text_to_image_laion400m/data/filelist/laion_aes.filelist.list @@ -0,0 +1 @@ +./data/filelist/laion_aes.filelist diff --git a/ppdiffusers/examples/text_to_image_laion400m/ldm/ldm_args.py b/ppdiffusers/examples/text_to_image_laion400m/ldm/ldm_args.py index 3a6c219bda69..9dfc6f63a59f 100644 --- a/ppdiffusers/examples/text_to_image_laion400m/ldm/ldm_args.py +++ b/ppdiffusers/examples/text_to_image_laion400m/ldm/ldm_args.py @@ -49,6 +49,7 @@ class ModelArguments: enable_xformers_memory_efficient_attention: bool = field( default=False, metadata={"help": "enable_xformers_memory_efficient_attention."} ) + to_static: Optional[bool] = field(default=False, metadata={"help": "Whether or not to_static"}) @dataclass diff --git a/ppdiffusers/examples/text_to_image_laion400m/ldm/model.py b/ppdiffusers/examples/text_to_image_laion400m/ldm/model.py index 59b90a75548f..bbba9e2f889d 100644 --- a/ppdiffusers/examples/text_to_image_laion400m/ldm/model.py +++ b/ppdiffusers/examples/text_to_image_laion400m/ldm/model.py @@ -115,15 +115,18 @@ def __init__(self, model_args): self.noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) - self.eval_scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - steps_offset=1, - ) - self.eval_scheduler.set_timesteps(model_args.num_inference_steps) + self.register_buffer("alphas_cumprod", self.noise_scheduler.alphas_cumprod) + + if model_args.image_logging_steps > 0: + self.eval_scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + self.eval_scheduler.set_timesteps(model_args.num_inference_steps) self.init_weights() self.use_ema = model_args.use_ema if self.use_ema: @@ -138,6 +141,30 @@ def __init__(self, model_args): f" correctly and a GPU is available: {e}" ) + # make sure unet text_encoder in train mode, vae in eval mode + self.unet.train() + self.text_encoder.train() + self.vae.eval() + + def add_noise( + self, + original_samples: paddle.Tensor, + noise: paddle.Tensor, + timesteps: paddle.Tensor, + ) -> paddle.Tensor: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + def init_weights(self): # init text_encoder if not self.text_encoder_is_pretrained: @@ -180,17 +207,17 @@ def on_train_batch_end(self): self.model_ema(self.unet) def forward(self, input_ids=None, pixel_values=None, **kwargs): - self.train() - with paddle.amp.auto_cast(enable=False): - with paddle.no_grad(): - self.vae.eval() - latents = self.vae.encode(pixel_values).latent_dist.sample() - latents = latents * 0.18215 - noise = paddle.randn(latents.shape) - timesteps = paddle.randint(0, self.noise_scheduler.num_train_timesteps, (latents.shape[0],)).astype( - "int64" - ) - noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + with paddle.no_grad(): + # TODO add this + # with paddle.amp.auto_cast(enable=False): + self.vae.eval() + latents = self.vae.encode(pixel_values).latent_dist.sample() + latents = latents * 0.18215 + noise = paddle.randn(latents.shape) + timesteps = paddle.randint(0, self.noise_scheduler.num_train_timesteps, (latents.shape[0],)).astype( + "int64" + ) + noisy_latents = self.add_noise(latents, noise, timesteps) encoder_hidden_states = self.text_encoder(input_ids)[0] noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample diff --git a/ppdiffusers/examples/text_to_image_laion400m/ldm/text_image_pair_dataset.py b/ppdiffusers/examples/text_to_image_laion400m/ldm/text_image_pair_dataset.py index dcdf6b23b047..b41f0b799469 100644 --- a/ppdiffusers/examples/text_to_image_laion400m/ldm/text_image_pair_dataset.py +++ b/ppdiffusers/examples/text_to_image_laion400m/ldm/text_image_pair_dataset.py @@ -15,6 +15,7 @@ import base64 import gzip import io +import json import random import numpy as np @@ -30,7 +31,9 @@ def parse_line(line, filename): def parse_src(filename): - if "laion400m" in filename: + if "laion_aes" in filename: + return "laion_aes" + elif "laion400m" in filename: return "laion400m" else: raise NotImplementedError(f"Unkown data source, {filename}") @@ -40,6 +43,10 @@ def parse_src(filename): data_source = parse_src(filename) if data_source == "laion400m": caption, _, img_b64 = vec[:3] + elif data_source == "laion_aes": + text_json = json.loads(vec[2]) + img_b64 = vec[5] + caption = text_json.get("caption_en", text_json.get("blip_caption_en", "")) else: _, captions, _, _, _, img_b64 = vec[:6] caption = random.sample(captions.split("|"), 1)[0].replace("\1", "") @@ -135,7 +142,7 @@ def sample_loader(self, file_ids, filenames): for i in file_ids: filename = filenames[i].strip("\n") with gzip.open(filename, "rb") if filename.endswith(".gz") else open(filename, "rb") as f: - retry = 0 + # retry = 0 while True: line = f.readline() @@ -151,9 +158,9 @@ def sample_loader(self, file_ids, filenames): continue data = parse_line(line, filename) if data is None: - retry += 1 - if retry > 100: - break + # retry += 1 + # if retry > 100: + # break continue else: w, h = data["image"].size diff --git a/ppdiffusers/examples/text_to_image_laion400m/train_txt2img_laion400m_trainer.py b/ppdiffusers/examples/text_to_image_laion400m/train_txt2img_laion400m_trainer.py index cf5420e0d7b2..edc69bcb77c6 100644 --- a/ppdiffusers/examples/text_to_image_laion400m/train_txt2img_laion400m_trainer.py +++ b/ppdiffusers/examples/text_to_image_laion400m/train_txt2img_laion400m_trainer.py @@ -34,9 +34,13 @@ def main(): # report to custom_visualdl training_args.report_to = ["custom_visualdl"] training_args.resolution = data_args.resolution + training_args.image_logging_steps = model_args.image_logging_steps = ( - math.ceil(model_args.image_logging_steps / training_args.logging_steps) * training_args.logging_steps + (math.ceil(model_args.image_logging_steps / training_args.logging_steps) * training_args.logging_steps) + if model_args.image_logging_steps > 0 + else -1 ) + training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") @@ -68,6 +72,16 @@ def main(): tokenizer=model.tokenizer, ) + if model_args.to_static: + input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, model_args.model_max_length], dtype="int64") + pixel_values = paddle.static.InputSpec( + name="pixel_values", shape=[-1, 3, data_args.resolution, data_args.resolution], dtype="float32" + ) + specs = [input_ids, pixel_values] + paddle.jit.ignore_module([os]) + model = paddle.jit.to_static(model, input_spec=specs) + logger.info("Successfully to apply @to_static with specs: {}".format(specs)) + trainer = LatentDiffusionTrainer( model=model, args=training_args, train_dataset=train_dataset, tokenizer=model.tokenizer ) diff --git a/ppdiffusers/ppdiffusers/patches/ppnlp_patch_utils.py b/ppdiffusers/ppdiffusers/patches/ppnlp_patch_utils.py index 96daf42cd4bb..21f68b584574 100644 --- a/ppdiffusers/ppdiffusers/patches/ppnlp_patch_utils.py +++ b/ppdiffusers/ppdiffusers/patches/ppnlp_patch_utils.py @@ -127,6 +127,7 @@ def permute_pt(x, *perm: builtins.int, name=None): # patch repeat_interleave raw_repeat_interleave = paddle.repeat_interleave + @paddle.jit.not_to_static def repeat_interleave(x, repeats, axis=None, name=None): fp16 = False if x.dtype == paddle.float16: @@ -145,6 +146,7 @@ def repeat_interleave(x, repeats, axis=None, name=None): # patch max raw_max = paddle.max + @paddle.jit.not_to_static def max(x, axis=None, keepdim=False, name=None): fp16 = False if x.dtype == paddle.float16: @@ -163,6 +165,7 @@ def max(x, axis=None, keepdim=False, name=None): # patch gather_nd support bfloat16 raw_gather_nd = paddle.gather_nd + @paddle.jit.not_to_static def gather_nd(x, index, name=None): bfp16 = False if x.dtype == paddle.bfloat16: @@ -180,7 +183,6 @@ def gather_nd(x, index, name=None): paddle.Tensor.contiguous = lambda x: x # must return self! - @patch_to(nn.Layer) def eval(self): # Layer-level setting self.training = False @@ -188,13 +190,16 @@ def eval(self): layer.training = False return self - @patch_to(nn) + nn.Layer.eval = eval + def Parameter(data: paddle.Tensor, requires_grad=True): tensor = paddle.create_parameter(data.shape, dtype=data.dtype, default_initializer=nn.initializer.Assign(data)) if not requires_grad: tensor.stop_gradient = True return tensor + nn.Parameter = Parameter + @contextlib.contextmanager def device_scope(device="cpu"): new_device = device.replace("cuda", "gpu") @@ -207,7 +212,6 @@ def device_scope(device="cpu"): paddle.device_scope = device_scope - @patch_to(nn.Layer) def get_sublayer(self, target: str): if target == "": return self @@ -225,6 +229,8 @@ def get_sublayer(self, target: str): raise AttributeError("`" + item + "` is not " "an nn.Layer") return mod + nn.Layer.get_sublayer = get_sublayer + class _WrappedHook: def __init__(self, hook: Callable, module: Optional["nn.Layer"] = None): self.hook: Callable = hook @@ -265,30 +271,31 @@ def __setstate__(self, state: Dict): except ImportError: from paddle.fluid.dygraph.layers import HookRemoveHelper - @patch_to(nn.Layer) def register_load_state_dict_pre_hook(self, hook, with_module=False): handle = HookRemoveHelper(self.load_state_dict_pre_hooks) self.load_state_dict_pre_hooks[handle._hook_id] = _WrappedHook(hook, self if with_module else None) return handle + nn.Layer.register_load_state_dict_pre_hook = register_load_state_dict_pre_hook + raw_set_state_dict = nn.Layer.set_state_dict - @patch_to(nn.Layer) def set_state_dict(self, state_dict, use_structured_name: bool = True): for hook in self.load_state_dict_pre_hooks.values(): hook(state_dict) return raw_set_state_dict(self, state_dict, use_structured_name=use_structured_name) + nn.Layer.set_state_dict = set_state_dict nn.Layer.load_dict = nn.Layer.set_state_dict nn.Layer.set_dict = nn.Layer.set_state_dict raw_init = nn.Layer.__init__ - @patch_to(nn.Layer) def __init__(self, name_scope=None, dtype="float32"): raw_init(self, name_scope=name_scope, dtype=dtype) self.load_state_dict_pre_hooks = OrderedDict() + nn.Layer.__init__ = __init__ if is_paddle_available() and is_paddlenlp_available(): # set logger level warning diff --git a/ppdiffusers/ppdiffusers/pipelines/pipeline_utils.py b/ppdiffusers/ppdiffusers/pipelines/pipeline_utils.py index 8829e2f97eb2..c58e72707cb3 100644 --- a/ppdiffusers/ppdiffusers/pipelines/pipeline_utils.py +++ b/ppdiffusers/ppdiffusers/pipelines/pipeline_utils.py @@ -944,15 +944,17 @@ def load_module(name, value): else: # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) - except Exception: + except Exception as e: # (TODO, junnyu) # if we cant find this file, we will try to download this if not is_local_dir and not from_hf_hub: loaded_sub_model = load_method( pretrained_model_name_or_path + "/" + name, cache_dir=cache_dir, **loading_kwargs ) - if loaded_sub_model is None: - raise ValueError(f"We cant load '{name}' from {pretrained_model_name_or_path} or {cached_folder}!") + if loaded_sub_model is None: + raise ValueError( + f"We cant load '{name}' from {pretrained_model_name_or_path} or {cached_folder}! \n {e} " + ) # paddlenlp's model is in training mode not eval mode # if isinstance(loaded_sub_model, PretrainedModel): # if paddle_dtype is not None and next(loaded_sub_model.named_parameters())[1].dtype != paddle_dtype: diff --git a/ppdiffusers/ppdiffusers/utils/paddle_utils.py b/ppdiffusers/ppdiffusers/utils/paddle_utils.py index 394bd8d02860..f3ebdbac0662 100644 --- a/ppdiffusers/ppdiffusers/utils/paddle_utils.py +++ b/ppdiffusers/ppdiffusers/utils/paddle_utils.py @@ -75,6 +75,7 @@ def get_rng_state_tracker(*args, **kwargs): rand = paddle.rand randint = paddle.randint + @paddle.jit.not_to_static def randn_pt(shape, dtype=None, name=None, **kwargs): generator = kwargs.get("generator", None) if generator is None: @@ -83,6 +84,7 @@ def randn_pt(shape, dtype=None, name=None, **kwargs): with get_rng_state_tracker().rng_state(generator): return randn(shape, dtype=dtype, name=name) + @paddle.jit.not_to_static def rand_pt(shape, dtype=None, name=None, **kwargs): generator = kwargs.get("generator", None) if generator is None: @@ -91,6 +93,7 @@ def rand_pt(shape, dtype=None, name=None, **kwargs): with get_rng_state_tracker().rng_state(generator): return rand(shape, dtype=dtype, name=name) + @paddle.jit.not_to_static def randint_pt(low=0, high=None, shape=[1], dtype=None, name=None, **kwargs): generator = kwargs.get("generator", None) if generator is None: @@ -99,6 +102,7 @@ def randint_pt(low=0, high=None, shape=[1], dtype=None, name=None, **kwargs): with get_rng_state_tracker().rng_state(generator): return randint(low=low, high=high, shape=shape, dtype=dtype, name=name) + @paddle.jit.not_to_static def randn_like_pt(x, dtype=None, name=None, **kwargs): generator = kwargs.get("generator", None) if dtype is None: