-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
simple context windows with freenoise shuffling
- Loading branch information
Showing
4 changed files
with
409 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import numpy as np | ||
from typing import Callable, Optional, List | ||
|
||
|
||
def ordered_halving(val): | ||
bin_str = f"{val:064b}" | ||
bin_flip = bin_str[::-1] | ||
as_int = int(bin_flip, 2) | ||
|
||
return as_int / (1 << 64) | ||
|
||
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: | ||
prev_val = -1 | ||
for i, val in enumerate(window): | ||
val = val % num_frames | ||
if val < prev_val: | ||
return True, i | ||
prev_val = val | ||
return False, -1 | ||
|
||
def shift_window_to_start(window: list[int], num_frames: int): | ||
start_val = window[0] | ||
for i in range(len(window)): | ||
# 1) subtract each element by start_val to move vals relative to the start of all frames | ||
# 2) add num_frames and take modulus to get adjusted vals | ||
window[i] = ((window[i] - start_val) + num_frames) % num_frames | ||
|
||
def shift_window_to_end(window: list[int], num_frames: int): | ||
# 1) shift window to start | ||
shift_window_to_start(window, num_frames) | ||
end_val = window[-1] | ||
end_delta = num_frames - end_val - 1 | ||
for i in range(len(window)): | ||
# 2) add end_delta to each val to slide windows to end | ||
window[i] = window[i] + end_delta | ||
|
||
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: | ||
all_indexes = list(range(num_frames)) | ||
for w in windows: | ||
for val in w: | ||
try: | ||
all_indexes.remove(val) | ||
except ValueError: | ||
pass | ||
return all_indexes | ||
|
||
def uniform_looped( | ||
step: int = ..., | ||
num_steps: Optional[int] = None, | ||
num_frames: int = ..., | ||
context_size: Optional[int] = None, | ||
context_stride: int = 3, | ||
context_overlap: int = 4, | ||
closed_loop: bool = True, | ||
): | ||
if num_frames <= context_size: | ||
yield list(range(num_frames)) | ||
return | ||
|
||
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) | ||
|
||
for context_step in 1 << np.arange(context_stride): | ||
pad = int(round(num_frames * ordered_halving(step))) | ||
for j in range( | ||
int(ordered_halving(step) * context_step) + pad, | ||
num_frames + pad + (0 if closed_loop else -context_overlap), | ||
(context_size * context_step - context_overlap), | ||
): | ||
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] | ||
|
||
#from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) | ||
def uniform_standard( | ||
step: int = ..., | ||
num_steps: Optional[int] = None, | ||
num_frames: int = ..., | ||
context_size: Optional[int] = None, | ||
context_stride: int = 3, | ||
context_overlap: int = 4, | ||
closed_loop: bool = True, | ||
): | ||
windows = [] | ||
if num_frames <= context_size: | ||
windows.append(list(range(num_frames))) | ||
return windows | ||
|
||
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) | ||
|
||
for context_step in 1 << np.arange(context_stride): | ||
pad = int(round(num_frames * ordered_halving(step))) | ||
for j in range( | ||
int(ordered_halving(step) * context_step) + pad, | ||
num_frames + pad + (0 if closed_loop else -context_overlap), | ||
(context_size * context_step - context_overlap), | ||
): | ||
windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) | ||
|
||
# now that windows are created, shift any windows that loop, and delete duplicate windows | ||
delete_idxs = [] | ||
win_i = 0 | ||
while win_i < len(windows): | ||
# if window is rolls over itself, need to shift it | ||
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) | ||
if is_roll: | ||
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides | ||
shift_window_to_end(windows[win_i], num_frames=num_frames) | ||
# check if next window (cyclical) is missing roll_val | ||
if roll_val not in windows[(win_i+1) % len(windows)]: | ||
# need to insert new window here - just insert window starting at roll_val | ||
windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) | ||
# delete window if it's not unique | ||
for pre_i in range(0, win_i): | ||
if windows[win_i] == windows[pre_i]: | ||
delete_idxs.append(win_i) | ||
break | ||
win_i += 1 | ||
|
||
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation | ||
delete_idxs.reverse() | ||
for i in delete_idxs: | ||
windows.pop(i) | ||
return windows | ||
|
||
def static_standard( | ||
step: int = ..., | ||
num_steps: Optional[int] = None, | ||
num_frames: int = ..., | ||
context_size: Optional[int] = None, | ||
context_stride: int = 3, | ||
context_overlap: int = 4, | ||
closed_loop: bool = True, | ||
): | ||
windows = [] | ||
if num_frames <= context_size: | ||
windows.append(list(range(num_frames))) | ||
return windows | ||
# always return the same set of windows | ||
delta = context_size - context_overlap | ||
for start_idx in range(0, num_frames, delta): | ||
# if past the end of frames, move start_idx back to allow same context_length | ||
ending = start_idx + context_size | ||
if ending >= num_frames: | ||
final_delta = ending - num_frames | ||
final_start_idx = start_idx - final_delta | ||
windows.append(list(range(final_start_idx, final_start_idx + context_size))) | ||
break | ||
windows.append(list(range(start_idx, start_idx + context_size))) | ||
return windows | ||
|
||
def get_context_scheduler(name: str) -> Callable: | ||
if name == "uniform_looped": | ||
return uniform_looped | ||
elif name == "uniform_standard": | ||
return uniform_standard | ||
elif name == "static_standard": | ||
return static_standard | ||
else: | ||
raise ValueError(f"Unknown context_overlap policy {name}") | ||
|
||
|
||
def get_total_steps( | ||
scheduler, | ||
timesteps: List[int], | ||
num_steps: Optional[int] = None, | ||
num_frames: int = ..., | ||
context_size: Optional[int] = None, | ||
context_stride: int = 3, | ||
context_overlap: int = 4, | ||
closed_loop: bool = True, | ||
): | ||
return sum( | ||
len( | ||
list( | ||
scheduler( | ||
i, | ||
num_steps, | ||
num_frames, | ||
context_size, | ||
context_stride, | ||
context_overlap, | ||
) | ||
) | ||
) | ||
for i in range(len(timesteps)) | ||
) |
Oops, something went wrong.