Skip to content

Commit

Permalink
Merge pull request #7 from hotshotco/feature/precision
Browse files Browse the repository at this point in the history
feat: --precision arg now in the inference script
  • Loading branch information
johnmullan authored Oct 4, 2023
2 parents 099632c + 3f0bacb commit 7c3d097
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
5 changes: 3 additions & 2 deletions hotshot_xl/pipelines/hotshot_xl_controlnet_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,10 +1301,11 @@ def decode_latents(self, latents):
# video = self.vae.decode(latents).sample
video = []
for frame_idx in tqdm(range(latents.shape[0])):
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
video.append(self.vae.decode(
latents[frame_idx:frame_idx+1]).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
video = (video / 2.0 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
video = video.cpu().float().numpy()
return video
Expand Down
2 changes: 1 addition & 1 deletion hotshot_xl/pipelines/hotshot_xl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def decode_latents(self, latents):
video = []
for frame_idx in tqdm(range(latents.shape[0])):
video.append(self.vae.decode(
latents[frame_idx:frame_idx+1].to(dtype=self.vae.dtype)).sample)
latents[frame_idx:frame_idx+1]).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2.0 + 0.5).clamp(0, 1)
Expand Down
62 changes: 45 additions & 17 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from hotshot_xl.utils import save_as_gif, extract_gif_frames_from_midpoint, scale_aspect_fill
from torch import autocast
from diffusers import ControlNetModel

from contextlib import contextmanager
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler

Expand Down Expand Up @@ -66,6 +66,12 @@ def parse_args():
parser.add_argument("--control_guidance_start", type=float, default=0.0)
parser.add_argument("--control_guidance_end", type=float, default=1.0)
parser.add_argument("--gif", type=str, default=None)
parser.add_argument("--precision", type=str, default='f16', choices=[
'f16', 'f32', 'bf16'
])
parser.add_argument("--autocast", type=str, default=None, choices=[
'f16', 'bf16'
])

return parser.parse_args()

Expand All @@ -86,6 +92,14 @@ def to_pil_images(video_frames: torch.Tensor, output_type='pil'):
images.append(video[j])
return images

@contextmanager
def maybe_auto_cast(data_type):
if data_type:
with autocast("cuda", dtype=data_type):
yield
else:
yield


def main():
args = parse_args()
Expand All @@ -112,10 +126,18 @@ def main():

data_type = torch.float32

if args.precision == 'f16':
data_type = torch.half
elif args.precision == 'f32':
data_type = torch.float32
elif args.precision == 'bf16':
data_type = torch.bfloat16

pipe_line_args = {
"torch_dtype": data_type,
"use_safetensors": True
}

PipelineClass = HotshotXLPipeline

if control_net_model_pretrained_path:
Expand Down Expand Up @@ -158,19 +180,25 @@ def main():

generator = torch.Generator().manual_seed(args.seed) if args.seed else None

with autocast("cuda", dtype=torch.bfloat16):
autocast_type = None
if args.autocast == 'f16':
autocast_type = torch.half
elif args.autocast == 'bf16':
autocast_type = torch.bfloat16

kwargs = {}

kwargs = {}
if args.gif and type(pipe) is HotshotXLControlNetPipeline:
kwargs['control_images'] = [
scale_aspect_fill(img, args.width, args.height).convert("RGB") \
for img in
extract_gif_frames_from_midpoint(args.gif, fps=args.video_length, target_duration=args.video_duration)
]
kwargs['controlnet_conditioning_scale'] = args.controlnet_conditioning_scale
kwargs['control_guidance_start'] = args.control_guidance_start
kwargs['control_guidance_end'] = args.control_guidance_end

if args.gif and type(pipe) is HotshotXLControlNetPipeline:
kwargs['control_images'] = [
scale_aspect_fill(img, args.width, args.height).convert("RGB") \
for img in
extract_gif_frames_from_midpoint(args.gif, fps=args.video_length, target_duration=args.video_duration)
]
kwargs['controlnet_conditioning_scale'] = args.controlnet_conditioning_scale
kwargs['control_guidance_start'] = args.control_guidance_start
kwargs['control_guidance_end'] = args.control_guidance_end
with maybe_auto_cast(autocast_type):

images = pipe(args.prompt,
negative_prompt=args.negative_prompt,
Expand All @@ -183,12 +211,12 @@ def main():
generator=generator,
output_type="tensor", **kwargs).videos

images = to_pil_images(images, output_type="pil")
images = to_pil_images(images, output_type="pil")

if args.video_length > 1:
save_as_gif(images, args.output, duration=args.video_duration // args.video_length)
else:
images[0].save(args.output, format='JPEG', quality=95)
if args.video_length > 1:
save_as_gif(images, args.output, duration=args.video_duration // args.video_length)
else:
images[0].save(args.output, format='JPEG', quality=95)


if __name__ == "__main__":
Expand Down

0 comments on commit 7c3d097

Please sign in to comment.