Skip to content

Commit

Permalink
[PPDiffusers] update text2img_laion400M (#5619)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
JunnYu committed Apr 12, 2023
1 parent d6e316a commit 3ee90fa
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/root/laion_aes/part-00000
/root/laion_aes/part-00001
/root/laion_aes/part-00002
/root/laion_aes/part-00003
/root/laion_aes/part-00004
/root/laion_aes/part-00005
/root/laion_aes/part-00006
/root/laion_aes/part-00007
/root/laion_aes/part-00008
/root/laion_aes/part-00009
/root/laion_aes/part-00010
/root/laion_aes/part-00011
/root/laion_aes/part-00012
/root/laion_aes/part-00013
/root/laion_aes/part-00014
/root/laion_aes/part-00015
/root/laion_aes/part-00016
/root/laion_aes/part-00017
/root/laion_aes/part-00018
/root/laion_aes/part-00019
/root/laion_aes/part-00020
/root/laion_aes/part-00021
/root/laion_aes/part-00022
/root/laion_aes/part-00023
/root/laion_aes/part-00024
/root/laion_aes/part-00025
/root/laion_aes/part-00026
/root/laion_aes/part-00027
/root/laion_aes/part-00028
/root/laion_aes/part-00029
/root/laion_aes/part-00030
/root/laion_aes/part-00031
/root/laion_aes/part-00032
/root/laion_aes/part-00033
/root/laion_aes/part-00034
/root/laion_aes/part-00035
/root/laion_aes/part-00036
/root/laion_aes/part-00037
/root/laion_aes/part-00038
/root/laion_aes/part-00039
/root/laion_aes/part-00040
/root/laion_aes/part-00041
/root/laion_aes/part-00042
/root/laion_aes/part-00043
/root/laion_aes/part-00044
/root/laion_aes/part-00045
/root/laion_aes/part-00046
/root/laion_aes/part-00047
/root/laion_aes/part-00048
/root/laion_aes/part-00049
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
./data/filelist/laion_aes.filelist
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ModelArguments:
enable_xformers_memory_efficient_attention: bool = field(
default=False, metadata={"help": "enable_xformers_memory_efficient_attention."}
)
to_static: Optional[bool] = field(default=False, metadata={"help": "Whether or not to_static"})


@dataclass
Expand Down
67 changes: 47 additions & 20 deletions ppdiffusers/examples/text_to_image_laion400m/ldm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,18 @@ def __init__(self, model_args):
self.noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
self.eval_scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.eval_scheduler.set_timesteps(model_args.num_inference_steps)
self.register_buffer("alphas_cumprod", self.noise_scheduler.alphas_cumprod)

if model_args.image_logging_steps > 0:
self.eval_scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.eval_scheduler.set_timesteps(model_args.num_inference_steps)
self.init_weights()
self.use_ema = model_args.use_ema
if self.use_ema:
Expand All @@ -138,6 +141,30 @@ def __init__(self, model_args):
f" correctly and a GPU is available: {e}"
)

# make sure unet text_encoder in train mode, vae in eval mode
self.unet.train()
self.text_encoder.train()
self.vae.eval()

def add_noise(
self,
original_samples: paddle.Tensor,
noise: paddle.Tensor,
timesteps: paddle.Tensor,
) -> paddle.Tensor:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def init_weights(self):
# init text_encoder
if not self.text_encoder_is_pretrained:
Expand Down Expand Up @@ -180,17 +207,17 @@ def on_train_batch_end(self):
self.model_ema(self.unet)

def forward(self, input_ids=None, pixel_values=None, **kwargs):
self.train()
with paddle.amp.auto_cast(enable=False):
with paddle.no_grad():
self.vae.eval()
latents = self.vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215
noise = paddle.randn(latents.shape)
timesteps = paddle.randint(0, self.noise_scheduler.num_train_timesteps, (latents.shape[0],)).astype(
"int64"
)
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
with paddle.no_grad():
# TODO add this
# with paddle.amp.auto_cast(enable=False):
self.vae.eval()
latents = self.vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215
noise = paddle.randn(latents.shape)
timesteps = paddle.randint(0, self.noise_scheduler.num_train_timesteps, (latents.shape[0],)).astype(
"int64"
)
noisy_latents = self.add_noise(latents, noise, timesteps)

encoder_hidden_states = self.text_encoder(input_ids)[0]
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import base64
import gzip
import io
import json
import random

import numpy as np
Expand All @@ -30,7 +31,9 @@

