Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 4 additions & 221 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,227 +17,10 @@ def __getitem__(self, key):
def __setitem__(self, key, item):
setattr(self, key, item)


def cubic_kernel(x, a: float = -0.75):
absx = x.abs()
absx2 = absx ** 2
absx3 = absx ** 3

w = (a + 2) * absx3 - (a + 3) * absx2 + 1
w2 = a * absx3 - 5*a * absx2 + 8*a * absx - 4*a

return torch.where(absx <= 1, w, torch.where(absx < 2, w2, torch.zeros_like(x)))

def get_indices_weights(in_size, out_size, scale):
# OpenCV-style half-pixel mapping
x = torch.arange(out_size, dtype=torch.float32)
x = (x + 0.5) / scale - 0.5

x0 = x.floor().long()
dx = x.unsqueeze(1) - (x0.unsqueeze(1) + torch.arange(-1, 3))

weights = cubic_kernel(dx)
weights = weights / weights.sum(dim=1, keepdim=True)

indices = x0.unsqueeze(1) + torch.arange(-1, 3)
indices = indices.clamp(0, in_size - 1)

return indices, weights

def resize_cubic_1d(x, out_size, dim):
b, c, h, w = x.shape
in_size = h if dim == 2 else w
scale = out_size / in_size

indices, weights = get_indices_weights(in_size, out_size, scale)

if dim == 2:
x = x.permute(0, 1, 3, 2)
x = x.reshape(-1, h)
else:
x = x.reshape(-1, w)

gathered = x[:, indices]
out = (gathered * weights.unsqueeze(0)).sum(dim=2)

if dim == 2:
out = out.reshape(b, c, w, out_size).permute(0, 1, 3, 2)
else:
out = out.reshape(b, c, h, out_size)

return out

def resize_cubic(img: torch.Tensor, size: tuple) -> torch.Tensor:
"""
Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
Implemented in pure PyTorch
"""

if img.ndim == 3:
img = img.unsqueeze(0)

img = img.permute(0, 3, 1, 2)

out_h, out_w = size
img = resize_cubic_1d(img, out_h, dim=2)
img = resize_cubic_1d(img, out_w, dim=3)
return img

def resize_area(img: torch.Tensor, size: tuple) -> torch.Tensor:
# vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
original_shape = img.shape
is_hwc = False

if img.ndim == 3:
if img.shape[0] <= 4:
img = img.unsqueeze(0)
else:
is_hwc = True
img = img.permute(2, 0, 1).unsqueeze(0)
elif img.ndim == 4:
pass
else:
raise ValueError("Expected image with 3 or 4 dims.")

B, C, H, W = img.shape
out_h, out_w = size
scale_y = H / out_h
scale_x = W / out_w

device = img.device

# compute the grid boundries
y_start = torch.arange(out_h, device=device).float() * scale_y
y_end = y_start + scale_y
x_start = torch.arange(out_w, device=device).float() * scale_x
x_end = x_start + scale_x

# for each output pixel, we will compute the range for it
y_start_int = torch.floor(y_start).long()
y_end_int = torch.ceil(y_end).long()
x_start_int = torch.floor(x_start).long()
x_end_int = torch.ceil(x_end).long()

# We will build the weighted sums by iterating over contributing input pixels once
output = torch.zeros((B, C, out_h, out_w), dtype=torch.float32, device=device)
area = torch.zeros((out_h, out_w), dtype=torch.float32, device=device)

max_kernel_h = int(torch.max(y_end_int - y_start_int).item())
max_kernel_w = int(torch.max(x_end_int - x_start_int).item())

for dy in range(max_kernel_h):
for dx in range(max_kernel_w):
# compute the weights for this offset for all output pixels

y_idx = y_start_int.unsqueeze(1) + dy
x_idx = x_start_int.unsqueeze(0) + dx

# clamp indices to image boundaries
y_idx_clamped = torch.clamp(y_idx, 0, H - 1)
x_idx_clamped = torch.clamp(x_idx, 0, W - 1)

# compute weights by broadcasting
y_weight = (torch.min(y_end.unsqueeze(1), y_idx_clamped.float() + 1.0) - torch.max(y_start.unsqueeze(1), y_idx_clamped.float())).clamp(min=0)
x_weight = (torch.min(x_end.unsqueeze(0), x_idx_clamped.float() + 1.0) - torch.max(x_start.unsqueeze(0), x_idx_clamped.float())).clamp(min=0)

weight = (y_weight * x_weight)

y_expand = y_idx_clamped.expand(out_h, out_w)
x_expand = x_idx_clamped.expand(out_h, out_w)


pixels = img[:, :, y_expand, x_expand]

# unsqueeze to broadcast
w = weight.unsqueeze(0).unsqueeze(0)

output += pixels * w
area += weight

# Normalize by area
output /= area.unsqueeze(0).unsqueeze(0)

if is_hwc:
return output[0].permute(1, 2, 0)
elif img.shape[0] == 1 and original_shape[0] <= 4:
return output[0]
else:
return output

def recenter(image, border_ratio: float = 0.2):

if image.shape[-1] == 4:
mask = image[..., 3]
else:
mask = torch.ones_like(image[..., 0:1]) * 255
image = torch.concatenate([image, mask], axis=-1)
mask = mask[..., 0]

H, W, C = image.shape

size = max(H, W)
result = torch.zeros((size, size, C), dtype = torch.uint8)

# as_tuple to match numpy behaviour
x_coords, y_coords = torch.nonzero(mask, as_tuple=True)

y_min, y_max = y_coords.min(), y_coords.max()
x_min, x_max = x_coords.min(), x_coords.max()

h = x_max - x_min
w = y_max - y_min

if h == 0 or w == 0:
raise ValueError('input image is empty')

desired_size = int(size * (1 - border_ratio))
scale = desired_size / max(h, w)

h2 = int(h * scale)
w2 = int(w * scale)

x2_min = (size - h2) // 2
x2_max = x2_min + h2

y2_min = (size - w2) // 2
y2_max = y2_min + w2

# note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
result[x2_min:x2_max, y2_min:y2_max] = resize_area(image[x_min:x_max, y_min:y_max], (h2, w2))

bg = torch.ones((result.shape[0], result.shape[1], 3), dtype = torch.uint8) * 255

mask = result[..., 3:].to(torch.float32) / 255
result = result[..., :3] * mask + bg * (1 - mask)

mask = mask * 255
result = result.clip(0, 255).to(torch.uint8)
mask = mask.clip(0, 255).to(torch.uint8)

return result

def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711],
crop=True, value_range = (-1, 1), border_ratio: float = None, recenter_size: int = 512):

if border_ratio is not None:

image = (image * 255).clamp(0, 255).to(torch.uint8)
image = [recenter(img, border_ratio = border_ratio) for img in image]

image = torch.stack(image, dim = 0)
image = resize_cubic(image, size = (recenter_size, recenter_size))

image = image / 255 * 2 - 1
low, high = value_range

image = (image - low) / (high - low)
image = image.permute(0, 2, 3, 1)

def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image

mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)

image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
if crop:
Expand All @@ -246,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
else:
scale_size = (size, size)

image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear" if border_ratio is not None else "bicubic", antialias=True)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
Expand Down Expand Up @@ -288,9 +71,9 @@ def load_sd(self, sd):
def get_sd(self):
return self.model.state_dict()

def encode_image(self, image, crop=True, border_ratio: float = None):
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, border_ratio=border_ratio).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)

outputs = Output()
Expand Down
21 changes: 0 additions & 21 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,27 +1058,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model = None
model_patcher = None

if isinstance(sd, dict) and all(k in sd for k in ["model", "vae", "conditioner"]):
from collections import OrderedDict
import gc

merged_sd = OrderedDict()

for k, v in sd["model"].items():
merged_sd[f"model.{k}"] = v

for k, v in sd["vae"].items():
merged_sd[f"vae.{k}"] = v

for key, value in sd["conditioner"].items():
merged_sd[f"conditioner.{key}"] = value

sd = merged_sd

del merged_sd
gc.collect()
torch.cuda.empty_cache()

diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
Expand Down
29 changes: 9 additions & 20 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,31 +998,20 @@ def load_clip(self, clip_name):
class CLIPVisionEncode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",),
"crop": (["center", "none", "recenter"],),
},
"optional": {
"border_ratio": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 0.5, "step": 0.01, "visible_if": {"crop": "recenter"},}),
}
}

return {"required": { "clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",),
"crop": (["center", "none"],)
}}
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode"

CATEGORY = "conditioning"

def encode(self, clip_vision, image, crop, border_ratio):
crop_image = crop == "center"

if crop == "recenter":
crop_image = True
else:
border_ratio = None

output = clip_vision.encode_image(image, crop=crop_image, border_ratio = border_ratio)
def encode(self, clip_vision, image, crop):
crop_image = True
if crop != "center":
crop_image = False
output = clip_vision.encode_image(image, crop=crop_image)
return (output,)

class StyleModelLoader:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ kornia>=0.7.1
spandrel
soundfile
pydantic~=2.0
pydantic-settings~=2.0
pydantic-settings~=2.0
Loading