Skip to content

[Distributed] Add lanes to KV cache #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 47 additions & 32 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,11 @@ def main(args):
pp_rank = pp_mesh.get_local_rank()
tp_group = tp_mesh.get_group()
pp_group = pp_mesh.get_group()
pp_group_size = pp_group.size()
tp_group_size = tp_group.size()
logger.info(f"{pp_group_size=}, {tp_group_size=}")
logger.info(f"{pp_degree=}, {tp_degree=}")

# Convenience variables
first_pp_rank = 0
last_pp_rank = pp_group_size - 1
last_pp_rank = pp_degree - 1

# Assuming same number of GPUs per node
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
Expand All @@ -297,18 +295,22 @@ def main(args):
if rank == 0:
logger.info(f"Model: {model}")

mbs = 1 # number of micro-batches
mb_size = 4 # micro-batch size
batch_size = mbs * mb_size # total batch size

# Batch size. Since we push batches dynamically through the pipeline rather
# than chunking them, this is effectively micro-batch size in pipeline
# sense. Thus it is interchangeable with micro-batch size below.
batch_size = 4
seqlen_prefill = 1024 # sequence length
dim = 4096 # embedding dimension

# Setup KV caches (after model distribution)
# TODO: the setting below only works for 1 micro-batch case. To support
# multiple micro-batches, we need the KV cache in the model to be aware of
# the number of micro-batches and the current micro-batch index.
model.setup_caches(mb_size, seqlen_prefill)
# The number of cache lanes is the same as the maximum number of
# micro-batches that can be "in flight" in parallel -- imagine each
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
# When decoding is done for certain micro-batches, we can reuse the KV cache
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
Expand All @@ -317,7 +319,7 @@ def main(args):
model.to(device)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# info on stage size and params
Expand All @@ -330,17 +332,16 @@ def main(args):

# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
input_pos = torch.arange(seqlen_prefill, device=device)
model.setup_input_pos(input_pos)
model.eval()

# Helper function to get example inputs and outputs for the stages.
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
activation = torch.rand(
mb_size, seqlen, dim, device=device, dtype=model_dtype
batch_size, seqlen, dim, device=device, dtype=model_dtype
)
logits = torch.rand(
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
)
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,)
example_outputs = (logits if pp_rank == last_pp_rank else activation,)
Expand All @@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
output_args=example_outputs,
group=pp_group,
)
# create schedule
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)

# Create schedule
# Number of micro-batches for the schedule is 1, because each step() call we
# only push 1 micro-batch into the pipeline. But we can continuously push
# new micro-batches into the pipeline as they arrive, achieving same
# pipelining effect.
prefiller = ScheduleGPipe(prefill_stage, 1)

