|
17 | 17 | import comfy.patcher_extension |
18 | 18 | import comfy.hooks |
19 | 19 | import comfy.context_windows |
| 20 | +import comfy.utils |
20 | 21 | import scipy.stats |
21 | 22 | import numpy |
22 | 23 |
|
@@ -61,15 +62,15 @@ def get_area_and_mult(conds, x_in, timestep_in): |
61 | 62 | if "mask_strength" in conds: |
62 | 63 | mask_strength = conds["mask_strength"] |
63 | 64 | mask = conds['mask'] |
64 | | - assert (mask.shape[1:] == x_in.shape[2:]) |
| 65 | + # assert (mask.shape[1:] == x_in.shape[2:]) |
65 | 66 |
|
66 | 67 | mask = mask[:input_x.shape[0]] |
67 | 68 | if area is not None: |
68 | 69 | for i in range(len(dims)): |
69 | 70 | mask = mask.narrow(i + 1, area[len(dims) + i], area[i]) |
70 | 71 |
|
71 | 72 | mask = mask * mask_strength |
72 | | - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) |
| 73 | + mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1)) |
73 | 74 | else: |
74 | 75 | mask = torch.ones_like(input_x) |
75 | 76 | mult = mask * strength |
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device): |
553 | 554 | if len(mask.shape) == len(dims): |
554 | 555 | mask = mask.unsqueeze(0) |
555 | 556 | if mask.shape[1:] != dims: |
556 | | - mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1) |
| 557 | + if mask.ndim < 4: |
| 558 | + mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1) |
| 559 | + else: |
| 560 | + mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none') |
557 | 561 |
|
558 | 562 | if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2 |
559 | 563 | bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) |
|
0 commit comments