-
Notifications
You must be signed in to change notification settings - Fork 51
Description
Hi, I thought compress_block_size is the whole compression block size, and compress_block_overlap_len is how much compress_block_size overlap with each other. After inspecting the code it turns out the whole compression block size is compress_window_size, which is the sum of compress_block_size and compress_block_overlap_len. Later in the code:
split_compress_window_fn = nn.Sequential(
Rearrange('b h n d -> (b h) d 1 n'),
nn.ZeroPad2d((compress_block_overlap_len, 0, 0, 0)),
nn.Unfold(kernel_size = (1, compress_window_size), stride = (1, compress_block_size)),
Rearrange('(b h) (d n) w -> b h w n d', d = dim_head, h = kv_heads)
)
The stride is set to compress_block_size. I think this might cause confusion for general users, and it is better to set compress_block_size as the whole compression block size, and another input argument of compress_block_sliding_stride, which have two advantages,
- it stick to the notation from original paper
- this concept is similar to sliding window in CNN, which is easier to understand
How do you think about it?
another issue is that the original paper set compression block size l = 32, sliding stride d = 16. Translate to this code implementation, the compress_block_size should be 16, and compress_block_overlap_len be 16. But this will trigger assertion in line 286: assert compress_block_overlap_len < compress_block_size.