Skip to content

Commit

Permalink
Add option for float32 sampling with float16 UNet
Browse files Browse the repository at this point in the history
This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half().
  • Loading branch information
brkirch committed Jan 25, 2023
1 parent 48a1582 commit 84d9ce3
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- (You)
4 changes: 3 additions & 1 deletion modules/deepbooru_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

from modules import devices

# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more


Expand Down Expand Up @@ -196,7 +198,7 @@ def forward(self, *inputs):
t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
t_360 = self.n_Conv_0(t_359_padded)
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361)
Expand Down
2 changes: 2 additions & 0 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def enable_tf32():
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False


def randn(seed, shape):
Expand Down
15 changes: 8 additions & 7 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def depth2img_image_conditioning(self, source_image):
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
Expand Down Expand Up @@ -203,15 +204,15 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask =

# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
)

# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))

# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
Expand All @@ -225,10 +226,10 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image)
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)

if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)

# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
Expand Down Expand Up @@ -610,7 +611,7 @@ def get_conds_with_caching(function, required_prompts, steps, cache):
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"

with devices.autocast():
with devices.autocast(disable=devices.unet_needs_upcast):
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)

x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
Expand Down Expand Up @@ -988,7 +989,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):

image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(shared.device)
image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)

self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

Expand Down
29 changes: 29 additions & 0 deletions modules/sd_hijack_unet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import torch
from packaging import version

from modules import devices
from modules.sd_hijack_utils import CondFunc


class TorchHijackForUnet:
Expand Down Expand Up @@ -28,3 +32,28 @@ def cat(self, tensors, *args, **kwargs):


th = TorchHijackForUnet()


# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
for y in cond.keys():
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()

class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
def forward(self, x):
if devices.unet_needs_upcast:
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
else:
return torch.nn.GELU.forward(self, x)

unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
28 changes: 28 additions & 0 deletions modules/sd_hijack_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import importlib

class CondFunc:
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-2, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
10 changes: 10 additions & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):

if not shared.cmd_opts.no_half:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)

# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None

model.half()
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model

devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
devices.dtype_unet = model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16

model.first_stage_model.to(devices.dtype_vae)

Expand Down Expand Up @@ -372,6 +380,8 @@ def load_model(checkpoint_info=None):

if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True

timer = Timer()

Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
Expand Down

0 comments on commit 84d9ce3

Please sign in to comment.