def parse_line(line, filename):
def parse_src(filename):
if "laion400m" in filename:
if "laion_aes" in filename:
return "laion_aes"
elif "laion400m" in filename:
return "laion400m"
else:
raise NotImplementedError(f"Unkown data source, {filename}")
Expand All @@ -40,6 +43,10 @@ def parse_src(filename):
data_source = parse_src(filename)
if data_source == "laion400m":
caption, _, img_b64 = vec[:3]
elif data_source == "laion_aes":
text_json = json.loads(vec[2])
img_b64 = vec[5]
caption = text_json.get("caption_en", text_json.get("blip_caption_en", ""))
else:
_, captions, _, _, _, img_b64 = vec[:6]
caption = random.sample(captions.split("|"), 1)[0].replace("\1", "")
Expand Down Expand Up @@ -135,7 +142,7 @@ def sample_loader(self, file_ids, filenames):
for i in file_ids:
filename = filenames[i].strip("\n")
with gzip.open(filename, "rb") if filename.endswith(".gz") else open(filename, "rb") as f:
retry = 0
# retry = 0
while True:
line = f.readline()

Expand All @@ -151,9 +158,9 @@ def sample_loader(self, file_ids, filenames):
continue
data = parse_line(line, filename)
if data is None:
retry += 1
if retry > 100:
break
# retry += 1
# if retry > 100:
# break
continue
else:
w, h = data["image"].size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ def main():
# report to custom_visualdl
training_args.report_to = ["custom_visualdl"]
training_args.resolution = data_args.resolution

training_args.image_logging_steps = model_args.image_logging_steps = (
math.ceil(model_args.image_logging_steps / training_args.logging_steps) * training_args.logging_steps
(math.ceil(model_args.image_logging_steps / training_args.logging_steps) * training_args.logging_steps)
if model_args.image_logging_steps > 0
else -1
)

training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")

Expand Down Expand Up @@ -68,6 +72,16 @@ def main():
tokenizer=model.tokenizer,
)

if model_args.to_static:
input_ids = paddle.static.InputSpec(name="input_ids", shape=[-1, model_args.model_max_length], dtype="int64")
pixel_values = paddle.static.InputSpec(
name="pixel_values", shape=[-1, 3, data_args.resolution, data_args.resolution], dtype="float32"
)
specs = [input_ids, pixel_values]
paddle.jit.ignore_module([os])
model = paddle.jit.to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))

trainer = LatentDiffusionTrainer(
model=model, args=training_args, train_dataset=train_dataset, tokenizer=model.tokenizer
)
Expand Down
19 changes: 13 additions & 6 deletions ppdiffusers/ppdiffusers/patches/ppnlp_patch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def permute_pt(x, *perm: builtins.int, name=None):
# patch repeat_interleave
raw_repeat_interleave = paddle.repeat_interleave

@paddle.jit.not_to_static
def repeat_interleave(x, repeats, axis=None, name=None):
fp16 = False
if x.dtype == paddle.float16:
Expand All @@ -145,6 +146,7 @@ def repeat_interleave(x, repeats, axis=None, name=None):
# patch max
raw_max = paddle.max

@paddle.jit.not_to_static
def max(x, axis=None, keepdim=False, name=None):
fp16 = False
if x.dtype == paddle.float16:
Expand All @@ -163,6 +165,7 @@ def max(x, axis=None, keepdim=False, name=None):
# patch gather_nd support bfloat16
raw_gather_nd = paddle.gather_nd

@paddle.jit.not_to_static
def gather_nd(x, index, name=None):
bfp16 = False
if x.dtype == paddle.bfloat16:
Expand All @@ -180,21 +183,23 @@ def gather_nd(x, index, name=None):
paddle.Tensor.contiguous = lambda x: x

# must return self!
@patch_to(nn.Layer)
def eval(self):
# Layer-level setting
self.training = False
for layer in self.sublayers():
layer.training = False
return self

@patch_to(nn)
nn.Layer.eval = eval

def Parameter(data: paddle.Tensor, requires_grad=True):
tensor = paddle.create_parameter(data.shape, dtype=data.dtype, default_initializer=nn.initializer.Assign(data))
if not requires_grad:
tensor.stop_gradient = True
return tensor

nn.Parameter = Parameter

@contextlib.contextmanager
def device_scope(device="cpu"):
new_device = device.replace("cuda", "gpu")
Expand All @@ -207,7 +212,6 @@ def device_scope(device="cpu"):

paddle.device_scope = device_scope

