Skip to content

Commit

Permalink
Fix Broken Demo
Browse files Browse the repository at this point in the history
  • Loading branch information
graphemecluster committed Sep 30, 2022
1 parent 9cfaaa4 commit d861745
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 81 deletions.
112 changes: 62 additions & 50 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
"id": "UCMFMJV7K-ag"
},
"source": [
"!pip install imageio==2.4.1 &> /dev/null",
"!pip install ffmpy &> /dev/null\n",
"!git init -q .\n",
"%%capture\n",
"%pip install ffmpeg-python imageio-ffmpeg\n",
"!git init .\n",
"!git remote add origin https://github.com/AliaksandrSiarohin/first-order-model\n",
"!git pull -q origin master\n",
"!git clone -q https://github.com/graphemecluster/first-order-model-demo demo"
"!git pull origin master\n",
"!git clone https://github.com/graphemecluster/first-order-model-demo demo"
],
"execution_count": null,
"outputs": []
Expand All @@ -59,6 +59,7 @@
"import IPython.display\n",
"import PIL.Image\n",
"import cv2\n",
"import ffmpeg\n",
"import imageio\n",
"import io\n",
"import ipywidgets\n",
Expand All @@ -68,11 +69,12 @@
"import skimage.transform\n",
"import warnings\n",
"from base64 import b64encode\n",
"from demo import load_checkpoints, make_animation\n",
"from ffmpy import FFmpeg\n",
"from demo import load_checkpoints, make_animation # type: ignore (local file)\n",
"from google.colab import files, output\n",
"from IPython.display import HTML, Javascript\n",
"from shutil import copyfileobj\n",
"from skimage import img_as_ubyte\n",
"from tempfile import NamedTemporaryFile\n",
"warnings.filterwarnings(\"ignore\")\n",
"os.makedirs(\"user\", exist_ok=True)\n",
"\n",
Expand Down Expand Up @@ -101,10 +103,10 @@
"\twidth: 250px;\n",
"}\n",
".widget-checkbox {\n",
" width: 650px;\n",
"\twidth: 650px;\n",
"}\n",
".widget-checkbox + .widget-checkbox {\n",
" margin-top: -6px;\n",
"\tmargin-top: -6px;\n",
"}\n",
".input-widget .output_html {\n",
"\ttext-align: center;\n",
Expand All @@ -114,9 +116,6 @@
"\tcolor: lightgray;\n",
"\tfont-size: 72px;\n",
"}\n",
"div.stream {\n",
"\tdisplay: none;\n",
"}\n",
".title {\n",
"\tfont-size: 20px;\n",
"\tfont-weight: bold;\n",
Expand Down Expand Up @@ -202,6 +201,9 @@
".loading-label {\n",
"\tcolor: gray;\n",
"}\n",
".video {\n",
"\tmargin: 0;\n",
"}\n",
".comparison-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
Expand All @@ -226,10 +228,7 @@
"\treturn imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()\n",
"\n",
"def create_image(i, j):\n",
"\timage_widget = ipywidgets.Image(\n",
"\t\tvalue=open('demo/images/%d%d.png' % (i, j), 'rb').read(),\n",
"\t\tformat='png'\n",
"\t)\n",
"\timage_widget = ipywidgets.Image.from_file('demo/images/%d%d.png' % (i, j))\n",
"\timage_widget.add_class('resource')\n",
"\timage_widget.add_class('resource-image')\n",
"\timage_widget.add_class('resource-image%d%d' % (i, j))\n",
Expand Down Expand Up @@ -260,7 +259,7 @@
"def convert_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tFFmpeg(inputs={'output.mp4': None}, outputs={'scaled.mp4': '-vf \"scale=1080x1080:flags=lanczos,pad=1920:1080:420:0\" -y'}).run()\n",
"\tffmpeg.input('output.mp4').output('scaled.mp4', vf='scale=1080x1080:flags=lanczos,pad=1920:1080:420:0').overwrite_output().run()\n",
"\tfiles.download('scaled.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
Expand Down Expand Up @@ -340,7 +339,8 @@
"loader.add_class(\"loader\")\n",
"loading_label = ipywidgets.Label(\"This may take several minutes to process…\")\n",
"loading_label.add_class(\"loading-label\")\n",
"loading = ipywidgets.VBox([loader, loading_label])\n",
"progress_bar = ipywidgets.Output()\n",
"loading = ipywidgets.VBox([loader, loading_label, progress_bar])\n",
"loading.add_class('loading')\n",
"\n",
"output_widget = ipywidgets.Output()\n",
Expand Down Expand Up @@ -421,10 +421,10 @@
"output.register_callback(\"notebook.select_video\", select_video)\n",
"\n",
"def resize(image, size=(256, 256)):\n",
" w, h = image.size\n",
" d = min(w, h)\n",
" r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n",
" return image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n",
"\tw, h = image.size\n",
"\td = min(w, h)\n",
"\tr = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n",
"\treturn image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n",
"\n",
"def upload_image(change):\n",
"\tglobal selected_image\n",
Expand Down Expand Up @@ -476,39 +476,51 @@
"\tfor frame in reader:\n",
"\t\tdriving_video.append(frame)\n",
"\tgenerator, kp_detector = load_checkpoints(config_path='config/%s-256.yaml' % model.value, checkpoint_path=filename)\n",
"\tpredictions = make_animation(\n",
"\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n",
"\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n",
"\t\tgenerator,\n",
"\t\tkp_detector,\n",
"\t\trelative=relative.value,\n",
"\t\tadapt_movement_scale=adapt_movement_scale.value\n",
"\t)\n",
"\twith progress_bar:\n",
"\t\tpredictions = make_animation(\n",
"\t\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n",
"\t\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n",
"\t\t\tgenerator,\n",
"\t\t\tkp_detector,\n",
"\t\t\trelative=relative.value,\n",
"\t\t\tadapt_movement_scale=adapt_movement_scale.value\n",
"\t\t)\n",
"\tprogress_bar.clear_output()\n",
"\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\tif selected_video.startswith('user/') or selected_video == 'demo/videos/0.mp4':\n",
"\t\timageio.mimsave('temp.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\t\tFFmpeg(inputs={'temp.mp4': None, selected_video: None}, outputs={'output.mp4': '-c copy -y'}).run()\n",
"\telse:\n",
"\t\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\t\twith NamedTemporaryFile(suffix='.mp4') as output:\n",
"\t\t\tffmpeg.output(ffmpeg.input('output.mp4').video, ffmpeg.input(selected_video).audio, output.name, c='copy').overwrite_output().run()\n",
"\t\t\twith open('output.mp4', 'wb') as result:\n",
"\t\t\t\tcopyfileobj(output, result)\n",
"\twith output_widget:\n",
"\t\tdisplay(HTML('<video id=\"left\" controls src=\"data:video/mp4;base64,%s\" />' % b64encode(open('output.mp4', 'rb').read()).decode()))\n",
"\t\tvideo_widget = ipywidgets.Video.from_file('output.mp4', autoplay=False, loop=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-left')\n",
"\t\tdisplay(video_widget)\n",
"\twith comparison_widget:\n",
"\t\tdisplay(HTML('<video id=\"right\" muted src=\"data:video/mp4;base64,%s\" />' % b64encode(open(selected_video, 'rb').read()).decode()))\n",
"\t\tvideo_widget = ipywidgets.Video.from_file(selected_video, autoplay=False, loop=False, controls=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-right')\n",
"\t\tdisplay(video_widget)\n",
"\tdisplay(Javascript(\"\"\"\n",
"\t(function(left, right) {\n",
"\t\tleft.addEventListener(\"play\", function() {\n",
"\t\t\tright.play();\n",
"\t\t});\n",
"\t\tleft.addEventListener(\"pause\", function() {\n",
"\t\t\tright.pause();\n",
"\t\t});\n",
"\t\tleft.addEventListener(\"seeking\", function() {\n",
"\t\t\tright.currentTime = left.currentTime;\n",
"\t\t});\n",
"\t})(document.getElementById(\"left\"), document.getElementById(\"right\"));\n",
"\tsetTimeout(function() {\n",
"\t\t(function(left, right) {\n",
"\t\t\tleft.addEventListener(\"play\", function() {\n",
"\t\t\t\tright.play();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"pause\", function() {\n",
"\t\t\t\tright.pause();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"seeking\", function() {\n",
"\t\t\t\tright.currentTime = left.currentTime;\n",
"\t\t\t});\n",
"\t\t\tright.muted = true;\n",
"\t\t})(document.getElementsByClassName(\"video-left\")[0], document.getElementsByClassName(\"video-right\")[0]);\n",
"\t}, 1000);\n",
"\t\"\"\"))\n",
"\t\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"generate_button.on_click(generate)\n",
"\n",
"loading.layout.display = 'none'\n",
Expand Down
52 changes: 24 additions & 28 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import matplotlib
matplotlib.use('Agg')
import os, sys
import sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm
from tqdm.auto import tqdm

