Skip to content

Commit

Permalink
[Distributed] Add lanes to KV cache (#1174)
Browse files Browse the repository at this point in the history
* [WIP][Distributed] Add lanes to KV cache

* Compatibility change

* Naming

* Remove setup_input_pos

* Add timer

* Remove mbs
  • Loading branch information
kwen2501 authored Sep 23, 2024
1 parent dc832fb commit 2cf4016
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 51 deletions.
79 changes: 47 additions & 32 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,11 @@ def main(args):
pp_rank = pp_mesh.get_local_rank()
tp_group = tp_mesh.get_group()
pp_group = pp_mesh.get_group()
pp_group_size = pp_group.size()
tp_group_size = tp_group.size()
logger.info(f"{pp_group_size=}, {tp_group_size=}")
logger.info(f"{pp_degree=}, {tp_degree=}")

# Convenience variables
first_pp_rank = 0
last_pp_rank = pp_group_size - 1
last_pp_rank = pp_degree - 1

# Assuming same number of GPUs per node
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
Expand All @@ -297,18 +295,22 @@ def main(args):
if rank == 0:
logger.info(f"Model: {model}")

mbs = 1 # number of micro-batches
mb_size = 4 # micro-batch size
batch_size = mbs * mb_size # total batch size

# Batch size. Since we push batches dynamically through the pipeline rather
# than chunking them, this is effectively micro-batch size in pipeline
# sense. Thus it is interchangeable with micro-batch size below.
batch_size = 4
seqlen_prefill = 1024 # sequence length
dim = 4096 # embedding dimension

# Setup KV caches (after model distribution)
# TODO: the setting below only works for 1 micro-batch case. To support
# multiple micro-batches, we need the KV cache in the model to be aware of
# the number of micro-batches and the current micro-batch index.
model.setup_caches(mb_size, seqlen_prefill)
# The number of cache lanes is the same as the maximum number of
# micro-batches that can be "in flight" in parallel -- imagine each
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
# When decoding is done for certain micro-batches, we can reuse the KV cache
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
Expand All @@ -317,7 +319,7 @@ def main(args):
model.to(device)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# info on stage size and params
Expand All @@ -330,17 +332,16 @@ def main(args):

# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
input_pos = torch.arange(seqlen_prefill, device=device)
model.setup_input_pos(input_pos)
model.eval()

# Helper function to get example inputs and outputs for the stages.
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
activation = torch.rand(
mb_size, seqlen, dim, device=device, dtype=model_dtype
batch_size, seqlen, dim, device=device, dtype=model_dtype
)
logits = torch.rand(
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
)
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,)
example_outputs = (logits if pp_rank == last_pp_rank else activation,)
Expand All @@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
output_args=example_outputs,
group=pp_group,
)
# create schedule
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)

# Create schedule
# Number of micro-batches for the schedule is 1, because each step() call we
# only push 1 micro-batch into the pipeline. But we can continuously push
# new micro-batches into the pipeline as they arrive, achieving same
# pipelining effect.
prefiller = ScheduleGPipe(prefill_stage, 1)

prompt = [
"What is a computer?",
Expand Down Expand Up @@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
s = set(prompt_lengths)
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"

# with CUDATrackTime() as timer:
# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
Expand All @@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
num_tokens = 40

# Prefill phase
# Run context input through pipeline, in 1 step
with torch.no_grad():
# Run context input through pipeline
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
lane = 0
kwargs = {"input_pos": input_pos, "cache_lane": lane}
with torch.no_grad(), CUDATrackTime() as timer:
if pp_rank == first_pp_rank:
output = prefill_schedule.step(padded_sequence)
output = prefiller.step(padded_sequence, **kwargs)
elif pp_rank == last_pp_rank:
output = prefill_schedule.step()
output = prefiller.step(**kwargs)
else: # middle pp ranks
prefill_schedule.step()
prefiller.step(**kwargs)

logger.info(
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Decode the output -- first generated token
if pp_rank == last_pp_rank:
Expand All @@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# seqlen = 1 now
seqlen_decode = 1
input_pos = torch.tensor([prompt_lengths[0]], device=device)
model.setup_input_pos(input_pos)

# Create decode stage
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
Expand All @@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
group=pp_group,
)
# create schedule
decode_schedule = ScheduleGPipe(decode_stage, mbs)
decorder = ScheduleGPipe(decode_stage, 1)

# Decoding
with torch.no_grad():
with torch.no_grad(), CUDATrackTime() as timer:
for step in range(num_tokens - 1):
kwargs = {"input_pos": input_pos, "cache_lane": lane}
# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
Expand All @@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Run data through pipeline
if pp_rank == first_pp_rank:
output = decode_schedule.step(new_token)
output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
output = decode_schedule.step()
output = decorder.step(**kwargs)
else: # middle pp ranks
decode_schedule.step()
decorder.step(**kwargs)

# Decode the output
if pp_rank == last_pp_rank:
Expand All @@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
) # decode_results[i][0]

