Skip to content

[Distributed] Add lanes to KV cache #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 23, 2024
Merged

[Distributed] Add lanes to KV cache #1174

merged 6 commits into from
Sep 23, 2024

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Sep 21, 2024

KV cache is extended to have multiple lanes, each letting a separate batch pass through, achieving pipeline parallelism.

# The number of cache lanes is the same as the maximum number of
# micro-batches that can be "in flight" in parallel -- imagine each
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
# When decoding is done for certain micro-batches, we can reuse the KV cache
# lanes.

Major changes

  1. setup_caches will take one kwarg cache_lanes (default to 1).
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1)
  1. attention.kv_cache is now a nn.ModuleList, containing multiple KVCache's, each corresponding to a lane.

  2. We now pass kwargs = {"input_pos": input_pos, "cache_lane": lane} to the step() function. Removing the temporary helper function model.setup_input_pos.

Requires pytorch/pytorch#136416 to support pass-in of kwargs.

Copy link

pytorch-bot bot commented Sep 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1174

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9514b54 with merge base 8d01d9b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 21, 2024
@kwen2501 kwen2501 changed the title [WIP][Distributed] Add lanes to KV cache [Distributed] Add lanes to KV cache Sep 23, 2024
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
group=pp_group,
)
# create schedule
decode_schedule = ScheduleGPipe(decode_stage, mbs)
decorder = ScheduleGPipe(decode_stage, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax error - this should be 'decoder' and not 'decorder'.

@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Run data through pipeline
if pp_rank == first_pp_rank:
output = decode_schedule.step(new_token)
output = decorder.step(new_token, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, syntax error - this should be 'decoder' and not 'decorder'.

elif pp_rank == last_pp_rank:
output = decode_schedule.step()
output = decorder.step(**kwargs)
Copy link
Contributor

@lessw2020 lessw2020 Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, syntax error - this should be 'decoder' and not 'decorder'.

else: # middle pp ranks
decode_schedule.step()
decorder.step(**kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last one, syntax error - this should be 'decoder' and not 'decorder'.

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice addition!
minor note that 'decorder' should be 'decoder' in the code for ease of understanding/syntax.

@kwen2501 kwen2501 merged commit 2cf4016 into main Sep 23, 2024
51 checks passed
kwen2501 added a commit that referenced this pull request Sep 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants