Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,6 @@ Options:
[required]
--output-jsonl TEXT Optionally write embeddings to local JSONLines
file (primarily for testing).
--batch-size INTEGER Number of embeddings to process per batch.
--help Show this message and exit.
```
16 changes: 15 additions & 1 deletion embeddings/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def model_required(f: Callable) -> Callable:
)
@functools.wraps(f)
def wrapper(*args: tuple, **kwargs: dict[str, str | Path]) -> Callable:
# early exit if --help passed
if "help" in kwargs:
return f(*args, **kwargs) # pragma: nocover

# pop "model_uri" and "model_path" from CLI args
model_uri: str = str(kwargs.pop("model_uri"))
model_path: str | Path = str(kwargs.pop("model_path"))
Expand Down Expand Up @@ -210,6 +214,14 @@ def test_model_load(ctx: click.Context) -> None:
default=None,
help="Optionally write embeddings to local JSONLines file (primarily for testing).",
)
@click.option(
"--batch-size",
required=False,
type=int,
default=100,
envvar="EMBEDDING_BATCH_SIZE",
help="Number of embeddings to process per batch.",
)
def create_embeddings(
ctx: click.Context,
dataset_location: str,
Expand All @@ -219,6 +231,7 @@ def create_embeddings(
input_jsonl: str,
strategy: list[str],
output_jsonl: str,
batch_size: int,
) -> None:
"""Create embeddings for TIMDEX records."""
model: BaseEmbeddingModel = ctx.obj["model"]
Expand Down Expand Up @@ -260,7 +273,7 @@ def create_embeddings(
embedding_inputs = create_embedding_inputs(timdex_records, list(strategy))

# create embeddings via the embedding model
embeddings = model.create_embeddings(embedding_inputs)
embeddings = model.create_embeddings(embedding_inputs, batch_size=batch_size)

# write embeddings to TIMDEX dataset (default) or to a JSONLines file
if output_jsonl:
Expand All @@ -278,6 +291,7 @@ def create_embeddings(
# if input_jsonl, init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)
timdex_dataset.embeddings.write(_dataset_embedding_iter(embeddings))
logger.info("Embeddings written to TIMDEX dataset.")

logger.info("Embeddings creation complete.")

Expand Down
5 changes: 4 additions & 1 deletion embeddings/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:

@abstractmethod
def create_embeddings(
self, embedding_inputs: Iterator[EmbeddingInput]
self,
embedding_inputs: Iterator[EmbeddingInput],
batch_size: int = 100,
) -> Iterator[Embedding]:
"""Yield Embeddings for multiple EmbeddingInputs.

Args:
embedding_inputs: iterator of EmbeddingInputs
batch_size: number of inputs to process per batch
"""
82 changes: 51 additions & 31 deletions embeddings/models/os_neural_sparse_doc_v3_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
import time
from collections.abc import Iterator
from itertools import batched
from pathlib import Path
from typing import cast

Expand Down Expand Up @@ -163,6 +164,7 @@ def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
def create_embeddings(
self,
embedding_inputs: Iterator[EmbeddingInput],
batch_size: int = 100,
) -> Iterator[Embedding]:
"""Yield Embeddings for multiple EmbeddingInputs.

Expand All @@ -173,25 +175,15 @@ def create_embeddings(
due to a "Bus Error". It is recommended to omit the env var TE_NUM_WORKERS, or
set to "1", in Docker contexts.

Currently, we also fully consume the input EmbeddingInputs before we start
embedding work. This may change in future iterations if we move to batching
embedding creation, so until then it's assumed that inputs to this method are
memory safe for the full run.
Embeddings are computed in batches to manage memory pressure. For this model
specifically, testing has shown that larger batches do not increase performance.

Args:
embedding_inputs: iterator of EmbeddingInputs
batch_size: number of inputs to process per batch
"""
# consume input EmbeddingInputs
embedding_inputs_list = list(embedding_inputs)
if not embedding_inputs_list:
return

# extract texts from all inputs
texts = [embedding_input.text for embedding_input in embedding_inputs_list]

# read env vars for configurations
num_workers = int(os.getenv("TE_NUM_WORKERS", "1"))
batch_size = int(os.getenv("TE_BATCH_SIZE", "32"))
te_batch_size = int(os.getenv("TE_BATCH_SIZE", "4"))
chunk_size_env = os.getenv("TE_CHUNK_SIZE")
chunk_size = int(chunk_size_env) if chunk_size_env else None

Expand All @@ -205,27 +197,55 @@ def create_embeddings(
device = self.device
pool = None
logger.info(
f"Num workers: {num_workers}, batch size: {batch_size}, "
f"chunk size: {chunk_size, }device: {device}, pool: {pool}"
f"Num workers: {num_workers}, application batch size: {batch_size}, "
f"model batch size: {te_batch_size}, device: {device}, pool: {pool}"
)

# get sparse vector embedding for input text(s)
inference_start = time.perf_counter()
sparse_vectors = self._model.encode_document(
texts,
batch_size=batch_size,
device=device,
pool=pool,
save_to_cpu=True,
chunk_size=chunk_size,
)
logger.info(f"Inference elapsed: {time.perf_counter()-inference_start}s")
sparse_vectors = cast("list[Tensor]", sparse_vectors)
batch_index = 0
try:
# create embeddings in batches
for embedding_inputs_batch in batched(embedding_inputs, batch_size):
batch_index += 1
batch_start = time.perf_counter()
texts = [
embedding_input.text for embedding_input in embedding_inputs_batch
]

# perform inference resulting in sparse vectors
sparse_vectors = self._model.encode_document(
texts,
batch_size=te_batch_size,
device=device,
pool=pool,
save_to_cpu=True,
chunk_size=chunk_size,
)
sparse_vectors = cast("list[Tensor]", sparse_vectors)
batch_elapsed = time.perf_counter() - batch_start
records_per_second = (
len(embedding_inputs_batch) / batch_elapsed
if batch_elapsed > 0
else 0.0
)
logger.debug(
f"Embeddings batch {batch_index}: "
f"{len(embedding_inputs_batch)} records, "
f"elapsed: {batch_elapsed:.2f}s, "
f"records/sec: {records_per_second:.2f}"
)

for i, embedding_input in enumerate(embedding_inputs_list):
sparse_vector = sparse_vectors[i]
sparse_vector = cast("Tensor", sparse_vector)
yield self._get_embedding_from_sparse_vector(embedding_input, sparse_vector)
# yield Embedding instances for batch
for i, embedding_input in enumerate(embedding_inputs_batch):
sparse_vector = sparse_vectors[i]
sparse_vector = cast("Tensor", sparse_vector)
yield self._get_embedding_from_sparse_vector(
embedding_input, sparse_vector
)
finally:
if pool is not None:
self._model.stop_multi_process_pool(pool)
logger.info(f"Inference elapsed: {time.perf_counter() - inference_start}s")

def _get_embedding_from_sparse_vector(
self,
Expand Down
11 changes: 8 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import zipfile
from collections.abc import Iterator
from itertools import batched
from pathlib import Path

import pytest
Expand Down Expand Up @@ -60,10 +61,14 @@ def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
)

def create_embeddings(
self, embedding_inputs: Iterator[EmbeddingInput]
self,
embedding_inputs: Iterator[EmbeddingInput],
batch_size: int = 100,
) -> Iterator[Embedding]:
for embedding_input in embedding_inputs:
yield self.create_embedding(embedding_input)
for embedding_inputs_batch in batched(embedding_inputs, batch_size):
logger.debug(f"Processing batch of {len(embedding_inputs_batch)} inputs")
for embedding_input in embedding_inputs_batch:
yield self.create_embedding(embedding_input)
Comment on lines +68 to +71
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This update to the testing model fixture shows very simply and directly how batching is applied.



@pytest.fixture
Expand Down
7 changes: 7 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def test_model_required_decorator_with_env_var(
assert output_path.exists()


def test_model_required_decorator_help_flag_early_exit(runner):
"""Ensure that passing --help does NOT require model information."""
result = runner.invoke(main, ["download-model", "--help"])

assert result.exit_code == 0


def test_model_required_decorator_missing_parameter(runner):
"""Test decorator fails when --model-uri is not provided and env var is not set."""
result = runner.invoke(main, ["download-model", "--model-path", "out.zip"])
Expand Down
46 changes: 46 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import zipfile

import pytest
Expand Down Expand Up @@ -53,6 +54,51 @@ def test_mock_model_create_embedding(mock_model):
assert embedding.embedding_token_weights == {"coffee": 0.9, "seattle": 0.5}


def test_mock_model_create_embeddings_embeds_all_records(mock_model):
embedding_inputs = [
EmbeddingInput(
timdex_record_id=f"test-id-{i}",
run_id="test-run",
run_record_offset=i,
embedding_strategy="full_record",
text=f"test text {i}",
)
for i in range(5)
]

embeddings = list(mock_model.create_embeddings(iter(embedding_inputs), batch_size=2))

assert len(embeddings) == len(embedding_inputs)
assert [e.timdex_record_id for e in embeddings] == [
inp.timdex_record_id for inp in embedding_inputs
]


def test_mock_model_create_embeddings_processes_in_batches(mock_model, caplog):
embedding_inputs = [
EmbeddingInput(
timdex_record_id=f"test-id-{i}",
run_id="test-run",
run_record_offset=i,
embedding_strategy="full_record",
text=f"test text {i}",
)
for i in range(5)
]

with caplog.at_level(logging.DEBUG):
list(mock_model.create_embeddings(iter(embedding_inputs), batch_size=2))

batch_logs = [r for r in caplog.records if "Processing batch" in r.message]
assert len(batch_logs) == 3 # 5 items with batch_size=2 = 3 batches


def test_mock_model_create_embeddings_empty_iterator(mock_model):
embeddings = list(mock_model.create_embeddings(iter([]), batch_size=2))

assert embeddings == []


def test_registry_contains_opensearch_model():
assert (
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
Expand Down
Loading