Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions embeddings/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import click
import jsonlines
import smart_open
from timdex_dataset_api import DatasetEmbedding, TIMDEXDataset, TIMDEXEmbeddings
from timdex_dataset_api import DatasetEmbedding, TIMDEXDataset

from embeddings.config import configure_logger, configure_sentry
from embeddings.models.base import Embedding
Expand Down Expand Up @@ -277,8 +277,7 @@ def create_embeddings(
if not timdex_dataset:
# if input_jsonl, init TIMDEXDataset
timdex_dataset = TIMDEXDataset(dataset_location)
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset)
timdex_embeddings.write(_dataset_embedding_iter(embeddings))
timdex_dataset.embeddings.write(_dataset_embedding_iter(embeddings))

logger.info("Embeddings creation complete.")

Expand All @@ -296,4 +295,5 @@ def _dataset_embedding_iter(
embedding_strategy=embedding.embedding_strategy,
embedding_vector=embedding.embedding_vector,
embedding_object=json.dumps(embedding.embedding_token_weights).encode(),
embedding_timestamp=embedding.embedding_timestamp.isoformat(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably the most important change: setting the renamed column embedding_timestamp on write. Formerly we weren't setting a timestamp at all, relying on defaults in the TDA DatasetEmbedding class, but this is more explicit.

)
2 changes: 1 addition & 1 deletion embeddings/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Embedding:
embedding_vector: list[float] | None
embedding_token_weights: dict | None

timestamp: datetime.datetime = field(
embedding_timestamp: datetime.datetime = field(
default_factory=lambda: datetime.datetime.now(datetime.UTC)
)

Expand Down
49 changes: 49 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pytest
from click.testing import CliRunner
from timdex_dataset_api import TIMDEXDataset
from timdex_dataset_api.record import DatasetRecord

from embeddings.embedding import Embedding, EmbeddingInput
from embeddings.models import registry
Expand Down Expand Up @@ -111,3 +113,50 @@ def mock_snapshot(repo_id, local_dir, **kwargs):
"embeddings.models.os_neural_sparse_doc_v3_gte.snapshot_download", mock_snapshot
)
return mock_snapshot


@pytest.fixture
def dataset_with_records(tmp_path) -> TIMDEXDataset:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new fixture gives us a real dataset, with real records, that allows for writing embeddings associated with real records and thereby supporting read methods.

dataset_path = tmp_path / "dataset"

records = iter(
[
DatasetRecord(
timdex_record_id="apple:1",
source="apples",
run_id="run-1",
run_record_offset=0,
run_date="2025-12-16",
run_timestamp="2025-12-16T00:00:00",
run_type="full",
source_record=b"",
transformed_record=(
b"""{"title":"Apple 1","description":"""
b""""This is a tale about apples."}"""
),
action="index",
),
DatasetRecord(
timdex_record_id="apple:2",
source="apples",
run_id="run-1",
run_record_offset=1,
run_date="2025-12-16",
run_timestamp="2025-12-16T00:00:00",
run_type="full",
source_record=b"",
transformed_record=(
b"""{"title":"Apple 1","description":"""
b""""This is a tale about apples."}"""
),
action="index",
),
]
)

timdex_dataset = TIMDEXDataset(str(dataset_path))
timdex_dataset.write(records)
timdex_dataset.metadata.rebuild_dataset_metadata()

# reload and return dataset
return TIMDEXDataset(str(dataset_path))
48 changes: 26 additions & 22 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from pathlib import Path
from unittest.mock import patch

import numpy as np
from timdex_dataset_api import TIMDEXDataset
from timdex_dataset_api.embeddings import TIMDEXEmbeddings

from embeddings.cli import main

Expand Down Expand Up @@ -139,44 +139,48 @@ def test_model_required_decorator_works_across_commands(
assert "OK" in result.output


@patch("timdex_dataset_api.TIMDEXDataset.read_dicts_iter")
def test_create_embeddings_writes_to_timdex_dataset(
mock_timdex_dataset_read_dicts_iter, register_mock_model, runner, tmp_path
caplog,
runner,
dataset_with_records,
register_mock_model,
):
mock_timdex_dataset_read_dicts_iter.return_value = iter(
[
{
"timdex_record_id": "record:1",
"run_id": "run-1",
"run_record_offset": 0,
"transformed_record": '{"title":"Record 1","description":"This is a record about coffee in the mountains."}', # noqa: E501
}
]
)

# init TIMDEX Dataset and Embeddings
timdex_dataset = TIMDEXDataset(location=str(tmp_path / "dataset"))
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset)
caplog.set_level("DEBUG")

result = runner.invoke(
main,
[
"--verbose",
"create-embeddings",
"--model-uri",
"test/mock-model",
"--dataset-location",
str(tmp_path / "dataset"),
dataset_with_records.location,
"--run-id",
"run-1",
"--strategy",
"full_record",
],
)

# TODO @jonavellecuerdo: Update to use TIMDEXEmbeddings # noqa: FIX002
# read method when ready
# assert CLI logged and exited cleanly
assert result.exit_code == 0
assert Path(timdex_embeddings.data_embeddings_root).exists()
assert "total files: 1, total rows: 2" in caplog.text

# reload temp TIMDEXDataset post embeddings write
timdex_dataset = TIMDEXDataset(location=dataset_with_records.location)

# assert embeddings written
assert Path(timdex_dataset.embeddings.data_embeddings_root).exists()
embeddings_df = timdex_dataset.embeddings.read_dataframe(run_id="run-1")
assert len(embeddings_df) == 2

# assert embedding row structure
embedding_row = embeddings_df.iloc[0]
assert embedding_row.embedding_model == "test/mock-model"
assert embedding_row.embedding_strategy == "full_record"
assert isinstance(json.loads(embedding_row.embedding_object), dict)
assert isinstance(embedding_row.embedding_vector, np.ndarray)


def test_create_embeddings_requires_strategy(register_mock_model, runner):
Expand Down
Loading