Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Nov 22, 2023
1 parent 533b5e6 commit ea7d0d4
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"from torchvision.transforms import ToTensor\n",
"from torchvision.transforms import functional as TF\n",
"from sgm.util import instantiate_from_config\n",
"from sgm.inference.helpers import embed_watermark\n",
"from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering\n",
"\n",
"def load_model(config: str, device: str, num_frames: int, num_steps: int):\n",
" config = OmegaConf.load(config)\n",
Expand All @@ -53,13 +55,14 @@
" config.model.params.sampler_config.params.guider_config.params.num_frames = (num_frames)\n",
" with torch.device(device):\n",
" model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False)\n",
" return model\n",
" filter = DeepFloydDataFiltering(verbose=False, device=device)\n",
" return model, filter\n",
"\n",
"num_frames = 25\n",
"num_steps = 30\n",
"model_config = \"generative-models/scripts/sampling/configs/svd_xt.yaml\"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = load_model(model_config, device, num_frames, num_steps)\n",
"model, filter = load_model(model_config, device, num_frames, num_steps)\n",
"model.conditioner.cpu()\n",
"model.first_stage_model.cpu()\n",
"model.model.to(dtype=torch.float16)\n",
Expand Down Expand Up @@ -255,6 +258,8 @@
" fps_id + 1,\n",
" (samples.shape[-1], samples.shape[-2]),\n",
" )\n",
" samples = embed_watermark(samples)\n",
" samples = filter(samples)\n",
" vid = (\n",
" (rearrange(samples, \"t c h w -> t h w c\") * 255)\n",
" .cpu()\n",
Expand Down

0 comments on commit ea7d0d4

Please sign in to comment.