Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kv_cache to LLM #244

Merged
merged 12 commits into from
Mar 21, 2023
Merged

add kv_cache to LLM #244

merged 12 commits into from
Mar 21, 2023

Conversation

vchiley
Copy link
Contributor

@vchiley vchiley commented Mar 18, 2023

This pr include past_key_values (ie kv_cache) in the LLM so that inference can be accelerated.
We also become explicit about how we apply padding_mask for querys and keys.

Shoutout: @dakinggg for working through some of this with me.

cc @dskhudia @alextrott16 @samhavens for after training / inference

@vchiley vchiley self-assigned this Mar 18, 2023
@vchiley vchiley requested a review from bmosaicml March 18, 2023 01:02
@vchiley vchiley changed the title init kv_cache pr add kv_cache to LLM Mar 18, 2023
@honglu2875
Copy link

honglu2875 commented Mar 18, 2023

@vchiley Cached kv values shift the positions as well. Maybe you want to shift the position embeddings in the following?

pos = torch.arange(0, S, dtype=torch.long,

Compare with this in HF
https://github.com/huggingface/transformers/blob/60d51ef5123d949fd8c59cd4d3254e711541d278/src/transformers/models/gpt2/modeling_gpt2.py#L801

In our fork of mosaic models, we have the kv cache and the relevant part looks like the following:

        if past_key_values is None:
            past_key_values = [None] * self.cfg.n_layers
            past_position = 0
        else:
            assert len(past_key_values) == self.cfg.n_layers
            # get the key tensor whose spec should be (batch, seq, n_head, head_dim), and
            # collect the `seq`, so that we shift the position embedding later.
            past_position = past_key_values[0][0].size(1)

        tok_emb = self.transformer.wte(input_ids)  # type: ignore
        if self.alibi:
            x = tok_emb
        else:
            if S + past_position > self.cfg.max_seq_len:
                raise ValueError(
                    f'Cannot forward input with past sequence length {past_position} and current sequence length '
                    f'{S + 1}, this model only supports total sequence length <= {self.cfg.max_seq_len}.'
                )
            pos = torch.arange(past_position, S + past_position, dtype=torch.long,
                               device=input_ids.device).unsqueeze(0)
            pos_emb = self.transformer.wpe(pos)  # type: ignore
            x = tok_emb + pos_emb

@vchiley vchiley force-pushed the attn_kv_cache branch 4 times, most recently from f8a03e7 to 26431f7 Compare March 20, 2023 16:49
@vchiley vchiley requested review from dakinggg and removed request for alextrott16 and samhavens March 20, 2023 17:31
@vchiley vchiley marked this pull request as ready for review March 20, 2023 17:33
@vchiley vchiley force-pushed the attn_kv_cache branch 4 times, most recently from 68afe03 to aaf0658 Compare March 20, 2023 18:03
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, can you train a model and make sure nothing is broken?

examples/llm/src/models/layers/attention.py Outdated Show resolved Hide resolved
examples/llm/src/models/layers/gpt_blocks.py Outdated Show resolved Hide resolved
examples/llm/src/models/mosaic_gpt.py Outdated Show resolved Hide resolved
@dskhudia
Copy link
Contributor

Could you explain the reason for separating out query_padding_mask?

vchiley and others added 3 commits March 20, 2023 16:15
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
@vchiley
Copy link
Contributor Author

vchiley commented Mar 20, 2023

@dskhudia this formulates a generic attn fn
the queries are potentially not the same as the keys/value and will need their own padding_mask.
This is useful for left padded inputs since past tokens influence future tokens, we NEED to mask them out.

examples/llm/src/models/layers/attention.py Show resolved Hide resolved
examples/llm/src/models/layers/attention.py Show resolved Hide resolved
examples/llm/src/models/mosaic_gpt.py Outdated Show resolved Hide resolved
@vchiley
Copy link
Contributor Author

vchiley commented Mar 21, 2023

Note: we should have a conversation about if all raise Error should be changed to assert (see here)

@vchiley vchiley merged commit 83e6998 into mosaicml:main Mar 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants