Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ Usage: embeddings download-model [OPTIONS]
Download a model from HuggingFace and save locally.

Options:
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
[required]
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name'). Defaults
to env var TE_MODEL_URI if set. [required]
--model-path PATH Path where the model will be downloaded to and loaded
from, e.g. '/path/to/model'. [required]
from, e.g. '/path/to/model'. Defaults to env var
TE_MODEL_PATH if set. [required]
--help Show this message and exit.
```

Expand All @@ -118,11 +119,13 @@ Usage: embeddings test-model-load [OPTIONS]
model loads correctly.

Options:
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
[required]
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name'). Defaults
to env var TE_MODEL_URI if set. [required]
--model-path PATH Path where the model will be downloaded to and loaded
from, e.g. '/path/to/model'. [required]
from, e.g. '/path/to/model'. Defaults to env var
TE_MODEL_PATH if set. [required]
--help Show this message and exit.

```

### `create-embeddings`
Expand All @@ -132,13 +135,18 @@ Usage: embeddings create-embeddings [OPTIONS]
Create embeddings for TIMDEX records.

Options:
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name').
Defaults to env var TE_MODEL_URI if set.
[required]
--model-path PATH Path where the model will be downloaded to and
loaded from, e.g. '/path/to/model'. [required]
loaded from, e.g. '/path/to/model'. Defaults to
env var TE_MODEL_PATH if set. [required]
--dataset-location PATH TIMDEX dataset location, e.g.
's3://timdex/dataset', to read records from.
--run-id TEXT TIMDEX ETL run id.
--run-id TEXT TIMDEX ETL run id. Mutually exclusive with
--source.
--source TEXT Retrieve current records from a TIMDEX source.
Mutually exclusive with --run-id.
--run-record-offset INTEGER TIMDEX ETL run record offset to start from,
default = 0.
--record-limit INTEGER Limit number of records after --run-record-
Expand All @@ -151,5 +159,7 @@ Options:
--output-jsonl TEXT Optionally write embeddings to local JSONLines
file (primarily for testing).
--batch-size INTEGER Number of embeddings to process per batch.
Defaults to env var EMBEDDING_BATCH_SIZE if
set.
--help Show this message and exit.
```
109 changes: 77 additions & 32 deletions embeddings/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def model_required(f: Callable) -> Callable:
"--model-uri",
envvar="TE_MODEL_URI",
required=True,
help="HuggingFace model URI (e.g., 'org/model-name')",
help=(
"HuggingFace model URI (e.g., 'org/model-name'). "
"Defaults to env var TE_MODEL_URI if set."
),
)
@click.option(
"--model-path",
Expand All @@ -82,7 +85,7 @@ def model_required(f: Callable) -> Callable:
type=click.Path(path_type=Path),
help=(
"Path where the model will be downloaded to and loaded from, "
"e.g. '/path/to/model'."
"e.g. '/path/to/model'. Defaults to env var TE_MODEL_PATH if set."
),
)
@functools.wraps(f)
Expand Down Expand Up @@ -171,7 +174,15 @@ def test_model_load(ctx: click.Context) -> None:
"--run-id",
required=False,
type=str,
help="TIMDEX ETL run id.",
help="TIMDEX ETL run id. Mutually exclusive with --source.",
)
@click.option(
"--source",
required=False,
type=str,
help=(
"Retrieve current records from a TIMDEX source. Mutually exclusive with --run-id."
),
)
@click.option(
"--run-record-offset",
Expand Down Expand Up @@ -220,62 +231,93 @@ def test_model_load(ctx: click.Context) -> None:
type=int,
default=100,
envvar="EMBEDDING_BATCH_SIZE",
help="Number of embeddings to process per batch.",
help=(
"Number of embeddings to process per batch. Defaults to env var "
"EMBEDDING_BATCH_SIZE if set."
),
)
def create_embeddings(
ctx: click.Context,
dataset_location: str,
run_id: str,
dataset_location: str | None,
run_id: str | None,
source: str | None,
run_record_offset: int,
record_limit: int,
input_jsonl: str,
record_limit: int | None,
input_jsonl: str | None,
strategy: list[str],
output_jsonl: str,
output_jsonl: str | None,
batch_size: int,
) -> None:
"""Create embeddings for TIMDEX records."""
model: BaseEmbeddingModel = ctx.obj["model"]
model.load()
timdex_dataset: TIMDEXDataset | None = None

# read input records from TIMDEX dataset (default) or a JSONLines file
# JSONLines input (primarily for testing)
if input_jsonl:
with (
smart_open.open(input_jsonl, "r") as file_obj, # type: ignore[no-untyped-call]
jsonlines.Reader(file_obj) as reader,
):
timdex_records = iter(list(reader))

# default: read from TIMDEX dataset
else:
if not dataset_location or not run_id:
if not dataset_location:
raise click.UsageError(
"Both '--dataset-location' and '--run-id' are required arguments "
"when reading input records from the TIMDEX dataset."
"'--dataset-location' is required when reading input records from "
"the TIMDEX dataset."
)

if run_id and source:
raise click.UsageError("Use either '--run-id' or '--source', not both.")
if not run_id and not source:
raise click.UsageError(
Comment on lines +270 to +273

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent use of exceptions to force expected usage!

"One of '--run-id' or '--source' is required when reading "
"input records from the TIMDEX dataset."
)

# init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)

# query TIMDEX dataset for an iterator of records
timdex_records = timdex_dataset.read_dicts_iter(
columns=[
"timdex_record_id",
"run_id",
"run_record_offset",
"transformed_record",
],
run_id=run_id,
where=f"""run_record_offset >= {run_record_offset}""",
limit=record_limit,
action="index",
)
# get ETL run records
if run_id:
timdex_records = timdex_dataset.read_dicts_iter(
table="records",
columns=[
"timdex_record_id",
"run_id",
"run_record_offset",
"transformed_record",
],
run_id=run_id,
where=f"run_record_offset >= {run_record_offset}",
limit=record_limit,
action="index",
)

# get current records for a source
else:
timdex_records = timdex_dataset.read_dicts_iter(
table="current_records",
columns=[
"timdex_record_id",
"run_id",
"run_record_offset",
"transformed_record",
],
source=source,
where=f"run_record_offset >= {run_record_offset}",
limit=record_limit,
action="index",
)

# create an iterator of EmbeddingInputs applying all requested strategies
embedding_inputs = create_embedding_inputs(timdex_records, list(strategy))

# create embeddings via the embedding model
embeddings = model.create_embeddings(embedding_inputs, batch_size=batch_size)

# write embeddings to TIMDEX dataset (default) or to a JSONLines file
# JSONLines output (primarily for testing)
if output_jsonl:
with (
smart_open.open(output_jsonl, "w") as s3_file, # type: ignore[no-untyped-call]
Expand All @@ -286,13 +328,16 @@ def create_embeddings(
):
for embedding in embeddings:
writer.write(embedding.to_dict())
else:
if not timdex_dataset:
# if input_jsonl, init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)
logger.info(f"Embeddings written to JSONLines file: {output_jsonl}")

# default: write to TIMDEX dataset
elif timdex_dataset:
timdex_dataset.embeddings.write(_dataset_embedding_iter(embeddings))
logger.info("Embeddings written to TIMDEX dataset.")

else:
logger.warning("No output destination specified for embeddings")

logger.info("Embeddings creation complete.")


Expand Down
9 changes: 8 additions & 1 deletion embeddings/models/os_neural_sparse_doc_v3_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
"""

MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
OPENSEARCH_MODEL_REVISION = "1646fef40807937e8e130c66d327a26421c408d5"
ALIBABA_NEW_IMPL_REVISION = "40ced75c3017eb27626c9d4ea981bde21a2662f4"

def __init__(self, model_path: str | Path) -> None:
"""Initialize the model.
Expand All @@ -56,7 +58,11 @@ def download(self) -> Path:
temp_path = Path(temp_dir)

# download snapshot of HuggingFace model
snapshot_download(repo_id=self.model_uri, local_dir=temp_path)
snapshot_download(
repo_id=self.model_uri,
local_dir=temp_path,
revision=self.OPENSEARCH_MODEL_REVISION,
)
logger.debug("Model download complete.")

# patch local model with files from dependency model "Alibaba-NLP/new-impl"
Expand Down Expand Up @@ -100,6 +106,7 @@ def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> Non
snapshot_download(
repo_id="Alibaba-NLP/new-impl",
local_dir=str(temp_path),
revision=self.ALIBABA_NEW_IMPL_REVISION,
)

logger.info("Copying Alibaba code and updating config.json")
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"sentry-sdk>=2.34.1",
"smart-open[s3]>=7.4.4",
"timdex-dataset-api",
"transformers>=4.57.6,<5.0.0",
]

[dependency-groups]
Expand Down
59 changes: 56 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def test_model_required_decorator_help_flag_early_exit(runner):
assert result.exit_code == 0


def test_model_required_decorator_missing_parameter(runner):
def test_model_required_decorator_missing_parameter(monkeypatch, runner):
"""Test decorator fails when --model-uri is not provided and env var is not set."""
monkeypatch.delenv("TE_MODEL_URI", raising=False)
result = runner.invoke(main, ["download-model", "--model-path", "out.zip"])

assert result.exit_code != 0
Expand Down Expand Up @@ -190,6 +191,37 @@ def test_create_embeddings_writes_to_timdex_dataset(
assert isinstance(embedding_row.embedding_vector, np.ndarray)


def test_create_embeddings_writes_to_timdex_dataset_by_source(
caplog,
runner,
dataset_with_records,
register_mock_model,
):
caplog.set_level("DEBUG")

result = runner.invoke(
main,
[
"--verbose",
"create-embeddings",
"--model-uri",
"test/mock-model",
"--dataset-location",
dataset_with_records.location,
"--source",
"apples",
"--strategy",
"full_record",
],
)

assert result.exit_code == 0

timdex_dataset = TIMDEXDataset(location=dataset_with_records.location)
embeddings_df = timdex_dataset.embeddings.read_dataframe(run_id="run-1")
assert len(embeddings_df) == 2


def test_create_embeddings_requires_strategy(register_mock_model, runner):
result = runner.invoke(
main,
Expand Down Expand Up @@ -221,7 +253,7 @@ def test_create_embeddings_requires_dataset_location(register_mock_model, runner
],
)
assert result.exit_code != 0
assert "Both '--dataset-location' and '--run-id' are required" in result.output
assert "'--dataset-location' is required" in result.output


def test_create_embeddings_requires_run_id(register_mock_model, runner):
Expand All @@ -238,7 +270,28 @@ def test_create_embeddings_requires_run_id(register_mock_model, runner):
],
)
assert result.exit_code != 0
assert "Both '--dataset-location' and '--run-id' are required" in result.output
assert "One of '--run-id' or '--source' is required" in result.output


def test_create_embeddings_requires_single_read_mode(register_mock_model, runner):
result = runner.invoke(
main,
[
"create-embeddings",
"--model-uri",
"test/mock-model",
"--dataset-location",
"s3://test",
"--run-id",
"run-1",
"--source",
"apples",
"--strategy",
"full_record",
],
)
assert result.exit_code != 0
assert "Use either '--run-id' or '--source', not both." in result.output


def test_create_embeddings_optional_input_jsonl(register_mock_model, runner, tmp_path):
Expand Down
Loading