@patch_to(nn.Layer)
def get_sublayer(self, target: str):
if target == "":
return self
Expand All @@ -225,6 +229,8 @@ def get_sublayer(self, target: str):
raise AttributeError("`" + item + "` is not " "an nn.Layer")
return mod

nn.Layer.get_sublayer = get_sublayer

class _WrappedHook:
def __init__(self, hook: Callable, module: Optional["nn.Layer"] = None):
self.hook: Callable = hook
Expand Down Expand Up @@ -265,30 +271,31 @@ def __setstate__(self, state: Dict):
except ImportError:
from paddle.fluid.dygraph.layers import HookRemoveHelper

@patch_to(nn.Layer)
def register_load_state_dict_pre_hook(self, hook, with_module=False):
handle = HookRemoveHelper(self.load_state_dict_pre_hooks)
self.load_state_dict_pre_hooks[handle._hook_id] = _WrappedHook(hook, self if with_module else None)
return handle

nn.Layer.register_load_state_dict_pre_hook = register_load_state_dict_pre_hook

raw_set_state_dict = nn.Layer.set_state_dict

@patch_to(nn.Layer)
def set_state_dict(self, state_dict, use_structured_name: bool = True):
for hook in self.load_state_dict_pre_hooks.values():
hook(state_dict)
return raw_set_state_dict(self, state_dict, use_structured_name=use_structured_name)

nn.Layer.set_state_dict = set_state_dict
nn.Layer.load_dict = nn.Layer.set_state_dict
nn.Layer.set_dict = nn.Layer.set_state_dict

raw_init = nn.Layer.__init__

@patch_to(nn.Layer)
def __init__(self, name_scope=None, dtype="float32"):
raw_init(self, name_scope=name_scope, dtype=dtype)
self.load_state_dict_pre_hooks = OrderedDict()

nn.Layer.__init__ = __init__

if is_paddle_available() and is_paddlenlp_available():
# set logger level warning
Expand Down
8 changes: 5 additions & 3 deletions ppdiffusers/ppdiffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,15 +944,17 @@ def load_module(name, value):
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
except Exception:
except Exception as e:
# (TODO, junnyu)
# if we cant find this file, we will try to download this
if not is_local_dir and not from_hf_hub:
loaded_sub_model = load_method(
pretrained_model_name_or_path + "/" + name, cache_dir=cache_dir, **loading_kwargs
)
if loaded_sub_model is None:
raise ValueError(f"We cant load '{name}' from {pretrained_model_name_or_path} or {cached_folder}!")
if loaded_sub_model is None:
raise ValueError(
f"We cant load '{name}' from {pretrained_model_name_or_path} or {cached_folder}! \n {e} "
)
# paddlenlp's model is in training mode not eval mode
# if isinstance(loaded_sub_model, PretrainedModel):
# if paddle_dtype is not None and next(loaded_sub_model.named_parameters())[1].dtype != paddle_dtype:
Expand Down
4 changes: 4 additions & 0 deletions ppdiffusers/ppdiffusers/utils/paddle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_rng_state_tracker(*args, **kwargs):
rand = paddle.rand
randint = paddle.randint

@paddle.jit.not_to_static
def randn_pt(shape, dtype=None, name=None, **kwargs):
generator = kwargs.get("generator", None)
if generator is None:
Expand All @@ -83,6 +84,7 @@ def randn_pt(shape, dtype=None, name=None, **kwargs):
with get_rng_state_tracker().rng_state(generator):
return randn(shape, dtype=dtype, name=name)

@paddle.jit.not_to_static
def rand_pt(shape, dtype=None, name=None, **kwargs):
generator = kwargs.get("generator", None)
if generator is None:
Expand All @@ -91,6 +93,7 @@ def rand_pt(shape, dtype=None, name=None, **kwargs):
with get_rng_state_tracker().rng_state(generator):
return rand(shape, dtype=dtype, name=name)

@paddle.jit.not_to_static
def randint_pt(low=0, high=None, shape=[1], dtype=None, name=None, **kwargs):
generator = kwargs.get("generator", None)
if generator is None:
Expand All @@ -99,6 +102,7 @@ def randint_pt(low=0, high=None, shape=[1], dtype=None, name=None, **kwargs):
with get_rng_state_tracker().rng_state(generator):
return randint(low=low, high=high, shape=shape, dtype=dtype, name=name)

@paddle.jit.not_to_static
def randn_like_pt(x, dtype=None, name=None, **kwargs):
generator = kwargs.get("generator", None)
if dtype is None:
Expand Down

0 comments on commit 3ee90fa

Please sign in to comment.