Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Fix llama tokenizer (#1635) #1636

Merged
merged 3 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _parse_requirements_file(file_path):
"black==22.12.0",
"flake8>=3.8.3",
"isort>=5.7.0",
"flaky~=3.7.0",
"pytest-rerunfailures>=13.0",
"ndjson>=0.3.1",
"wheel>=0.36.2",
"pytest>=6.0.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def run(
else [],
"finished_reason": [],
"token_generator": token_generator,
"past_tokens_queue": copy.copy(tokens),
}

if kv_cache is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from typing import Optional
from typing import List, Optional

import numpy

Expand Down Expand Up @@ -54,6 +54,33 @@ def _create_generated_text_output(
finished=False,
)

def _generate_streamed_text_from_past_tokens(
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
) -> str:
"""
An auxiliary method that helps to properly generate the streamed text.
Some models like llama2 and mistral are using LlamaTokenizer which is
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
to output appropriate prefix spaces when decoding token by token.
One can make it work if the previously generated tokens are included.
This allows the tokenizer to figure out that the appropriate spaces
from last n consecutive tokens.

:param generated_tokens: the generated tokens from the engine
:param past_tokens_queue: the queue of last n tokens (n is the
original prompt length in tokens)
:return: the generated string
"""
string_from_n_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.append(generated_tokens[0])
string_from_n_plus_1_tokens = self.tokenizer.decode(
past_tokens_queue, skip_special_tokens=True
)
past_tokens_queue.pop(0)
return [string_from_n_plus_1_tokens[len(string_from_n_tokens) :]]

def run(
self,
generated_tokens: numpy.ndarray,
Expand All @@ -64,9 +91,24 @@ def run(
):
generation_config = inference_state.current_state.get("generation_config")
generated_logits = generated_logits if generation_config.output_scores else None
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

import transformers

# Fix for LLAMA-specific models when running streaming
# TODO: make streaming a conditional input to this operator. using inference
# state is a quick fix.
if isinstance(
self.tokenizer,
(transformers.LlamaTokenizer, transformers.LlamaTokenizerFast),
) and inference_state.current_state.get("streaming"):
past_tokens_queue = inference_state.current_state.get("past_tokens_queue")
sequences = self._generate_streamed_text_from_past_tokens(
generated_tokens, past_tokens_queue
)
else:
sequences = self.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)

try:
finished_reason = [f[-1] for f in finished_reason]
Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from deepsparse.generated_version import is_enterprise, is_release, splash, version
except Exception:
# otherwise, fall back to version info in this file
version = "1.7.0"
version = "1.7.1"
is_release = False
is_enterprise = False
splash = (
Expand Down
3 changes: 1 addition & 2 deletions tests/deepsparse/pipelines/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from concurrent.futures import ThreadPoolExecutor
from unittest import mock

import flaky
import pytest
from deepsparse.legacy.base_pipeline import BasePipeline

Expand Down Expand Up @@ -125,7 +124,7 @@ def test_pipeline_executor_num_workers():
assert executor._max_workers >= 1


@flaky.flaky(max_runs=2, min_passes=1)
@pytest.mark.flaky(reruns=2, min_passes=1)
@mock_engine(rng_seed=0)
def test_pipeline_call_is_async(engine_mock):
# attempts to verify that pipeline calls to engine are async
Expand Down
6 changes: 3 additions & 3 deletions tests/server/test_legacy_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import Counter
from unittest import mock

import pytest
from deepsparse.legacy.loggers import PythonLogger
from deepsparse.legacy.loggers.config import (
PipelineSystemLoggingConfig,
Expand All @@ -30,7 +31,6 @@
from deepsparse.server.deepsparse_server import DeepsparseServer
from deepsparse.server.helpers import server_logger_from_config
from fastapi.testclient import TestClient
from flaky import flaky
from tests.deepsparse.legacy.loggers.helpers import fetch_leaf_logger
from tests.helpers import find_free_port
from tests.test_data.server_test_data import SAMPLE_LOGS_DICT
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_data_logging_from_predefined():
assert log == expected_log


@flaky(max_runs=4, min_passes=3)
@pytest.mark.flaky(reruns=4, min_passes=3)
def test_logging_only_system_info():
server_config = ServerConfig(
endpoints=[EndpointConfig(task=task, name=name, model=stub)],
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_multiple_targets_logging():
)


@flaky(max_runs=3, min_passes=2)
@pytest.mark.flaky(reruns=3, min_passes=2)
def test_function_metric_with_target_loggers():
server_config = ServerConfig(
endpoints=[
Expand Down
Loading