|
5 | 5 | from typing_extensions import override |
6 | 6 | from comfy_api.latest import ComfyExtension, io |
7 | 7 | import logging |
| 8 | +import math |
8 | 9 |
|
9 | 10 | def reshape_latent_to(target_shape, latent, repeat_batch=True): |
10 | 11 | if latent.shape[1:] != target_shape[1:]: |
@@ -207,6 +208,47 @@ def execute(cls, samples, dim, index, amount) -> io.NodeOutput: |
207 | 208 | samples_out["samples"] = torch.narrow(s1, dim, index, amount) |
208 | 209 | return io.NodeOutput(samples_out) |
209 | 210 |
|
| 211 | +class LatentCutToBatch(io.ComfyNode): |
| 212 | + @classmethod |
| 213 | + def define_schema(cls): |
| 214 | + return io.Schema( |
| 215 | + node_id="LatentCutToBatch", |
| 216 | + category="latent/advanced", |
| 217 | + inputs=[ |
| 218 | + io.Latent.Input("samples"), |
| 219 | + io.Combo.Input("dim", options=["t", "x", "y"]), |
| 220 | + io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), |
| 221 | + ], |
| 222 | + outputs=[ |
| 223 | + io.Latent.Output(), |
| 224 | + ], |
| 225 | + ) |
| 226 | + |
| 227 | + @classmethod |
| 228 | + def execute(cls, samples, dim, slice_size) -> io.NodeOutput: |
| 229 | + samples_out = samples.copy() |
| 230 | + |
| 231 | + s1 = samples["samples"] |
| 232 | + |
| 233 | + if "x" in dim: |
| 234 | + dim = s1.ndim - 1 |
| 235 | + elif "y" in dim: |
| 236 | + dim = s1.ndim - 2 |
| 237 | + elif "t" in dim: |
| 238 | + dim = s1.ndim - 3 |
| 239 | + |
| 240 | + if dim < 2: |
| 241 | + return io.NodeOutput(samples) |
| 242 | + |
| 243 | + s = s1.movedim(dim, 1) |
| 244 | + if s.shape[1] < slice_size: |
| 245 | + slice_size = s.shape[1] |
| 246 | + elif s.shape[1] % slice_size != 0: |
| 247 | + s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size] |
| 248 | + new_shape = [-1, slice_size] + list(s.shape[2:]) |
| 249 | + samples_out["samples"] = s.reshape(new_shape).movedim(1, dim) |
| 250 | + return io.NodeOutput(samples_out) |
| 251 | + |
210 | 252 | class LatentBatch(io.ComfyNode): |
211 | 253 | @classmethod |
212 | 254 | def define_schema(cls): |
@@ -435,6 +477,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: |
435 | 477 | LatentInterpolate, |
436 | 478 | LatentConcat, |
437 | 479 | LatentCut, |
| 480 | + LatentCutToBatch, |
438 | 481 | LatentBatch, |
439 | 482 | LatentBatchSeedBehavior, |
440 | 483 | LatentApplyOperation, |
|
0 commit comments