input_pos += 1
model.setup_input_pos(input_pos)

logger.info(
f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Display the decoding results

Expand Down
4 changes: 2 additions & 2 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def __init__(self, attention: Attention):
self.wo = attention.wo

max_batch_size, n_heads, max_seq_length, head_dim = (
attention.kv_cache.k_cache.shape
attention.kv_cache[0].k_cache.shape
)
cache_dtype = attention.kv_cache.k_cache.dtype
cache_dtype = attention.kv_cache[0].k_cache.dtype
self.kv_cache = CustomKVCache(
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
)
Expand Down
29 changes: 12 additions & 17 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def __init__(self, config: TransformerArgs) -> None:
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
if (
self.max_seq_length >= max_seq_length
and self.max_batch_size >= max_batch_size
Expand All @@ -620,7 +620,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
# parallelism may have been applied there and the `n_local_heads``
# value being adjusted.
b.attention.setup_cache(
max_batch_size, max_seq_length,
max_batch_size, max_seq_length, cache_lanes=cache_lanes
)

freqs_cis = precompute_freqs_cis(
Expand Down Expand Up @@ -653,22 +653,15 @@ def distribute(self, device_mesh: DeviceMesh):
ColwiseParallel(output_layouts=Replicate()),
)

# This is a temporary solution to pass input_pos to non-0 pipeline stages
# TODO: make `step()` function of dist.pipelining accept args for non-0 stages
def setup_input_pos(self, input_pos: Tensor) -> None:
self._input_pos = input_pos

def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
# TODO: find a better way to pass input_pos to non-0 pipeline stages
input_pos = input_pos if input_pos is not None else self._input_pos
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
if self.tok_embeddings:
x = self.tok_embeddings(x)

for _, layer in self.layers.items():
x = layer(x, input_pos, freqs_cis, mask)
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)

if self.norm:
x = self.norm(x)
Expand All @@ -691,7 +684,7 @@ def distribute(self, device_mesh: DeviceMesh):
self.feed_forward.distribute(device_mesh)

def forward(
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
Expand Down Expand Up @@ -723,15 +716,16 @@ def __init__(self, config: TransformerArgs):
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)

def setup_cache(self, max_batch_size, max_seq_length):
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
n_local_heads = self.n_local_heads
# If TP is enabled, the heads would be divided and assigned to different ranks
if hasattr(self, "tp_degree"):
n_local_heads = self.n_local_heads // self.tp_degree

self.kv_cache = KVCache(
max_batch_size, max_seq_length, n_local_heads, self.head_dim
)
self.kv_cache = nn.ModuleList([
KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim)
for _ in range(cache_lanes)
])

def load_hook(self, state_dict, prefix, *args):
# if prefix + "wq.weight" in state_dict:
Expand Down Expand Up @@ -784,6 +778,7 @@ def forward(
freqs_cis: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
cache_lane: int = 0,
) -> Tensor:
bsz, seqlen, _ = x.shape

Expand All @@ -809,7 +804,7 @@ def forward(
q, k, v = (x.transpose(1, 2) for x in (q, k, v))

if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
k, v = self.kv_cache[cache_lane].update(input_pos, k, v)

k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
Expand Down

0 comments on commit 2cf4016

Please sign in to comment.