-
Notifications
You must be signed in to change notification settings - Fork 250
[Distributed] Add lanes to KV cache #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -273,13 +273,11 @@ def main(args): | |
pp_rank = pp_mesh.get_local_rank() | ||
tp_group = tp_mesh.get_group() | ||
pp_group = pp_mesh.get_group() | ||
pp_group_size = pp_group.size() | ||
tp_group_size = tp_group.size() | ||
logger.info(f"{pp_group_size=}, {tp_group_size=}") | ||
logger.info(f"{pp_degree=}, {tp_degree=}") | ||
|
||
# Convenience variables | ||
first_pp_rank = 0 | ||
last_pp_rank = pp_group_size - 1 | ||
last_pp_rank = pp_degree - 1 | ||
|
||
# Assuming same number of GPUs per node | ||
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") | ||
|
@@ -297,18 +295,22 @@ def main(args): | |
if rank == 0: | ||
logger.info(f"Model: {model}") | ||
|
||
mbs = 1 # number of micro-batches | ||
mb_size = 4 # micro-batch size | ||
batch_size = mbs * mb_size # total batch size | ||
|
||
# Batch size. Since we push batches dynamically through the pipeline rather | ||
# than chunking them, this is effectively micro-batch size in pipeline | ||
# sense. Thus it is interchangeable with micro-batch size below. | ||
batch_size = 4 | ||
seqlen_prefill = 1024 # sequence length | ||
dim = 4096 # embedding dimension | ||
|
||
# Setup KV caches (after model distribution) | ||
# TODO: the setting below only works for 1 micro-batch case. To support | ||
# multiple micro-batches, we need the KV cache in the model to be aware of | ||
# the number of micro-batches and the current micro-batch index. | ||
model.setup_caches(mb_size, seqlen_prefill) | ||
# The number of cache lanes is the same as the maximum number of | ||
# micro-batches that can be "in flight" in parallel -- imagine each | ||
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. | ||
# When decoding is done for certain micro-batches, we can reuse the KV cache | ||
# lanes. | ||
# TODO: bump up the lane count | ||
pipeline_lanes = 1 | ||
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) | ||
|
||
# Load weights | ||
logger.info(f"Loading weights for {pp_rank=} on {device=}") | ||
|
@@ -317,7 +319,7 @@ def main(args): | |
model.to(device) | ||
|
||
logger.info( | ||
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" | ||
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
) | ||
|
||
# info on stage size and params | ||
|
@@ -330,17 +332,16 @@ def main(args): | |
|
||
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen | ||
input_pos = torch.arange(seqlen_prefill, device=device) | ||
model.setup_input_pos(input_pos) | ||
model.eval() | ||
|
||
# Helper function to get example inputs and outputs for the stages. | ||
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | ||
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device) | ||
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) | ||
activation = torch.rand( | ||
mb_size, seqlen, dim, device=device, dtype=model_dtype | ||
batch_size, seqlen, dim, device=device, dtype=model_dtype | ||
) | ||
logits = torch.rand( | ||
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype | ||
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype | ||
) | ||
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) | ||
example_outputs = (logits if pp_rank == last_pp_rank else activation,) | ||
|
@@ -358,8 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
output_args=example_outputs, | ||
group=pp_group, | ||
) | ||
# create schedule | ||
prefill_schedule = ScheduleGPipe(prefill_stage, mbs) | ||
|
||
# Create schedule | ||
# Number of micro-batches for the schedule is 1, because each step() call we | ||
# only push 1 micro-batch into the pipeline. But we can continuously push | ||
# new micro-batches into the pipeline as they arrive, achieving same | ||
# pipelining effect. | ||
prefiller = ScheduleGPipe(prefill_stage, 1) | ||
|
||
prompt = [ | ||
"What is a computer?", | ||
|
@@ -388,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
s = set(prompt_lengths) | ||
assert len(s) == 1, f"prompt_lengths should be the same, got {s}" | ||
|
||
# with CUDATrackTime() as timer: | ||
# Need these global ids due to the API definition of dist.send and recv | ||
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) | ||
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) | ||
|
@@ -401,14 +406,21 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
num_tokens = 40 | ||
|
||
# Prefill phase | ||
# Run context input through pipeline, in 1 step | ||
with torch.no_grad(): | ||
# Run context input through pipeline | ||
# TODO: we need to pass `input_pos` and `cache_lane` to each stage. | ||
lane = 0 | ||
kwargs = {"input_pos": input_pos, "cache_lane": lane} | ||
with torch.no_grad(), CUDATrackTime() as timer: | ||
if pp_rank == first_pp_rank: | ||
output = prefill_schedule.step(padded_sequence) | ||
output = prefiller.step(padded_sequence, **kwargs) | ||
elif pp_rank == last_pp_rank: | ||
output = prefill_schedule.step() | ||
output = prefiller.step(**kwargs) | ||
else: # middle pp ranks | ||
prefill_schedule.step() | ||
prefiller.step(**kwargs) | ||
|
||
logger.info( | ||
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
) | ||
|
||
# Decode the output -- first generated token | ||
if pp_rank == last_pp_rank: | ||
|
@@ -430,7 +442,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
# seqlen = 1 now | ||
seqlen_decode = 1 | ||
input_pos = torch.tensor([prompt_lengths[0]], device=device) | ||
model.setup_input_pos(input_pos) | ||
|
||
# Create decode stage | ||
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") | ||
|
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
group=pp_group, | ||
) | ||
# create schedule | ||
decode_schedule = ScheduleGPipe(decode_stage, mbs) | ||
decorder = ScheduleGPipe(decode_stage, 1) | ||
|
||
# Decoding | ||
with torch.no_grad(): | ||
with torch.no_grad(), CUDATrackTime() as timer: | ||
for step in range(num_tokens - 1): | ||
kwargs = {"input_pos": input_pos, "cache_lane": lane} | ||
# sendrecv between last and first ranks, only if: | ||
# first_pp_rank != last_pp_rank. | ||
if pp_rank == last_pp_rank and pp_rank != first_pp_rank: | ||
|
@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
|
||
# Run data through pipeline | ||
if pp_rank == first_pp_rank: | ||
output = decode_schedule.step(new_token) | ||
output = decorder.step(new_token, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, syntax error - this should be 'decoder' and not 'decorder'. |
||
elif pp_rank == last_pp_rank: | ||
output = decode_schedule.step() | ||
output = decorder.step(**kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, syntax error - this should be 'decoder' and not 'decorder'. |
||
else: # middle pp ranks | ||
decode_schedule.step() | ||
decorder.step(**kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. last one, syntax error - this should be 'decoder' and not 'decorder'. |
||
|
||
# Decode the output | ||
if pp_rank == last_pp_rank: | ||
|
@@ -491,7 +503,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
) # decode_results[i][0] | ||
|
||
input_pos += 1 | ||
model.setup_input_pos(input_pos) | ||
|
||
logger.info( | ||
f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" | ||
) | ||
|
||
# Display the decoding results | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax error - this should be 'decoder' and not 'decorder'.