-
Notifications
You must be signed in to change notification settings - Fork 250
[Distributed] Add lanes to KV cache #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1174
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9514b54 with merge base 8d01d9b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
e33e681
to
39eff90
Compare
@@ -445,11 +456,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |||
group=pp_group, | |||
) | |||
# create schedule | |||
decode_schedule = ScheduleGPipe(decode_stage, mbs) | |||
decorder = ScheduleGPipe(decode_stage, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax error - this should be 'decoder' and not 'decorder'.
@@ -467,11 +479,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: | |||
|
|||
# Run data through pipeline | |||
if pp_rank == first_pp_rank: | |||
output = decode_schedule.step(new_token) | |||
output = decorder.step(new_token, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, syntax error - this should be 'decoder' and not 'decorder'.
elif pp_rank == last_pp_rank: | ||
output = decode_schedule.step() | ||
output = decorder.step(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, syntax error - this should be 'decoder' and not 'decorder'.
else: # middle pp ranks | ||
decode_schedule.step() | ||
decorder.step(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
last one, syntax error - this should be 'decoder' and not 'decorder'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice addition!
minor note that 'decorder' should be 'decoder' in the code for ease of understanding/syntax.
KV cache is extended to have multiple lanes, each letting a separate batch pass through, achieving pipeline parallelism.
Major changes
setup_caches
will take one kwargcache_lanes
(default to 1).attention.kv_cache
is now ann.ModuleList
, containing multipleKVCache
's, each corresponding to a lane.We now pass
kwargs = {"input_pos": input_pos, "cache_lane": lane}
to thestep()
function. Removing the temporary helper functionmodel.setup_input_pos
.Requires pytorch/pytorch#136416 to support pass-in of kwargs.