From ca1b2f4994608c977628dd29848ecc8f85ffced8 Mon Sep 17 00:00:00 2001 From: Vidya Galli Date: Fri, 14 Jun 2024 13:36:44 -0700 Subject: [PATCH] Updated kv cache for starcoder (#128) --- server/tests/models/test_starcoder.py | 372 ++++++++++++++++++ .../text_generation_server/models/__init__.py | 10 +- .../models/starcoder.py | 51 +++ 3 files changed, 425 insertions(+), 8 deletions(-) create mode 100644 server/tests/models/test_starcoder.py create mode 100644 server/text_generation_server/models/starcoder.py diff --git a/server/tests/models/test_starcoder.py b/server/tests/models/test_starcoder.py new file mode 100644 index 00000000000..05bdd72b713 --- /dev/null +++ b/server/tests/models/test_starcoder.py @@ -0,0 +1,372 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +import pytest +import torch + +from copy import copy + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models import get_model +from text_generation_server.models.starcoder import StarCoderCausalLMBatch +from text_generation_server.models.causal_lm import ( + PREFILL_BATCH_BUCKET_SIZE, + PAD_SEQUENCE_TO_MULTIPLE_OF, + MAX_TOTAL_TOKENS, + BATCH_BUCKET_SIZE, +) +PAD_TOKEN=0 + + +@pytest.fixture(scope="session") +def default_starcoder(): + return get_model("bigcode/starcoder", None, None, None, None) + + +@pytest.fixture(scope="session") +def default_tokenizer(default_starcoder): + default_starcoder.tokenizer.pad_token_id = PAD_TOKEN + return default_starcoder.tokenizer + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + prefill_logprobs=True, + truncate=PAD_SEQUENCE_TO_MULTIPLE_OF, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_starcoder_batch(default_pb_batch, default_tokenizer): + return StarCoderCausalLMBatch.from_pb( + default_pb_batch, default_tokenizer, torch.float32, torch.device("hpu") + ) + + +@pytest.fixture +def default_multi_requests_starcoder_batch(default_pb_request, default_tokenizer): + req_0 = copy(default_pb_request) + req_0.id = 1 + req_1 = default_pb_request + req_1.id = 2 + req_1.stopping_parameters.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) + return StarCoderCausalLMBatch.from_pb( + batch_pb, default_tokenizer, torch.float32, torch.device("hpu") + ) + + +def test_starcoder_batch_type(default_starcoder): + assert default_starcoder.batch_type == StarCoderCausalLMBatch + + +def test_batch_from_pb(default_pb_batch, default_starcoder_batch): + batch = default_starcoder_batch + + assert batch.batch_id == default_pb_batch.id + assert len(batch.requests) == len(default_pb_batch.requests) + + for r in range(0,len(default_pb_batch.requests)): + assert batch.requests[r].data == default_pb_batch.requests[r] + + # For Gaudi we are adding padding of multiplication of bucket size + size_of_padded_to_bucket = ((default_pb_batch.size + PREFILL_BATCH_BUCKET_SIZE - 1) // PREFILL_BATCH_BUCKET_SIZE) * PREFILL_BATCH_BUCKET_SIZE + + assert len(batch.input_ids) == size_of_padded_to_bucket + assert batch.input_ids.shape == torch.Size([4, 128]) + + assert batch.input_ids[0][-2] == 1006 + assert batch.input_ids[1][-2] == 49 + assert batch.input_ids[2][-2] == 49 + assert batch.attention_mask[0][-2] == 1 + assert batch.attention_mask[1][-2] == 1 + assert batch.attention_mask[2][-2] == 1 + assert torch.all(batch.attention_mask[0, :-3] == 0) + + assert batch.past_key_values is None + assert all( + [ + torch.equal(input_ids, request.all_input_ids[:batch.input_length + 1, 0]) + for input_ids, request in zip(batch.input_ids, batch.requests) + ] + ) + + assert len(batch) == default_pb_batch.size + + assert batch.max_input_length + 1 == default_pb_batch.requests[0].truncate + + +def test_starcoder_generate_token(default_starcoder, default_starcoder_batch): + + sequence_length = len(default_starcoder_batch.requests[0].all_input_ids) + generations, next_batch, _ = default_starcoder.generate_token([default_starcoder_batch]) + padding = next_batch.requests[0].stopping_criteria.max_new_tokens + + assert isinstance(next_batch, StarCoderCausalLMBatch) + assert len(next_batch.attention_mask[0]) == PAD_SEQUENCE_TO_MULTIPLE_OF + assert next_batch.requests[0].all_input_ids[-padding-2] == 1006 + + assert torch.all(next_batch.requests[0].all_input_ids[-padding-1:] == PAD_TOKEN) + assert torch.all(next_batch.requests[0].all_input_ids[:-padding-3] == PAD_TOKEN) + + generations, next_batch, _ = default_starcoder.generate_token([default_starcoder_batch]) + assert torch.all(next_batch.attention_mask[0][PAD_SEQUENCE_TO_MULTIPLE_OF-2:PAD_SEQUENCE_TO_MULTIPLE_OF] == 1) + assert torch.all(next_batch.attention_mask[0][:PAD_SEQUENCE_TO_MULTIPLE_OF-3] == 0) + assert torch.all(next_batch.attention_mask[0][PAD_SEQUENCE_TO_MULTIPLE_OF+1:] == 0) + + assert next_batch.requests[0].all_input_ids[-padding-2] == 1006 + assert next_batch.requests[0].all_input_ids[-padding-1] == 26 + assert torch.all(next_batch.requests[0].all_input_ids[-padding:] == PAD_TOKEN) + assert torch.all(next_batch.requests[0].all_input_ids[:-padding-3] == PAD_TOKEN) + + assert next_batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF + assert next_batch.max_input_length == next_batch.input_length + + assert next_batch.past_key_values is not None + assert all( + [p[0].shape == (MAX_TOTAL_TOKENS, 256) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (MAX_TOTAL_TOKENS, 256) for p in next_batch.past_key_values] + ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == PAD_SEQUENCE_TO_MULTIPLE_OF-1 for generation in generations]) + assert all([generation.tokens.token_ids[0] == 26 for generation in generations]) + assert all([generation.tokens.texts[0] == "(" for generation in generations]) + assert generations[0].request_id == 0 + + +def test_starcoder_generate_token_completion( + default_starcoder, default_starcoder_batch +): + + next_batch = default_starcoder_batch + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + + for _ in range(default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - 1): + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == '(self):\n """\n Test that the test' + assert generations[0].request_id == default_starcoder_batch.requests[0].data.id + assert ( + generations[0].generated_text.generated_tokens + == default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens + ) + + +def test_starcoder_generate_token_completion_multi( + default_starcoder, default_multi_requests_starcoder_batch +): + next_batch = default_multi_requests_starcoder_batch + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + + for i in range( + default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - 1 + ): + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert next_batch is not None + + assert len(generations) == 2 + assert generations[1].generated_text.text == '(self):\n """' + assert ( + generations[1].request_id + == default_multi_requests_starcoder_batch.requests[1].data.id + ) + assert ( + generations[1].generated_text.generated_tokens + == default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens + ) + + next_batch = next_batch.filter([next_batch.requests[0].data.id]) + + for _ in range( + default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - 1 + ): + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == '(self):\n """\n Test that the test' + assert ( + generations[0].request_id + == default_multi_requests_starcoder_batch.requests[0].data.id + ) + assert ( + generations[0].generated_text.generated_tokens + == default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens + ) + + +def test_batch_concatenate( + default_starcoder, default_starcoder_batch, default_multi_requests_starcoder_batch +): + next_batch_0 = default_starcoder_batch + _, next_batch_0, _ = default_starcoder.generate_token([next_batch_0]) + _, next_batch_0, _ = default_starcoder.generate_token([next_batch_0]) + _, next_batch_0, _ = default_starcoder.generate_token([next_batch_0]) + + next_batch_1 = default_multi_requests_starcoder_batch + _, next_batch_1, _ = default_starcoder.generate_token([next_batch_1]) + _, next_batch_1, _ = default_starcoder.generate_token([next_batch_1]) + + # Clone past_key_values before concatenating to compare after, + # because they are removed from the concatenated batches + next_batch_0_past_key_values = [x.clone() for x in next_batch_0.past_key_values] + next_batch_1_past_key_values = [x.clone() for x in next_batch_1.past_key_values] + + next_batch = StarCoderCausalLMBatch.concatenate([next_batch_0, next_batch_1]) + + assert torch.equal(next_batch.requests[0].all_input_ids, next_batch_0.requests[0].all_input_ids) + assert torch.equal(next_batch.requests[1].all_input_ids, next_batch_1.requests[0].all_input_ids) + assert torch.equal(next_batch.requests[2].all_input_ids, next_batch_1.requests[1].all_input_ids) + + + assert torch.all( + next_batch.attention_mask[0:2, -next_batch.right_padding - 2: -next_batch.right_padding] == 1 + ) + assert torch.all( + next_batch.attention_mask[2, -next_batch.right_padding - 3: -next_batch.right_padding] == 1 + ) + assert torch.all( + next_batch.attention_mask[3, -next_batch.right_padding - 2: -next_batch.right_padding] == 1 + ) + + assert torch.all( + next_batch.attention_mask[0:2, :-next_batch.right_padding-2] == 0) + assert torch.all( + next_batch.attention_mask[2, :-next_batch.right_padding-4] == 0) + assert torch.all( + next_batch.attention_mask[3, :-next_batch.right_padding-3] == 0) + + assert next_batch.batch_id == 0 + assert next_batch.input_ids[0,-next_batch.right_padding - 2] == 1006 + assert next_batch.input_ids[0,-next_batch.right_padding - 1] == 26 + + assert next_batch.max_input_length == 129 + + assert torch.all(next_batch.input_ids[0,-next_batch.right_padding:] == PAD_TOKEN) + assert torch.all(next_batch.input_ids[1,-next_batch.right_padding:] == PAD_TOKEN) + assert torch.all(next_batch.input_ids[2,-next_batch.right_padding:] == PAD_TOKEN) + assert torch.all(next_batch.input_ids[3,-next_batch.right_padding:] == PAD_TOKEN) + + assert next_batch.input_length == PAD_SEQUENCE_TO_MULTIPLE_OF +1 + assert next_batch.max_input_length == PAD_SEQUENCE_TO_MULTIPLE_OF + 1 + + assert next_batch.requests[0] == next_batch_0.requests[0] + assert next_batch.requests[1:] == next_batch_1.requests + + assert next_batch.requests[0].stopping_criteria == next_batch_0.requests[0].stopping_criteria + assert next_batch.requests[1].stopping_criteria == next_batch_1.requests[0].stopping_criteria + assert next_batch.requests[2].stopping_criteria == next_batch_1.requests[1].stopping_criteria + + assert next_batch.past_key_values is not None + + assert all([p[0].shape == (2048, 256) for p in next_batch.past_key_values]) + assert all([p[1].shape == (2048, 256) for p in next_batch.past_key_values]) + + assert next_batch.past_key_values is not None + + for i, past in enumerate(next_batch.past_key_values): + assert torch.equal(next_batch_0_past_key_values[i][0,0,0:128], past[0][1:129][0, 0:128]) + assert torch.equal(next_batch_0_past_key_values[i][0,1,0:128], past[1][1:129][0, 0:128]) + assert torch.equal( + next_batch_1_past_key_values[i][:, :, 0:1][0][0][0], past[0][1:, :][0][0] + ) + + assert torch.equal( + next_batch_1_past_key_values[i][1:, :, 0:1][0][0][0], past[1][1:, :][0][0] + ) + + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + + for _ in range( + default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens - 2 + ): + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert next_batch is not None + + assert len(generations) == 3 + assert generations[2].generated_text.text == '(self):\n """' + + assert ( + generations[2].request_id + == default_multi_requests_starcoder_batch.requests[1].data.id + ) + assert ( + generations[2].generated_text.generated_tokens + == default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens + ) + + next_batch = next_batch.filter( + [next_batch.requests[0].data.id, next_batch.requests[1].data.id] + ) + + for _ in range( + default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens + - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens + - 2 + ): + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert next_batch is not None + + assert len(generations) == 2 + assert generations[0].generated_text.text == '(self):\n """\n Test that the test' + assert generations[0].request_id == default_starcoder_batch.requests[0].data.id + assert ( + generations[0].generated_text.generated_tokens + == default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens + ) + + next_batch = next_batch.filter([next_batch.requests[1].data.id]) + + for _ in range( + default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens + - default_starcoder_batch.requests[0].stopping_criteria.max_new_tokens + - default_multi_requests_starcoder_batch.requests[1].stopping_criteria.max_new_tokens + - 4 + ): + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert len(generations) == len(next_batch) + + generations, next_batch, _ = default_starcoder.generate_token([next_batch]) + assert next_batch is None + + assert len(generations) == 1 + assert generations[0].generated_text.text == '(self):\n """\n Test that the test' + assert ( + generations[0].request_id + == default_multi_requests_starcoder_batch.requests[0].data.id + ) + assert ( + generations[0].generated_text.generated_tokens + == default_multi_requests_starcoder_batch.requests[0].stopping_criteria.max_new_tokens + ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 7eb9669f179..1cd1563f4c2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -14,7 +14,7 @@ from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM -from text_generation_server.models.santacoder import SantaCoder +from text_generation_server.models.starcoder import StarCoder from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -86,13 +86,7 @@ def get_model( model_type = config_dict["model_type"] if model_type == "gpt_bigcode": - return SantaCoder( - model_id, - revision, - use_medusa=use_medusa, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) + return StarCoder(model_id, revision, dtype) if model_type == "bloom": return BLOOM( diff --git a/server/text_generation_server/models/starcoder.py b/server/text_generation_server/models/starcoder.py new file mode 100644 index 00000000000..b956a94c77e --- /dev/null +++ b/server/text_generation_server/models/starcoder.py @@ -0,0 +1,51 @@ +from loguru import logger +import torch +from dataclasses import dataclass +import os +from typing import List, Optional, Type + +from text_generation_server.models import CausalLM +from text_generation_server.models.causal_lm import CausalLMBatch + + +@dataclass +class StarCoderCausalLMBatch(CausalLMBatch): + past_key_values: Optional[List[torch.Tensor]] + + def detach_kv_cache(self): + past_keys = [] + past_values = [] + last_dim = int(self.past_key_values[0].size(dim=-1)/2) + for key_value in self.past_key_values: + past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0]) + past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1]) + del self.past_key_values + + return past_keys, past_values + + def attach_kv_cache(self, past_keys, past_values): + self.past_key_values = [ + torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)] + + +class StarCoder(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + + # Bypasses runtime error "Empty tensor optional" with hpu graphs + os.environ["ENABLE_HPU_GRAPH"] = "false" + logger.warning("Disabling HPU graphs as they are not supported with Starcoder model!") + + super(StarCoder, self).__init__( + model_id=model_id, + revision=revision, + dtype=dtype, + ) + + @property + def batch_type(self) -> Type[CausalLMBatch]: + return StarCoderCausalLMBatch \ No newline at end of file