Skip to content

Commit

Permalink
Fix streaming propagate + fixing bug in KV Cache (#148)
Browse files Browse the repository at this point in the history
* fix

* further fix potential bug in KVCache when wrapping around

* bump version

* update doc
  • Loading branch information
adefossez authored Oct 31, 2024
1 parent 2df3a38 commit 4a40e26
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 45 deletions.
2 changes: 1 addition & 1 deletion moshi/moshi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from . import models
from . import quantization

__version__ = "0.1.0"
__version__ = "0.1.1a1"
5 changes: 4 additions & 1 deletion moshi/moshi/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def __init__(
dtype=dtype,
**kwargs_dep,
)
self.depformer.set_streaming_propagate(False)
# Depformer follow its own cycle of streaming entirely contained in one time step
# and should not follow the streaming of the steps dimensions.
self.depformer.set_streaming_detached(True)
dim = depformer_dim # we will directly apply the next linears to the output of the Depformer.

self.linears = nn.ModuleList(
Expand Down Expand Up @@ -465,6 +467,7 @@ def depformer_step(
depformer_tokens: list[torch.Tensor] = []
assert not lm_model.depformer.is_streaming
with lm_model.depformer.streaming(B):
assert lm_model.depformer.is_streaming
for cb_index in range(lm_model.dep_q):
input_ = prev_token[:, None, None]
logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
Expand Down
56 changes: 26 additions & 30 deletions moshi/moshi/modules/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,56 +33,52 @@ def reset(self) -> None:
class StreamingModule(abc.ABC, nn.Module, tp.Generic[State]):
"""Common API for streaming components.
Each streaming component has a streaming state, which is just a dict[str, Tensor].
By convention, the first dim of each tensor must be the batch size.
Don't use dots in the key names, as this would clash with submodules
(like in state_dict).
If `self._is_streaming` is True, the component should use and remember
the proper state inside `self._streaming_state`.
Each streaming component has a streaming state, `self._streaming_state`, which is None by default.
To set a streaming component in streaming state, use
with module.streaming():
...
This will automatically reset the streaming state when exiting the context manager.
This will automatically void the streaming state when exiting the context manager.
This also automatically propagates to all streaming children module.
Some module might also implement the `StreamingModule.flush` method, although
this one is trickier, as all parents module must be StreamingModule and implement
it as well for it to work properly. See `StreamingSequential` after.
When the streaming state is set, modules should store whatever state they need in there.
"""

def __init__(self) -> None:
super().__init__()
self._streaming_state: State | None = None
self._streaming_propagate: bool = True
self._streaming_detached: bool = False

@property
def is_streaming(self):
return self._streaming_state is not None

def set_streaming_propagate(self, streaming_propagate: bool):
self._streaming_propagate = streaming_propagate
def set_streaming_detached(self, streaming_detached: bool):
"""If set to False, the default, this module and all submodules will switch to streaming mode
if a parent module is set to streaming mode.
If set to True, or in detach mode, only a direct call to this module `.streaming(...)` method
will set it into streaming mode, ignoring the changes from its parents.
This is useful is streaming over two different dimensions, e.g. for the RQ-Transformer
with the inner Depth Transformer working on the dimension of the codebooks."""
self._streaming_detached = streaming_detached

def _apply_named_streaming(self, fn: tp.Any):
def _handle_module(prefix: str, module: nn.Module, recurse: bool = True):
propagate = True
def _handle_module(prefix: str, module: nn.Module):
if isinstance(module, StreamingModule):
if module._streaming_propagate:
fn(prefix, module)
# If prefix is empty, we are the direct receiver of the streaming request,
# otherwise, we are inheriting from a parent and will stop if detached.
if module._streaming_detached and prefix != "":
return
fn(prefix, module)
for name, child in module.named_children():
if prefix:
new_prefix = prefix + "." + name
else:
propagate = False
if not recurse:
return
if propagate:
for name, child in module.named_children():
_handle_module(prefix + "." + name, child)

_handle_module("", self, recurse=False)
for name, child in self.named_children():
_handle_module(name, child)
new_prefix = name
_handle_module(new_prefix, child)

_handle_module("", self)

def _start_streaming(self, batch_size: int):
def _start_streaming(name: str, module: StreamingModule):
Expand Down
26 changes: 13 additions & 13 deletions moshi/moshi/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,37 +242,37 @@ def reset(self):
def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:
assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape)
B, H, T, D = k.shape
assert T > 0
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset
indexes = indexes % self.capacity
self.cache[0].index_copy_(2, indexes, k)
self.cache[1].index_copy_(2, indexes, v)
self.end_offset.add_(T)

keys = self.cache[0]
values = self.cache[1]

indexes = torch.arange(
self.capacity, device=self.end_offset.device, dtype=torch.long
)
invalid = indexes >= self.end_offset

end_index = self.end_offset % self.capacity
# end_index correspond to the actual index where the last value was written.
last_offset = self.end_offset + T - 1
end_index = last_offset % self.capacity
delta = indexes - end_index

# If last key is for step S, and capacity is C, last key was written at index S % C.
# then end_offset = S + 1, and end_index = (S + 1) % C.
# Then for index = (S % C), delta = -1, and the next code gives us:
# position(index) = (S + 1) - 1 = S, all good.
# Now the time step at end_offset is actually the oldest in the KVCache, e.g., its
# position should be (S - self.capacity + 1).
# The following code gives us:
# position(index + 1) = S + 1 + 0 - self.capacity.
# We know that if `index == end_index`, then we should output `self.end_offset`.
# If `index = end_index - 1` we should output `self.end_offset - 1`
# If `index = end_index - n` we should output `self.end_offset - n`
# Now, for `index == end_index + 1` , we actually have the oldest entry in the cache,
# so we should output `end_index + 1 - self.capacity`

positions = torch.where(
delta <= 0,
self.end_offset + delta,
self.end_offset + delta - self.capacity,
last_offset + delta,
last_offset + delta - self.capacity,
)
self.end_offset.add_(T)
invalid = indexes >= self.end_offset
positions = torch.where(invalid, torch.full_like(positions, -1), positions)

return KVCacheResult(keys, values, positions)
Expand Down

0 comments on commit 4a40e26

Please sign in to comment.