Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Nov 22, 2023
1 parent d2b7907 commit c2b67ea
Showing 1 changed file with 22 additions and 77 deletions.
99 changes: 22 additions & 77 deletions test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,87 +30,47 @@
"!ln -s /content/generative-models/scripts/util/detection/p_head_v1.npz /content/scripts/util/detection/p_head_v1.npz\n",
"!ln -s /content/generative-models/scripts/util/detection/w_head_v1.npz /content/scripts/util/detection/w_head_v1.npz\n",
"\n",
"# Load Model\n",
"import sys\n",
"import sys, os, math, torch, cv2\n",
"from omegaconf import OmegaConf\n",
"import torch\n",
"from glob import glob\n",
"from pathlib import Path\n",
"from typing import Optional\n",
"import numpy as np\n",
"from einops import rearrange, repeat\n",
"\n",
"from PIL import Image\n",
"from torchvision.transforms import ToTensor\n",
"from torchvision.transforms import functional as TF\n",
"from sgm.util import instantiate_from_config\n",
"\n",
"sys.path.append(\"generative-models\")\n",
"from sgm.util import default, instantiate_from_config\n",
"\n",
"def load_model(\n",
" config: str,\n",
" device: str,\n",
" num_frames: int,\n",
" num_steps: int,\n",
"):\n",
"def load_model(config: str, device: str, num_frames: int, num_steps: int):\n",
" config = OmegaConf.load(config)\n",
" config.model.params.conditioner_config.params.emb_models[\n",
" 0\n",
" ].params.open_clip_embedding_config.params.init_device = device\n",
" config.model.params.conditioner_config.params.emb_models[0].params.open_clip_embedding_config.params.init_device = device\n",
" config.model.params.sampler_config.params.num_steps = num_steps\n",
" config.model.params.sampler_config.params.guider_config.params.num_frames = (\n",
" num_frames\n",
" )\n",
" config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)\n",
" with torch.device(device):\n",
" model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)\n",
"\n",
" return model\n",
"\n",
"version = \"svd_xt\"\n",
"if version == \"svd\":\n",
" num_frames = 14\n",
" num_steps = 25\n",
" # output_folder = default(output_folder, \"outputs/simple_video_sample/svd/\")\n",
" model_config = \"generative-models/scripts/sampling/configs/svd.yaml\"\n",
"elif version == \"svd_xt\":\n",
" num_frames = 25\n",
" num_steps = 30\n",
" # output_folder = default(output_folder, \"outputs/simple_video_sample/svd_xt/\")\n",
" model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n",
"else:\n",
" raise ValueError(f\"Version {version} does not exist.\")\n",
"\n",
"num_frames = 25\n",
"num_steps = 30\n",
"model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = load_model(\n",
" model_config,\n",
" device,\n",
" num_frames,\n",
" num_steps,\n",
")\n",
"# move models expect unet to cpu\n",
"model = load_model(model_config, device, num_frames, num_steps)\n",
"model.conditioner.cpu()\n",
"model.first_stage_model.cpu()\n",
"# change the dtype of unet\n",
"model.model.to(dtype=torch.float16)\n",
"torch.cuda.empty_cache()\n",
"model = model.requires_grad_(False)\n",
"\n",
"# Sampling function\n",
"import math\n",
"import os\n",
"from glob import glob\n",
"from pathlib import Path\n",
"from typing import Optional\n",
"\n",
"import cv2\n",
"import numpy as np\n",
"import torch\n",
"from einops import rearrange, repeat\n",
"from fire import Fire\n",
"\n",
"from PIL import Image\n",
"from torchvision.transforms import ToTensor\n",
"from torchvision.transforms import functional as TF\n",
"from sgm.util import default, instantiate_from_config\n",
"\n",
"def get_unique_embedder_keys_from_conditioner(conditioner):\n",
" return list(set([x.input_key for x in conditioner.embedders]))\n",
"\n",
"def get_batch(keys, value_dict, N, T, device, dtype=None):\n",
" batch = {}\n",
" batch_uc = {}\n",
"\n",
" for key in keys:\n",
" if key == \"fps_id\":\n",
" batch[key] = (\n",
Expand Down Expand Up @@ -138,17 +98,15 @@
" )\n",
" else:\n",
" batch[key] = value_dict[key]\n",
"\n",
" if T is not None:\n",
" batch[\"num_video_frames\"] = T\n",
"\n",
" for key in batch.keys():\n",
" if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n",
" batch_uc[key] = torch.clone(batch[key])\n",
" return batch, batch_uc\n",
"\n",
"def sample(\n",
" input_path: str = \"assets/test_image.png\", # Can either be image file or folder with image files\n",
" input_path: str = \"/content/test_image.png\",\n",
" resize_image: bool = False,\n",
" num_frames: Optional[int] = None,\n",
" num_steps: Optional[int] = None,\n",
Expand Down Expand Up @@ -194,14 +152,12 @@
" print(f\"Resizing {image.size} to (1024, 576)\")\n",
" image = TF.resize(TF.resize(image, 1024), (576, 1024))\n",
" w, h = image.size\n",
"\n",
" if h % 64 != 0 or w % 64 != 0:\n",
" width, height = map(lambda x: x - x % 64, (w, h))\n",
" image = image.resize((width, height))\n",
" print(\n",
" f\"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!\"\n",
" )\n",
"\n",
" image = ToTensor()(image)\n",
" image = image * 2.0 - 1.0\n",
"\n",
Expand All @@ -219,10 +175,8 @@
" print(\n",
" \"WARNING: High motion bucket! This may lead to suboptimal performance.\"\n",
" )\n",
"\n",
" if fps_id < 5:\n",
" print(\"WARNING: Small fps value! This may lead to suboptimal performance.\")\n",
"\n",
" if fps_id > 30:\n",
" print(\"WARNING: Large fps value! This may lead to suboptimal performance.\")\n",
"\n",
Expand Down Expand Up @@ -271,26 +225,19 @@
" c[k] = c[k].to(dtype=torch.float16)\n",
"\n",
" randn = torch.randn(shape, device=device, dtype=torch.float16)\n",
"\n",
" additional_model_inputs = {}\n",
" additional_model_inputs[\"image_only_indicator\"] = torch.zeros(\n",
" 2, num_frames\n",
" ).to(device, )\n",
" additional_model_inputs[\"image_only_indicator\"] = torch.zeros(2, num_frames).to(device)\n",
" additional_model_inputs[\"num_video_frames\"] = batch[\"num_video_frames\"]\n",
"\n",
" for k in additional_model_inputs:\n",
" if isinstance(additional_model_inputs[k], torch.Tensor):\n",
" additional_model_inputs[k] = additional_model_inputs[k].to(dtype=torch.float16)\n",
"\n",
" def denoiser(input, sigma, c):\n",
" return model.denoiser(\n",
" model.model, input, sigma, c, **additional_model_inputs\n",
" )\n",
" return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)\n",
"\n",
" samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)\n",
" samples_z.to(dtype=model.first_stage_model.dtype)\n",
" ##\n",
"\n",
" model.en_and_decode_n_samples_a_time = decoding_t\n",
" model.first_stage_model.to(device)\n",
" samples_x = model.decode_first_stage(samples_z)\n",
Expand Down Expand Up @@ -353,9 +300,7 @@
" decoding_t = gr.Number(precision=0, label=\"number of frames decoded at a time\", value=2)\n",
" with gr.Column():\n",
" video_out = gr.Video(label=\"generated video\")\n",
" examples = [\n",
" [\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]\n",
" ]\n",
" examples = [\"https://user-images.githubusercontent.com/33302880/284758167-367a25d8-8d7b-42d3-8391-6d82813c7b0f.png\"]\n",
" inputs = [image, resize_image, n_frames, n_steps, seed, decoding_t]\n",
" outputs = [video_out]\n",
" btn.click(infer, inputs=inputs, outputs=outputs)\n",
Expand Down

0 comments on commit c2b67ea

Please sign in to comment.