Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Nov 22, 2023
1 parent 11ab0fa commit 6aa8346
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 3 deletions.
8 changes: 5 additions & 3 deletions stable_video_diffusion_fp16_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"\n",
"sys.path.append(\"generative-models\")\n",
"from sgm.util import default, instantiate_from_config\n",
"from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\n",
"\n",
"def load_model(\n",
" config: str,\n",
Expand All @@ -54,8 +55,8 @@
" )\n",
" with torch.device(device):\n",
" model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)\n",
"\n",
" return model\n",
" filter = DeepFloydDataFiltering(verbose=False, device=device)\n",
" return model, filter\n",
"\n",
"version = \"svd_xt\"\n",
"if version == \"svd\":\n",
Expand All @@ -72,7 +73,7 @@
" raise ValueError(f\"Version {version} does not exist.\")\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = load_model(\n",
"model, filter = load_model(\n",
" model_config,\n",
" device,\n",
" num_frames,\n",
Expand Down Expand Up @@ -311,6 +312,7 @@
" )\n",
"\n",
" samples = embed_watermark(samples)\n",
" samples = filter(samples)\n",
" vid = (\n",
" (rearrange(samples, \"t c h w -> t h w c\") * 255)\n",
" .cpu()\n",
Expand Down
Loading

0 comments on commit 6aa8346

Please sign in to comment.