Skip to content

Commit 894802b

Browse files
Add LatentCutToBatch node. (comfyanonymous#11411)
1 parent 28eaab6 commit 894802b

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

comfy_extras/nodes_latent.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing_extensions import override
66
from comfy_api.latest import ComfyExtension, io
77
import logging
8+
import math
89

910
def reshape_latent_to(target_shape, latent, repeat_batch=True):
1011
if latent.shape[1:] != target_shape[1:]:
@@ -207,6 +208,47 @@ def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
207208
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
208209
return io.NodeOutput(samples_out)
209210

211+
class LatentCutToBatch(io.ComfyNode):
212+
@classmethod
213+
def define_schema(cls):
214+
return io.Schema(
215+
node_id="LatentCutToBatch",
216+
category="latent/advanced",
217+
inputs=[
218+
io.Latent.Input("samples"),
219+
io.Combo.Input("dim", options=["t", "x", "y"]),
220+
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
221+
],
222+
outputs=[
223+
io.Latent.Output(),
224+
],
225+
)
226+
227+
@classmethod
228+
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
229+
samples_out = samples.copy()
230+
231+
s1 = samples["samples"]
232+
233+
if "x" in dim:
234+
dim = s1.ndim - 1
235+
elif "y" in dim:
236+
dim = s1.ndim - 2
237+
elif "t" in dim:
238+
dim = s1.ndim - 3
239+
240+
if dim < 2:
241+
return io.NodeOutput(samples)
242+
243+
s = s1.movedim(dim, 1)
244+
if s.shape[1] < slice_size:
245+
slice_size = s.shape[1]
246+
elif s.shape[1] % slice_size != 0:
247+
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
248+
new_shape = [-1, slice_size] + list(s.shape[2:])
249+
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
250+
return io.NodeOutput(samples_out)
251+
210252
class LatentBatch(io.ComfyNode):
211253
@classmethod
212254
def define_schema(cls):
@@ -435,6 +477,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
435477
LatentInterpolate,
436478
LatentConcat,
437479
LatentCut,
480+
LatentCutToBatch,
438481
LatentBatch,
439482
LatentBatchSeedBehavior,
440483
LatentApplyOperation,

0 commit comments

Comments
 (0)