Skip to content

[Distributed] Fix cache lane #1194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,21 @@ def __init__(self, attention: Attention):
attention.kv_cache[0].k_cache.shape
)
cache_dtype = attention.kv_cache[0].k_cache.dtype
self.kv_cache = CustomKVCache(
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
)
# The `Attention` module being replaced can have multiple KV caches
# (denoted by `cache_lanes`). Thus we follow the same setup format
# as in `Attention.setup_cache`.
cache_lanes = len(attention.kv_cache)
self.kv_cache = nn.ModuleList([
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype)
for _ in range(cache_lanes)
])

self.n_heads = attention.n_heads
self.head_dim = attention.head_dim
self.n_local_heads = attention.n_local_heads
self.dim = attention.dim

def forward(self, x, freqs_cis, mask, input_pos=None):
def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
bsz, seqlen, _ = x.shape

q = self.wq(x)
Expand All @@ -181,12 +186,13 @@ def forward(self, x, freqs_cis, mask, input_pos=None):

# KV cache should always be enabled
assert self.kv_cache is not None
kv_cache = self.kv_cache[cache_lane]
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
kv_cache.k_cache,
kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
Expand Down
6 changes: 4 additions & 2 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def distribute(self, device_mesh: DeviceMesh):
ColwiseParallel(output_layouts=Replicate()),
)

def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
Expand Down Expand Up @@ -686,7 +686,9 @@ def distribute(self, device_mesh: DeviceMesh):
def forward(
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
h = x + self.attention(
self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane
)
out = h + self.feed_forward(self.ffn_norm(h))
return out

Expand Down
Loading