Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: macOS support #143

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,2 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.3.0
torchvision==0.18.0
torchaudio==2.3.0

numpy==1.26.4
pyyaml==6.0.1
opencv-python==4.10.0.84
scipy==1.13.1
imageio==2.34.2
lmdb==1.4.1
tqdm==4.66.4
rich==13.7.1
ffmpeg-python==0.2.0
-r requirements_base.txt
onnxruntime-gpu==1.18.0
onnx==1.16.1
scikit-image==0.24.0
albumentations==1.4.10
matplotlib==3.9.0
imageio-ffmpeg==0.5.1
tyro==0.8.5
gradio==4.37.1
2 changes: 2 additions & 0 deletions requirements_apple.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-r requirements_base.txt
onnxruntime-silicon==1.16.3
21 changes: 21 additions & 0 deletions requirements_base.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.3.0
torchvision==0.18.0
torchaudio==2.3.0

numpy==1.26.4
pyyaml==6.0.1
opencv-python==4.10.0.84
scipy==1.13.1
imageio==2.34.2
lmdb==1.4.1
tqdm==4.66.4
rich==13.7.1
ffmpeg-python==0.2.0
onnx==1.16.1
scikit-image==0.24.0
albumentations==1.4.10
matplotlib==3.9.0
imageio-ffmpeg==0.5.1
tyro==0.8.5
gradio==4.37.1
6 changes: 3 additions & 3 deletions src/live_portrait_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,14 @@ def execute(self, args: ArgumentConfig):
wfp_concat = None
flag_has_audio = (not flag_load_from_template) and has_audio_stream(args.driving_info)

######### build final concact result #########
######### build final concat result #########
# driving frame | source image | generation, or source image | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)

if flag_has_audio:
# final result with concact
# final result with concat
wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat_with_audio.mp4')
add_audio_to_video(wfp_concat, args.driving_info, wfp_concat_with_audio)
os.replace(wfp_concat_with_audio, wfp_concat)
Expand All @@ -247,7 +247,7 @@ def execute(self, args: ArgumentConfig):
if wfp_template not in (None, ''):
log(f'Animated template: {wfp_template}, you can specify `-d` argument with this template path next time to avoid cropping video, motion making and protecting privacy.', style='bold green')
log(f'Animated video: {wfp}')
log(f'Animated video with concact: {wfp_concat}')
log(f'Animated video with concat: {wfp_concat}')

return wfp, wfp_concat

Expand Down
41 changes: 25 additions & 16 deletions src/live_portrait_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Wrapper for LivePortrait core functions
"""

import contextlib
import os.path as osp
import numpy as np
import cv2
Expand All @@ -28,7 +29,10 @@ def __init__(self, inference_cfg: InferenceConfig):
if inference_cfg.flag_force_cpu:
self.device = 'cpu'
else:
self.device = 'cuda:' + str(self.device_id)
if torch.backends.mps.is_available():
self.device = 'mps'
else:
self.device = 'cuda:' + str(self.device_id)

model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
Expand Down Expand Up @@ -57,6 +61,14 @@ def __init__(self, inference_cfg: InferenceConfig):

self.timer = Timer()

def inference_ctx(self):
if self.device == "mps":
ctx = contextlib.nullcontext()
else:
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
enabled=self.inference_cfg.flag_use_half_precision)
return ctx

def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.inference_cfg, k):
Expand Down Expand Up @@ -105,9 +117,8 @@ def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
""" get the appearance feature of the image by F
x: Bx3xHxW, normalized to 0~1
"""
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
feature_3d = self.appearance_feature_extractor(x)
with torch.no_grad(), self.inference_ctx():
feature_3d = self.appearance_feature_extractor(x)

return feature_3d.float()

Expand All @@ -117,9 +128,8 @@ def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
"""
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
kp_info = self.motion_extractor(x)
with torch.no_grad(), self.inference_ctx():
kp_info = self.motion_extractor(x)

if self.inference_cfg.flag_use_half_precision:
# float the dict
Expand Down Expand Up @@ -263,15 +273,14 @@ def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driv
kp_driving: BxNx3
"""
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
with torch.no_grad():
with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.inference_cfg.flag_use_half_precision):
if self.compile:
# Mark the beginning of a new CUDA Graph step
torch.compiler.cudagraph_mark_step_begin()
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
with torch.no_grad(), self.inference_ctx():
if self.compile:
# Mark the beginning of a new CUDA Graph step
torch.compiler.cudagraph_mark_step_begin()
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])

# float the dict
if self.inference_cfg.flag_use_half_precision:
Expand Down
2 changes: 1 addition & 1 deletion src/modules/dense_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_heatmap_representations(self, feature, kp_driving, kp_source):
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)

# adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
return heatmap
Expand Down
15 changes: 11 additions & 4 deletions src/utils/cropper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import torch

from ..config.crop_config import CropConfig
from .crop import (
Expand Down Expand Up @@ -43,10 +44,16 @@ def __init__(self, **kwargs) -> None:
flag_force_cpu = kwargs.get("flag_force_cpu", False)
if flag_force_cpu:
device = "cpu"
face_analysis_wrapper_provicer = ["CPUExecutionProvider"]
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
else:
device = "cuda"
face_analysis_wrapper_provicer = ["CUDAExecutionProvider"]
if torch.backends.mps.is_available():
# Shape inference currently fails with CoreMLExecutionProvider
# for the retinaface model
device = "mps"
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
else:
device = "cuda"
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
self.landmark_runner = LandmarkRunner(
ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
onnx_provider=device,
Expand All @@ -57,7 +64,7 @@ def __init__(self, **kwargs) -> None:
self.face_analysis_wrapper = FaceAnalysisDIY(
name="buffalo_l",
root=make_abs_path(self.crop_cfg.insightface_root),
providers=face_analysis_wrapper_provicer,
providers=face_analysis_wrapper_provider,
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
self.face_analysis_wrapper.warmup()
Expand Down
2 changes: 1 addition & 1 deletion src/utils/dependencies/insightface/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def find_onnx_file(dir_path):
return paths[-1]

def get_default_providers():
return ['CUDAExecutionProvider', 'CPUExecutionProvider']
return ['CUDAExecutionProvider', 'CoreMLExecutionProvider', 'CPUExecutionProvider']

def get_default_provider_options():
return None
Expand Down
6 changes: 6 additions & 0 deletions src/utils/landmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def __init__(self, **kwargs):
('CUDAExecutionProvider', {'device_id': device_id})
]
)
elif onnx_provider.lower() == 'mps':
self.session = onnxruntime.InferenceSession(
ckpt_path, providers=[
'CoreMLExecutionProvider'
]
)
else:
opts = onnxruntime.SessionOptions()
opts.intra_op_num_threads = 4 # 默认线程数为 4
Expand Down