Skip to content

Commit

Permalink
Merge pull request Tencent#49 from zsxkib/main
Browse files Browse the repository at this point in the history
Replicate API Added Part 2!: New checkpoint added, fp16 input param removed for simplicity
  • Loading branch information
gujiaxi authored Jul 17, 2024
2 parents ccab45e + 6677c3a commit ce20af1
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def setup(self):
model_files = [
"DWPose.tar",
"MimicMotion.pth",
"MimicMotion_1-1.pth",
"SVD.tar",
]
for model_file in model_files:
Expand All @@ -72,16 +73,18 @@ def setup(self):
from mimicmotion.utils.utils import save_to_mp4
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose

# Load config
# Load config with new checkpoint as default
self.config = OmegaConf.create(
{
"base_model_path": "models/SVD/stable-video-diffusion-img2vid-xt-1-1",
"ckpt_path": "models/MimicMotion.pth",
"ckpt_path": "models/MimicMotion_1-1.pth",
}
)

# Create pipeline
# Create the pipeline with the new checkpoint
self.pipeline = create_pipeline(self.config, self.device)
self.current_checkpoint = "v1-1"
self.current_dtype = torch.get_default_dtype()

def predict(
self,
Expand Down Expand Up @@ -140,9 +143,10 @@ def predict(
description="Random seed. Leave blank to randomize the seed",
default=None,
),
use_fp16: bool = Input(
description="Use half-precision floating point (float16). Faster but slightly less accurate than float32.",
default=True,
checkpoint_version: str = Input(
description="Choose the checkpoint version to use",
choices=["v1", "v1-1"],
default="v1-1",
),
) -> Path:
"""Run a single prediction on the model"""
Expand All @@ -153,11 +157,40 @@ def predict(
num_inference_steps = denoising_steps
noise_aug_strength = noise_strength
fps = output_frames_per_second
use_fp16 = True

if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

need_pipeline_update = False

# Check if we need to switch checkpoints
if checkpoint_version != self.current_checkpoint:
if checkpoint_version == "v1":
self.config.ckpt_path = "models/MimicMotion.pth"
else: # v1-1
self.config.ckpt_path = "models/MimicMotion_1-1.pth"
need_pipeline_update = True
self.current_checkpoint = checkpoint_version

# Check if we need to switch dtype
target_dtype = torch.float16 if use_fp16 else torch.float32
if target_dtype != self.current_dtype:
torch.set_default_dtype(target_dtype)
need_pipeline_update = True
self.current_dtype = target_dtype

# Update pipeline if needed
if need_pipeline_update:
print(
f"Updating pipeline with checkpoint: {self.config.ckpt_path} and dtype: {torch.get_default_dtype()}"
)
self.pipeline = create_pipeline(self.config, self.device)

print(f"Using checkpoint: {self.config.ckpt_path}")
print(f"Using dtype: {torch.get_default_dtype()}")

print(
f"[!] ({type(ref_video)}) ref_video={ref_video}, "
f"[!] ({type(ref_image)}) ref_image={ref_image}, "
Expand Down Expand Up @@ -221,12 +254,6 @@ def predict(
if fps < 1 or fps > 60:
raise ValueError(f"FPS must be between 1 and 60, got {fps}")

if use_fp16:
torch.set_default_dtype(torch.float16)

# Recreate the pipeline with the new dtype
self.pipeline = create_pipeline(self.config, self.device)

try:
# Preprocess
pose_pixels, image_pixels = self.preprocess(
Expand Down

0 comments on commit ce20af1

Please sign in to comment.