import imageio
import numpy as np
Expand All @@ -15,10 +13,11 @@
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull

import moviepy
import moviepy.editor as mpe
import ffmpeg
from os.path import splitext
from shutil import copyfileobj
from tempfile import NamedTemporaryFile

if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
Expand All @@ -37,22 +36,22 @@ def load_checkpoints(config_path, checkpoint_path, cpu=False):
**config['model_params']['common_params'])
if not cpu:
kp_detector.cuda()

if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)

generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])

if not cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)

generator.eval()
kp_detector.eval()

return generator, kp_detector


Expand Down Expand Up @@ -80,7 +79,8 @@ def make_animation(source_image, driving_video, generator, kp_detector, relative
return predictions

def find_best_frame(source, driving, cpu=False):
import face_alignment
import face_alignment # type: ignore (local file)
from scipy.spatial import ConvexHull

def normalize_kp(kp):
kp = kp - kp.mean(axis=0, keepdims=True)
Expand Down Expand Up @@ -110,21 +110,20 @@ def normalize_kp(kp):
parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")

parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image")
parser.add_argument("--driving_video", default='sup-mat/source.png', help="path to driving video")
parser.add_argument("--driving_video", default='driving.mp4', help="path to driving video")
parser.add_argument("--result_video", default='result.mp4', help="path to output")

parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")

parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")

parser.add_argument("--best_frame", dest="best_frame", type=int, default=None,
help="Set frame to start from.")

parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, help="Set frame to start from.")

parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
parser.add_argument("--audio_on", dest="audio_on", action="store_true", help="option to have audio on." )

parser.add_argument("--audio", dest="audio", action="store_true", help="copy audio to output from the driving video" )

parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
Expand Down Expand Up @@ -159,11 +158,8 @@ def normalize_kp(kp):
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)

if opt.audio_on:
video_clip = mpe.VideoFileClip(opt.result_video)
audio_clip = mpe.AudioFileClip(opt.driving_video)

final_clip = video_clip.set_audio(audio_clip)
file_name_with_audio = opt.result_video.split(".")[0] + "-with-audio.mp4"
final_clip.write_videofile(file_name_with_audio, fps=fps, codec='libx264', audio_codec='aac', write_logfile=True, ffmpeg_params=['-level','4.0','-b:a','128k'])

if opt.audio:
with NamedTemporaryFile(suffix='.' + splitext(opt.result_video)[1]) as output:
ffmpeg.output(ffmpeg.input(opt.result_video).video, ffmpeg.input(opt.driving_video).audio, output.name, c='copy').run()
with open(opt.result_video, 'wb') as result:
copyfileobj(output, result)
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
imageio==2.13.3
ffmpeg-python==0.2.0
imageio==2.22.0
imageio-ffmpeg==0.4.7
matplotlib==2.2.2
numpy==1.22.0
pandas==0.23.4
Expand All @@ -10,5 +12,4 @@ scikit-learn==0.19.2
scipy==1.1.0
torch==1.0.0
torchvision==0.2.1
tqdm==4.24.0
moviepy==1.0.3
tqdm==4.64.1

0 comments on commit d861745

Please sign in to comment.