prompt = [
"What is a computer?",
Expand Down Expand Up @@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
s = set(prompt_lengths)
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"

# with CUDATrackTime() as timer:
# Need these global ids due to the API definition of dist.send and recv
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
Expand All @@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
num_tokens = 40

# Prefill phase
# Run context input through pipeline, in 1 step
with torch.no_grad():
# Run context input through pipeline
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
lane = 0
kwargs = {"input_pos": input_pos, "cache_lane": lane}
with torch.no_grad(), CUDATrackTime() as timer:
if pp_rank == first_pp_rank:
output = prefill_schedule.step(padded_sequence)
output = prefiller.step(padded_sequence, **kwargs)
elif pp_rank == last_pp_rank:
output = prefill_schedule.step()
output = prefiller.step(**kwargs)
else: # middle pp ranks
prefill_schedule.step()
prefiller.step(**kwargs)

logger.info(
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Decode the output -- first generated token
if pp_rank == last_pp_rank:
Expand All @@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
# seqlen = 1 now
seqlen_decode = 1
input_pos = torch.tensor([prompt_lengths[0]], device=device)
model.setup_input_pos(input_pos)

# Create decode stage
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
Expand All @@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
group=pp_group,
)
# create schedule
decode_schedule = ScheduleGPipe(decode_stage, mbs)
decorder = ScheduleGPipe(decode_stage, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax error - this should be 'decoder' and not 'decorder'.


# Decoding
with torch.no_grad():
with torch.no_grad(), CUDATrackTime() as timer:
for step in range(num_tokens - 1):
kwargs = {"input_pos": input_pos, "cache_lane": lane}
# sendrecv between last and first ranks, only if:
# first_pp_rank != last_pp_rank.
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
Expand All @@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Run data through pipeline
if pp_rank == first_pp_rank:
output = decode_schedule.step(new_token)
output = decorder.step(new_token, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, syntax error - this should be 'decoder' and not 'decorder'.

elif pp_rank == last_pp_rank:
output = decode_schedule.step()
output = decorder.step(**kwargs)
Copy link
Contributor

@lessw2020 lessw2020 Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, syntax error - this should be 'decoder' and not 'decorder'.

else: # middle pp ranks
decode_schedule.step()
decorder.step(**kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last one, syntax error - this should be 'decoder' and not 'decorder'.


# Decode the output
if pp_rank == last_pp_rank:
Expand All @@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
) # decode_results[i][0]

input_pos += 1
model.setup_input_pos(input_pos)

logger.info(
f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Display the decoding results

Expand Down
4 changes: 2 additions & 2 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def __init__(self, attention: Attention):
self.wo = attention.wo

max_batch_size, n_heads, max_seq_length, head_dim = (
attention.kv_cache.k_cache.shape
attention.kv_cache[0].k_cache.shape
)
cache_dtype = attention.kv_cache.k_cache.dtype
cache_dtype = attention.kv_cache[0].k_cache.dtype
self.kv_cache = CustomKVCache(
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
)
Expand Down
29 changes: 12 additions & 17 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def __init__(self, config: TransformerArgs) -> None:
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
if (
self.max_seq_length >= max_seq_length
and self.max_batch_size >= max_batch_size
Expand All @@ -620,7 +620,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
# parallelism may have been applied there and the `n_local_heads``
# value being adjusted.
b.attention.setup_cache(
max_batch_size, max_seq_length,
max_batch_size, max_seq_length, cache_lanes=cache_lanes
)

freqs_cis = precompute_freqs_cis(
Expand Down Expand Up @@ -653,22 +653,15 @@ def distribute(self, device_mesh: DeviceMesh):
ColwiseParallel(output_layouts=Replicate()),
)

# This is a temporary solution to pass input_pos to non-0 pipeline stages
# TODO: make `step()` function of dist.pipelining accept args for non-0 stages
def setup_input_pos(self, input_pos: Tensor) -> None:
self._input_pos = input_pos

def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
# TODO: find a better way to pass input_pos to non-0 pipeline stages
input_pos = input_pos if input_pos is not None else self._input_pos
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
if self.tok_embeddings:
x = self.tok_embeddings(x)

for _, layer in self.layers.items():
x = layer(x, input_pos, freqs_cis, mask)
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)

if self.norm:
x = self.norm(x)
Expand All @@ -691,7 +684,7 @@ def distribute(self, device_mesh: DeviceMesh):
self.feed_forward.distribute(device_mesh)

def forward(
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
Expand Down Expand Up @@ -723,15 +716,16 @@ def __init__(self, config: TransformerArgs):
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)

def setup_cache(self, max_batch_size, max_seq_length):
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
n_local_heads = self.n_local_heads
# If TP is enabled, the heads would be divided and assigned to different ranks
if hasattr(self, "tp_degree"):
n_local_heads = self.n_local_heads // self.tp_degree

self.kv_cache = KVCache(
max_batch_size, max_seq_length, n_local_heads, self.head_dim
)
self.kv_cache = nn.ModuleList([
KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim)
for _ in range(cache_lanes)
])

def load_hook(self, state_dict, prefix, *args):
# if prefix + "wq.weight" in state_dict:
Expand Down Expand Up @@ -784,6 +778,7 @@ def forward(
freqs_cis: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
cache_lane: int = 0,
) -> Tensor:
bsz, seqlen, _ = x.shape

Expand All @@ -809,7 +804,7 @@ def forward(
q, k, v = (x.transpose(1, 2) for x in (q, k, v))

if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
k, v = self.kv_cache[cache_lane].update(input_pos, k, v)

k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
Expand Down
Loading