diff --git a/stable_video_diffusion_fp16_colab.ipynb b/stable_video_diffusion_fp16_colab.ipynb index 6c5d788..39fd69a 100644 --- a/stable_video_diffusion_fp16_colab.ipynb +++ b/stable_video_diffusion_fp16_colab.ipynb @@ -30,90 +30,47 @@ "!ln -s /content/generative-models/scripts/util/detection/p_head_v1.npz /content/scripts/util/detection/p_head_v1.npz\n", "!ln -s /content/generative-models/scripts/util/detection/w_head_v1.npz /content/scripts/util/detection/w_head_v1.npz\n", "\n", - "# Load Model\n", - "import sys\n", + "import sys, os, math, torch, cv2\n", "from omegaconf import OmegaConf\n", - "import torch\n", + "from glob import glob\n", + "from pathlib import Path\n", + "from typing import Optional\n", + "import numpy as np\n", + "from einops import rearrange, repeat\n", + "\n", + "from PIL import Image\n", + "from torchvision.transforms import ToTensor\n", + "from torchvision.transforms import functional as TF\n", + "from sgm.util import instantiate_from_config\n", "\n", "sys.path.append(\"generative-models\")\n", - "from sgm.util import default, instantiate_from_config\n", - "from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\n", "\n", - "def load_model(\n", - " config: str,\n", - " device: str,\n", - " num_frames: int,\n", - " num_steps: int,\n", - "):\n", + "def load_model(config: str, device: str, num_frames: int, num_steps: int):\n", " config = OmegaConf.load(config)\n", - " config.model.params.conditioner_config.params.emb_models[\n", - " 0\n", - " ].params.open_clip_embedding_config.params.init_device = device\n", + " config.model.params.conditioner_config.params.emb_models[0].params.open_clip_embedding_config.params.init_device = device\n", " config.model.params.sampler_config.params.num_steps = num_steps\n", - " config.model.params.sampler_config.params.guider_config.params.num_frames = (\n", - " num_frames\n", - " )\n", + " config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)\n", " with torch.device(device):\n", " model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)\n", - " filter = DeepFloydDataFiltering(verbose=False, device=device)\n", - " return model, filter\n", - "\n", - "version = \"svd_xt\"\n", - "if version == \"svd\":\n", - " num_frames = 14\n", - " num_steps = 25\n", - " # output_folder = default(output_folder, \"outputs/simple_video_sample/svd/\")\n", - " model_config = \"generative-models/scripts/sampling/configs/svd.yaml\"\n", - "elif version == \"svd_xt\":\n", - " num_frames = 25\n", - " num_steps = 30\n", - " # output_folder = default(output_folder, \"outputs/simple_video_sample/svd_xt/\")\n", - " model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n", - "else:\n", - " raise ValueError(f\"Version {version} does not exist.\")\n", + " return model\n", "\n", + "num_frames = 25\n", + "num_steps = 30\n", + "model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "model, filter = load_model(\n", - " model_config,\n", - " device,\n", - " num_frames,\n", - " num_steps,\n", - ")\n", - "# move models expect unet to cpu\n", + "model = load_model(model_config, device, num_frames, num_steps)\n", "model.conditioner.cpu()\n", "model.first_stage_model.cpu()\n", - "# change the dtype of unet\n", "model.model.to(dtype=torch.float16)\n", "torch.cuda.empty_cache()\n", "model = model.requires_grad_(False)\n", "\n", - "# Sampling function\n", - "import math\n", - "import os\n", - "from glob import glob\n", - "from pathlib import Path\n", - "from typing import Optional\n", - "\n", - "import cv2\n", - "import numpy as np\n", - "import torch\n", - "from einops import rearrange, repeat\n", - "from fire import Fire\n", - "\n", - "from PIL import Image\n", - "from torchvision.transforms import ToTensor\n", - "from torchvision.transforms import functional as TF\n", - "\n", - "from sgm.inference.helpers import embed_watermark\n", - "from sgm.util import default, instantiate_from_config\n", - "\n", "def get_unique_embedder_keys_from_conditioner(conditioner):\n", " return list(set([x.input_key for x in conditioner.embedders]))\n", "\n", "def get_batch(keys, value_dict, N, T, device, dtype=None):\n", " batch = {}\n", " batch_uc = {}\n", - "\n", " for key in keys:\n", " if key == \"fps_id\":\n", " batch[key] = (\n", @@ -141,17 +98,15 @@ " )\n", " else:\n", " batch[key] = value_dict[key]\n", - "\n", " if T is not None:\n", " batch[\"num_video_frames\"] = T\n", - "\n", " for key in batch.keys():\n", " if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n", " batch_uc[key] = torch.clone(batch[key])\n", " return batch, batch_uc\n", "\n", "def sample(\n", - " input_path: str = \"assets/test_image.png\", # Can either be image file or folder with image files\n", + " input_path: str = \"/content/test_image.png\",\n", " resize_image: bool = False,\n", " num_frames: Optional[int] = None,\n", " num_steps: Optional[int] = None,\n", @@ -197,14 +152,12 @@ " print(f\"Resizing {image.size} to (1024, 576)\")\n", " image = TF.resize(TF.resize(image, 1024), (576, 1024))\n", " w, h = image.size\n", - "\n", " if h % 64 != 0 or w % 64 != 0:\n", " width, height = map(lambda x: x - x % 64, (w, h))\n", " image = image.resize((width, height))\n", " print(\n", " f\"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!\"\n", " )\n", - "\n", " image = ToTensor()(image)\n", " image = image * 2.0 - 1.0\n", "\n", @@ -222,10 +175,8 @@ " print(\n", " \"WARNING: High motion bucket! This may lead to suboptimal performance.\"\n", " )\n", - "\n", " if fps_id < 5:\n", " print(\"WARNING: Small fps value! This may lead to suboptimal performance.\")\n", - "\n", " if fps_id > 30:\n", " print(\"WARNING: Large fps value! This may lead to suboptimal performance.\")\n", "\n", @@ -274,11 +225,8 @@ " c[k] = c[k].to(dtype=torch.float16)\n", "\n", " randn = torch.randn(shape, device=device, dtype=torch.float16)\n", - "\n", " additional_model_inputs = {}\n", - " additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n", - " 2, num_frames\n", - " ).to(device, )\n", + " additional_model_inputs[\"image_only_indicator\"] = torch.zeros(2, num_frames).to(device)\n", " additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n", "\n", " for k in additional_model_inputs:\n", @@ -286,14 +234,10 @@ " additional_model_inputs[k] = additional_model_inputs[k].to(dtype=torch.float16)\n", "\n", " def denoiser(input, sigma, c):\n", - " return model.denoiser(\n", - " model.model, input, sigma, c, **additional_model_inputs\n", - " )\n", + " return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)\n", "\n", " samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)\n", " samples_z.to(dtype=model.first_stage_model.dtype)\n", - " ##\n", - "\n", " model.en_and_decode_n_samples_a_time = decoding_t\n", " model.first_stage_model.to(device)\n", " samples_x = model.decode_first_stage(samples_z)\n", @@ -310,9 +254,6 @@ " fps_id + 1,\n", " (samples.shape[-1], samples.shape[-2]),\n", " )\n", - "\n", - " samples = embed_watermark(samples)\n", - " samples = filter(samples)\n", " vid = (\n", " (rearrange(samples, \"t c h w -> t h w c\") * 255)\n", " .cpu()\n", @@ -359,9 +300,7 @@ " decoding_t = gr.Number(precision=0, label=\"number of frames decoded at a time\", value=2)\n", " with gr.Column():\n", " video_out = gr.Video(label=\"generated video\")\n", - " examples = [\n", - " [\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]\n", - " ]\n", + " examples = [[\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]]\n", " inputs = [image, resize_image, n_frames, n_steps, seed, decoding_t]\n", " outputs = [video_out]\n", " btn.click(infer, inputs=inputs, outputs=outputs)\n", diff --git a/test.ipynb b/test.ipynb index 958862c..074f6a8 100644 --- a/test.ipynb +++ b/test.ipynb @@ -300,7 +300,7 @@ " decoding_t = gr.Number(precision=0, label=\"number of frames decoded at a time\", value=2)\n", " with gr.Column():\n", " video_out = gr.Video(label=\"generated video\")\n", - " examples = [\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]\n", + " examples = [[\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]]\n", " inputs = [image, resize_image, n_frames, n_steps, seed, decoding_t]\n", " outputs = [video_out]\n", " btn.click(infer, inputs=inputs, outputs=outputs)\n",