Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Nov 22, 2023
1 parent cdb88eb commit 3b5b720
Showing 1 changed file with 85 additions and 25 deletions.
110 changes: 85 additions & 25 deletions stable_video_diffusion_fp16_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,48 +30,90 @@
"!ln -s /content/generative-models/scripts/util/detection/p_head_v1.npz /content/scripts/util/detection/p_head_v1.npz\n",
"!ln -s /content/generative-models/scripts/util/detection/w_head_v1.npz /content/scripts/util/detection/w_head_v1.npz\n",
"\n",
"# Load Model\n",
"import sys\n",
"sys.path.append(\"generative-models\")\n",
"\n",
"import os, math, torch, cv2\n",
"from omegaconf import OmegaConf\n",
"from glob import glob\n",
"from pathlib import Path\n",
"from typing import Optional\n",
"import numpy as np\n",
"from einops import rearrange, repeat\n",
"import torch\n",
"\n",
"from PIL import Image\n",
"from torchvision.transforms import ToTensor\n",
"from torchvision.transforms import functional as TF\n",
"from sgm.util import instantiate_from_config\n",
"sys.path.append(\"generative-models\")\n",
"from sgm.util import default, instantiate_from_config\n",
"from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\n",
"\n",
"def load_model(config: str, device: str, num_frames: int, num_steps: int):\n",
"def load_model(\n",
" config: str,\n",
" device: str,\n",
" num_frames: int,\n",
" num_steps: int,\n",
"):\n",
" config = OmegaConf.load(config)\n",
" config.model.params.conditioner_config.params.emb_models[0].params.open_clip_embedding_config.params.init_device = device\n",
" config.model.params.conditioner_config.params.emb_models[\n",
" 0\n",
" ].params.open_clip_embedding_config.params.init_device = device\n",
" config.model.params.sampler_config.params.num_steps = num_steps\n",
" config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)\n",
" config.model.params.sampler_config.params.guider_config.params.num_frames = (\n",
" num_frames\n",
" )\n",
" with torch.device(device):\n",
" model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)\n",
" return model\n",
" filter = DeepFloydDataFiltering(verbose=False, device=device)\n",
" return model, filter\n",
"\n",
"version = \"svd_xt\"\n",
"if version == \"svd\":\n",
" num_frames = 14\n",
" num_steps = 25\n",
" # output_folder = default(output_folder, \"outputs/simple_video_sample/svd/\")\n",
" model_config = \"generative-models/scripts/sampling/configs/svd.yaml\"\n",
"elif version == \"svd_xt\":\n",
" num_frames = 25\n",
" num_steps = 30\n",
" # output_folder = default(output_folder, \"outputs/simple_video_sample/svd_xt/\")\n",
" model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n",
"else:\n",
" raise ValueError(f\"Version {version} does not exist.\")\n",
"\n",
"num_frames = 25\n",
"num_steps = 30\n",
"model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = load_model(model_config, device, num_frames, num_steps)\n",
"model, filter = load_model(\n",
" model_config,\n",
" device,\n",
" num_frames,\n",
" num_steps,\n",
")\n",
"# move models expect unet to cpu\n",
"model.conditioner.cpu()\n",
"model.first_stage_model.cpu()\n",
"# change the dtype of unet\n",
"model.model.to(dtype=torch.float16)\n",
"torch.cuda.empty_cache()\n",
"model = model.requires_grad_(False)\n",
"\n",
"# Sampling function\n",
"import math\n",
"import os\n",
"from glob import glob\n",
"from pathlib import Path\n",
"from typing import Optional\n",
"\n",
"import cv2\n",
"import numpy as np\n",
"import torch\n",
"from einops import rearrange, repeat\n",
"from fire import Fire\n",
"\n",
"from PIL import Image\n",
"from torchvision.transforms import ToTensor\n",
"from torchvision.transforms import functional as TF\n",
"\n",
"from sgm.inference.helpers import embed_watermark\n",
"from sgm.util import default, instantiate_from_config\n",
"\n",
"def get_unique_embedder_keys_from_conditioner(conditioner):\n",
" return list(set([x.input_key for x in conditioner.embedders]))\n",
"\n",
"def get_batch(keys, value_dict, N, T, device, dtype=None):\n",
" batch = {}\n",
" batch_uc = {}\n",
"\n",
" for key in keys:\n",
" if key == \"fps_id\":\n",
" batch[key] = (\n",
Expand Down Expand Up @@ -99,15 +141,17 @@
" )\n",
" else:\n",
" batch[key] = value_dict[key]\n",
"\n",
" if T is not None:\n",
" batch[\"num_video_frames\"] = T\n",
"\n",
" for key in batch.keys():\n",
" if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n",
" batch_uc[key] = torch.clone(batch[key])\n",
" return batch, batch_uc\n",
"\n",
"def sample(\n",
" input_path: str = \"/content/test_image.png\",\n",
" input_path: str = \"assets/test_image.png\", # Can either be image file or folder with image files\n",
" resize_image: bool = False,\n",
" num_frames: Optional[int] = None,\n",
" num_steps: Optional[int] = None,\n",
Expand Down Expand Up @@ -153,12 +197,14 @@
" print(f\"Resizing {image.size} to (1024, 576)\")\n",
" image = TF.resize(TF.resize(image, 1024), (576, 1024))\n",
" w, h = image.size\n",
"\n",
" if h % 64 != 0 or w % 64 != 0:\n",
" width, height = map(lambda x: x - x % 64, (w, h))\n",
" image = image.resize((width, height))\n",
" print(\n",
" f\"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!\"\n",
" )\n",
"\n",
" image = ToTensor()(image)\n",
" image = image * 2.0 - 1.0\n",
"\n",
Expand All @@ -176,8 +222,10 @@
" print(\n",
" \"WARNING: High motion bucket! This may lead to suboptimal performance.\"\n",
" )\n",
"\n",
" if fps_id < 5:\n",
" print(\"WARNING: Small fps value! This may lead to suboptimal performance.\")\n",
"\n",
" if fps_id > 30:\n",
" print(\"WARNING: Large fps value! This may lead to suboptimal performance.\")\n",
"\n",
Expand Down Expand Up @@ -226,19 +274,26 @@
" c[k] = c[k].to(dtype=torch.float16)\n",
"\n",
" randn = torch.randn(shape, device=device, dtype=torch.float16)\n",
"\n",
" additional_model_inputs = {}\n",
" additional_model_inputs[\"image_only_indicator\"] = torch.zeros(2, num_frames).to(device)\n",
" additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n",
" 2, num_frames\n",
" ).to(device, )\n",
" additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n",
"\n",
" for k in additional_model_inputs:\n",
" if isinstance(additional_model_inputs[k], torch.Tensor):\n",
" additional_model_inputs[k] = additional_model_inputs[k].to(dtype=torch.float16)\n",
"\n",
" def denoiser(input, sigma, c):\n",
" return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)\n",
" return model.denoiser(\n",
" model.model, input, sigma, c, **additional_model_inputs\n",
" )\n",
"\n",
" samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)\n",
" samples_z.to(dtype=model.first_stage_model.dtype)\n",
" ##\n",
"\n",
" model.en_and_decode_n_samples_a_time = decoding_t\n",
" model.first_stage_model.to(device)\n",
" samples_x = model.decode_first_stage(samples_z)\n",
Expand All @@ -255,6 +310,9 @@
" fps_id + 1,\n",
" (samples.shape[-1], samples.shape[-2]),\n",
" )\n",
"\n",
" samples = embed_watermark(samples)\n",
" samples = filter(samples)\n",
" vid = (\n",
" (rearrange(samples, \"t c h w -> t h w c\") * 255)\n",
" .cpu()\n",
Expand Down Expand Up @@ -301,7 +359,9 @@
" decoding_t = gr.Number(precision=0, label=\"number of frames decoded at a time\", value=2)\n",
" with gr.Column():\n",
" video_out = gr.Video(label=\"generated video\")\n",
" examples = [[\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]]\n",
" examples = [\n",
" [\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]\n",
" ]\n",
" inputs = [image, resize_image, n_frames, n_steps, seed, decoding_t]\n",
" outputs = [video_out]\n",
" btn.click(infer, inputs=inputs, outputs=outputs)\n",
Expand All @@ -326,4 +386,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

0 comments on commit 3b5b720

Please sign in to comment.