From 08756612418b33e427b12dfd8b6686bedc06e08b Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Wed, 22 Nov 2023 21:11:42 +0300 Subject: [PATCH] test --- stable_video_diffusion_fp32_colab.ipynb | 276 +++++++++++++++++++++++- 1 file changed, 272 insertions(+), 4 deletions(-) diff --git a/stable_video_diffusion_fp32_colab.ipynb b/stable_video_diffusion_fp32_colab.ipynb index 65b0082..add407e 100644 --- a/stable_video_diffusion_fp32_colab.ipynb +++ b/stable_video_diffusion_fp32_colab.ipynb @@ -18,10 +18,278 @@ "outputs": [], "source": [ "%cd /content\n", - "!git clone -b dev https://github.com/camenduru/stable-video-diffusion-hf\n", - "%cd /content/stable-video-diffusion-hf\n", - "!pip install -r https://github.com/camenduru/stable-video-diffusion-colab/raw/main/requirements.txt\n", - "!python app.py" + "!git clone -b dev https://github.com/camenduru/generative-models\n", + "!pip install -q -r https://github.com/camenduru/stable-video-diffusion-colab/raw/main/requirements.txt\n", + "!pip install -q -e generative-models\n", + "!pip install -q -e git+https://github.com/Stability-AI/datapipelines@main#egg=sdata\n", + "\n", + "!apt -y install -qq aria2\n", + "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/vdo/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors?download=true -d /content/checkpoints -o svd_xt.safetensors\n", + "\n", + "!mkdir -p /content/scripts/util/detection\n", + "!ln -s /content/generative-models/scripts/util/detection/p_head_v1.npz /content/scripts/util/detection/p_head_v1.npz\n", + "!ln -s /content/generative-models/scripts/util/detection/w_head_v1.npz /content/scripts/util/detection/w_head_v1.npz\n", + "\n", + "import sys\n", + "sys.path.append(\"generative-models\")\n", + "\n", + "import os, math, torch, cv2\n", + "from omegaconf import OmegaConf\n", + "from glob import glob\n", + "from pathlib import Path\n", + "from typing import Optional\n", + "import numpy as np\n", + "from einops import rearrange, repeat\n", + "\n", + "from PIL import Image\n", + "from torchvision.transforms import ToTensor\n", + "from torchvision.transforms import functional as TF\n", + "from sgm.util import instantiate_from_config\n", + "from sgm.inference.helpers import embed_watermark\n", + "from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\n", + "\n", + "def load_model(config: str, device: str, num_frames: int, num_steps: int):\n", + " config = OmegaConf.load(config)\n", + " config.model.params.conditioner_config.params.emb_models[0].params.open_clip_embedding_config.params.init_device = device\n", + " config.model.params.sampler_config.params.num_steps = num_steps\n", + " config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)\n", + " with torch.device(device):\n", + " model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)\n", + " filter = DeepFloydDataFiltering(verbose=False, device=device)\n", + " return model, filter\n", + "\n", + "num_frames = 25\n", + "num_steps = 30\n", + "model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model, filter = load_model(model_config, device, num_frames, num_steps)\n", + "\n", + "def get_unique_embedder_keys_from_conditioner(conditioner):\n", + " return list(set([x.input_key for x in conditioner.embedders]))\n", + "\n", + "def get_batch(keys, value_dict, N, T, device):\n", + " batch = {}\n", + " batch_uc = {}\n", + " for key in keys:\n", + " if key == \"fps_id\":\n", + " batch[key] = (\n", + " torch.tensor([value_dict[\"fps_id\"]])\n", + " .to(device)\n", + " .repeat(int(math.prod(N)))\n", + " )\n", + " elif key == \"motion_bucket_id\":\n", + " batch[key] = (\n", + " torch.tensor([value_dict[\"motion_bucket_id\"]])\n", + " .to(device)\n", + " .repeat(int(math.prod(N)))\n", + " )\n", + " elif key == \"cond_aug\":\n", + " batch[key] = repeat(\n", + " torch.tensor([value_dict[\"cond_aug\"]]).to(device),\n", + " \"1 -> b\",\n", + " b=math.prod(N),\n", + " )\n", + " elif key == \"cond_frames\":\n", + " batch[key] = repeat(value_dict[\"cond_frames\"], \"1 ... -> b ...\", b=N[0])\n", + " elif key == \"cond_frames_without_noise\":\n", + " batch[key] = repeat(\n", + " value_dict[\"cond_frames_without_noise\"], \"1 ... -> b ...\", b=N[0]\n", + " )\n", + " else:\n", + " batch[key] = value_dict[key]\n", + " if T is not None:\n", + " batch[\"num_video_frames\"] = T\n", + " for key in batch.keys():\n", + " if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n", + " batch_uc[key] = torch.clone(batch[key])\n", + " return batch, batch_uc\n", + "\n", + "def sample(\n", + " input_path: str = \"/content/test_image.png\",\n", + " resize_image: bool = False,\n", + " num_frames: Optional[int] = None,\n", + " num_steps: Optional[int] = None,\n", + " fps_id: int = 6,\n", + " motion_bucket_id: int = 127,\n", + " cond_aug: float = 0.02,\n", + " seed: int = 23,\n", + " decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n", + " device: str = \"cuda\",\n", + " output_folder: Optional[str] = \"/content/outputs\",\n", + "):\n", + " \"\"\"\n", + " Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each\n", + " image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.\n", + " \"\"\"\n", + " torch.manual_seed(seed)\n", + "\n", + " path = Path(input_path)\n", + " all_img_paths = []\n", + " if path.is_file():\n", + " if any([input_path.endswith(x) for x in [\"jpg\", \"jpeg\", \"png\"]]):\n", + " all_img_paths = [input_path]\n", + " else:\n", + " raise ValueError(\"Path is not valid image file.\")\n", + " elif path.is_dir():\n", + " all_img_paths = sorted(\n", + " [\n", + " f\n", + " for f in path.iterdir()\n", + " if f.is_file() and f.suffix.lower() in [\".jpg\", \".jpeg\", \".png\"]\n", + " ]\n", + " )\n", + " if len(all_img_paths) == 0:\n", + " raise ValueError(\"Folder does not contain any images.\")\n", + " else:\n", + " raise ValueError\n", + " all_out_paths = []\n", + " for input_img_path in all_img_paths:\n", + " with Image.open(input_img_path) as image:\n", + " if image.mode == \"RGBA\":\n", + " image = image.convert(\"RGB\")\n", + " if resize_image and image.size != (1024, 576):\n", + " print(f\"Resizing {image.size} to (1024, 576)\")\n", + " image = TF.resize(TF.resize(image, 1024), (576, 1024))\n", + " w, h = image.size\n", + " if h % 64 != 0 or w % 64 != 0:\n", + " width, height = map(lambda x: x - x % 64, (w, h))\n", + " image = image.resize((width, height))\n", + " print(\n", + " f\"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!\"\n", + " )\n", + " image = ToTensor()(image)\n", + " image = image * 2.0 - 1.0\n", + "\n", + " image = image.unsqueeze(0).to(device)\n", + " H, W = image.shape[2:]\n", + " assert image.shape[1] == 3\n", + " F = 8\n", + " C = 4\n", + " shape = (num_frames, C, H // F, W // F)\n", + " if (H, W) != (576, 1024):\n", + " print(\n", + " \"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`.\"\n", + " )\n", + " if motion_bucket_id > 255:\n", + " print(\n", + " \"WARNING: High motion bucket! This may lead to suboptimal performance.\"\n", + " )\n", + " if fps_id < 5:\n", + " print(\"WARNING: Small fps value! This may lead to suboptimal performance.\")\n", + " if fps_id > 30:\n", + " print(\"WARNING: Large fps value! This may lead to suboptimal performance.\")\n", + "\n", + " value_dict = {}\n", + " value_dict[\"motion_bucket_id\"] = motion_bucket_id\n", + " value_dict[\"fps_id\"] = fps_id\n", + " value_dict[\"cond_aug\"] = cond_aug\n", + " value_dict[\"cond_frames_without_noise\"] = image\n", + " value_dict[\"cond_frames\"] = image + cond_aug * torch.randn_like(image)\n", + " value_dict[\"cond_aug\"] = cond_aug\n", + " torch.cuda.empty_cache()\n", + "\n", + " with torch.no_grad():\n", + " with torch.autocast(device):\n", + " batch, batch_uc = get_batch(\n", + " get_unique_embedder_keys_from_conditioner(model.conditioner),\n", + " value_dict,\n", + " [1, num_frames],\n", + " T=num_frames,\n", + " device=device,\n", + " )\n", + " c, uc = model.conditioner.get_unconditional_conditioning(\n", + " batch,\n", + " batch_uc=batch_uc,\n", + " force_uc_zero_embeddings=[\n", + " \"cond_frames\",\n", + " \"cond_frames_without_noise\",\n", + " ],\n", + " )\n", + " torch.cuda.empty_cache()\n", + "\n", + " for k in [\"crossattn\", \"concat\"]:\n", + " uc[k] = repeat(uc[k], \"b ... -> b t ...\", t=num_frames)\n", + " uc[k] = rearrange(uc[k], \"b t ... -> (b t) ...\", t=num_frames)\n", + " c[k] = repeat(c[k], \"b ... -> b t ...\", t=num_frames)\n", + " c[k] = rearrange(c[k], \"b t ... -> (b t) ...\", t=num_frames)\n", + "\n", + " randn = torch.randn(shape, device=device)\n", + " additional_model_inputs = {}\n", + " additional_model_inputs[\"image_only_indicator\"] = torch.zeros(2, num_frames).to(device)\n", + " additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n", + "\n", + " def denoiser(input, sigma, c):\n", + " return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)\n", + "\n", + " samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)\n", + " model.en_and_decode_n_samples_a_time = decoding_t\n", + " samples_x = model.decode_first_stage(samples_z)\n", + " samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)\n", + " torch.cuda.empty_cache()\n", + "\n", + " os.makedirs(output_folder, exist_ok=True)\n", + " base_count = len(glob(os.path.join(output_folder, \"*.mp4\")))\n", + " video_path = os.path.join(output_folder, f\"{base_count:06d}.mp4\")\n", + " writer = cv2.VideoWriter(\n", + " video_path,\n", + " cv2.VideoWriter_fourcc(*\"MP4V\"),\n", + " fps_id + 1,\n", + " (samples.shape[-1], samples.shape[-2]),\n", + " )\n", + " samples = embed_watermark(samples)\n", + " samples = filter(samples)\n", + " vid = (\n", + " (rearrange(samples, \"t c h w -> t h w c\") * 255)\n", + " .cpu()\n", + " .numpy()\n", + " .astype(np.uint8)\n", + " )\n", + " for frame in vid:\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n", + " writer.write(frame)\n", + " writer.release()\n", + " all_out_paths.append(video_path)\n", + " return all_out_paths\n", + "\n", + "import gradio as gr\n", + "import random\n", + "\n", + "def infer(input_path: str, resize_image: bool, n_frames: int, n_steps: int, seed: str, decoding_t: int) -> str:\n", + " if seed == \"random\":\n", + " seed = random.randint(0, 2**32)\n", + " seed = int(seed)\n", + " output_paths = sample(\n", + " input_path=input_path,\n", + " resize_image=resize_image,\n", + " num_frames=n_frames,\n", + " num_steps=n_steps,\n", + " fps_id=6,\n", + " motion_bucket_id=127,\n", + " cond_aug=0.02,\n", + " seed=seed,\n", + " decoding_t=decoding_t, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.\n", + " device=device,\n", + " )\n", + " return output_paths[0]\n", + "\n", + "with gr.Blocks() as demo:\n", + " with gr.Column():\n", + " image = gr.Image(label=\"input image\", type=\"filepath\")\n", + " resize_image = gr.Checkbox(label=\"resize to optimal size\", value=True)\n", + " btn = gr.Button(\"Run\")\n", + " with gr.Accordion(label=\"Advanced options\", open=False):\n", + " n_frames = gr.Number(precision=0, label=\"number of frames\", value=num_frames)\n", + " n_steps = gr.Number(precision=0, label=\"number of steps\", value=num_steps)\n", + " seed = gr.Text(value=\"random\", label=\"seed (integer or 'random')\",)\n", + " decoding_t = gr.Number(precision=0, label=\"number of frames decoded at a time\", value=2)\n", + " with gr.Column():\n", + " video_out = gr.Video(label=\"generated video\")\n", + " examples = [[\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]]\n", + " inputs = [image, resize_image, n_frames, n_steps, seed, decoding_t]\n", + " outputs = [video_out]\n", + " btn.click(infer, inputs=inputs, outputs=outputs)\n", + " gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=infer)\n", + " demo.queue().launch(debug=True, share=True, inline=False, show_error=True)" ] } ],