Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 73 additions & 30 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,26 +727,20 @@ def scanner(
for valid SQL expressions. Expression filter is applied to filtered scan,
full text search and vector search.

- VectorSearchQuery is a vector search that can only be applied to full
text search. Example:
.. code-block:: python

filter=VectorSearchQuery(
"vector",
np.array([12, 17, 300, 10], dtype=np.float32),
5,
20,
True,
)
- Search filter is a post-filter (query filter) that re-ranks results and
optionally removes rows.

- FullTextQuery is a full text search that can only be applied to vector
search. Example:
.. code-block:: python
* FullTextQuery re-ranks by BM25 score and can optionally apply
`score_threshold` (keep rows where `_score >= score_threshold`).
* VectorSearchQuery re-ranks by vector distance (top-k) and can
optionally apply `distance_range` (keep rows where
`lower_bound <= _distance < upper_bound`).

filter=PhraseQuery("hello world", "col")
Search filters can be used for any kind of scan, including non-search
scans.

- Dictionary is a combined filter containing both expression filter with
key `expr_filter` and search filter with key `search_filter`. Example:
key `expr_filter` and search filter with key `search_filter`. Example:
.. code-block:: python

scanner = ds.scanner(
Expand All @@ -759,7 +753,9 @@ def scanner(
},
filter={
"expr_filter": "category='geography'",
"search_filter": PhraseQuery("hello world", "col"),
"search_filter": MatchQuery("hello world", "col"),
# Optional, only for FullTextQuery post-filter
"score_threshold": 0.0,
},
)
limit: int, default None
Expand Down Expand Up @@ -4900,7 +4896,13 @@ def filter(

search_filter = filter.get("search_filter")
if search_filter is not None:
self.filter(search_filter)
if isinstance(search_filter, FullTextQuery):
score_threshold = filter.get("score_threshold")
self._search_filter = PySearchFilter.from_full_text_query(
search_filter.inner, score_threshold
)
else:
self.filter(search_filter)

return self

Expand Down Expand Up @@ -6274,19 +6276,60 @@ def __init__(
refine_factor: Optional[int] = None,
use_index: bool = True,
ef: Optional[int] = None,
distance_range: Optional[tuple[Optional[float], Optional[float]]] = None,
):
self._inner = _build_vector_search_query(
column,
q,
k=k,
metric=metric,
nprobes=nprobes,
minimum_nprobes=minimum_nprobes,
maximum_nprobes=maximum_nprobes,
refine_factor=refine_factor,
use_index=use_index,
ef=ef,
)
q, q_dim = _coerce_query_vector(q)

if k is not None and int(k) <= 0:
raise ValueError(f"Nearest-K must be > 0 but got {k}")
if nprobes is not None and int(nprobes) <= 0:
raise ValueError(f"Nprobes must be > 0 but got {nprobes}")
if minimum_nprobes is not None and int(minimum_nprobes) < 0:
raise ValueError(f"Minimum nprobes must be >= 0 but got {minimum_nprobes}")
if maximum_nprobes is not None and int(maximum_nprobes) < 0:
raise ValueError(f"Maximum nprobes must be >= 0 but got {maximum_nprobes}")

if nprobes is not None:
if minimum_nprobes is not None or maximum_nprobes is not None:
raise ValueError(
"nprobes cannot be set in combination with minimum_nprobes or "
"maximum_nprobes"
)
else:
minimum_nprobes = nprobes
maximum_nprobes = nprobes
if (
minimum_nprobes is not None
and maximum_nprobes is not None
and minimum_nprobes > maximum_nprobes
):
raise ValueError("minimum_nprobes must be <= maximum_nprobes")
if refine_factor is not None and int(refine_factor) < 1:
raise ValueError(f"Refine factor must be 1 or more got {refine_factor}")
if ef is not None and int(ef) <= 0:
# `ef` should be >= `k`, but `k` could be None so we can't check it here
# the rust code will check it
raise ValueError(f"ef must be > 0 but got {ef}")

if distance_range is not None:
if len(distance_range) != 2:
raise ValueError(
"distance_range must be a tuple of (lower_bound, upper_bound)"
)

self._inner = {
"column": column,
"q": q,
"k": k,
"metric": metric,
"minimum_nprobes": minimum_nprobes,
"maximum_nprobes": maximum_nprobes,
"refine_factor": refine_factor,
"use_index": use_index,
"ef": ef,
"distance_range": distance_range,
}

@property
def inner(self):
return self._inner
89 changes: 23 additions & 66 deletions python/python/tests/test_scalar_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import pyarrow as pa
import pytest
from lance.dataset import VectorSearchQuery
from lance.indices import IndexConfig
from lance.query import (
BooleanQuery,
Expand Down Expand Up @@ -4375,7 +4376,7 @@ def test_describe_indices(tmp_path):
assert index.num_rows_indexed == 50


def test_vector_filter_fts_search(tmp_path):
def test_vector_post_filter_full_text_search(tmp_path):
# Create test data
ids = list(range(1, 301))
vectors = [[float(i)] * 4 for i in ids]
Expand Down Expand Up @@ -4421,74 +4422,30 @@ def test_vector_filter_fts_search(tmp_path):
)
ds.create_scalar_index("text", index_type="INVERTED", with_position=True)

# Create vector_query
vector_query = {
"column": "vector",
"q": np.array([300, 300, 300, 300], dtype=np.float32),
"k": 5,
"minimum_nprobes": 20,
"use_index": True,
}
# Base query: full text search
base_fts = MatchQuery("text", "text")

# Case 1: search with prefilter=true, query_filter=vector([300,300,300,300])
scanner = ds.scanner(
prefilter=False, nearest=vector_query, filter=MatchQuery("text", "text")
# Post-filter: vector search (re-rank by distance)
vector_post_filter = VectorSearchQuery(
"vector",
np.array([300, 300, 300, 300], dtype=np.float32),
k=5,
minimum_nprobes=20,
use_index=True,
)
result = scanner.to_table()
assert [300, 299] == result["id"].to_pylist()

# Case 2: search with prefilter=true, search_filter=match("text"),
# filter="category='geography'"
scanner = ds.scanner(
prefilter=True,
nearest=vector_query,
filter={
"expr_filter": "category='geography'",
"search_filter": MatchQuery("text", "text"),
},
)
result = scanner.to_table()
assert [300, 255, 252, 249, 246] == result["id"].to_pylist()
# Case 1: full text search + vector post-filter
result = ds.scanner(full_text_query=base_fts, filter=vector_post_filter).to_table()
assert [300, 299, 255, 254, 253] == result["id"].to_pylist()

# Case 3: search with prefilter=false, search_filter=match("text")
scanner = ds.scanner(
prefilter=False,
nearest=vector_query,
filter=MatchQuery("text", "text"),
# Case 2: full text search + vector post-filter + distance_range
vector_post_filter = VectorSearchQuery(
"vector",
np.array([300, 300, 300, 300], dtype=np.float32),
k=5,
minimum_nprobes=20,
use_index=True,
distance_range=(None, 5.0),
)
result = scanner.to_table()
result = ds.scanner(full_text_query=base_fts, filter=vector_post_filter).to_table()
assert [300, 299] == result["id"].to_pylist()

# Case 4: search with prefilter=false, search_filter=match("text"),
# filter="category='geography'"
scanner = ds.scanner(
prefilter=False,
nearest=vector_query,
filter={
"expr_filter": "category='geography'",
"search_filter": MatchQuery("text", "text"),
},
)
result = scanner.to_table()
assert [300] == result["id"].to_pylist()

# Case 5: search with prefilter=false, search_filter=phrase("text")
scanner = ds.scanner(
prefilter=False,
nearest=vector_query,
filter=PhraseQuery("text", "text"),
)
with pytest.raises(ValueError):
scanner.to_table()

# Case 6: search with prefilter=false, search_filter=phrase("text")
scanner = ds.scanner(
prefilter=False,
nearest=vector_query,
filter={
"expr_filter": "category='geography'",
"search_filter": PhraseQuery("text", "text"),
},
)
with pytest.raises(ValueError):
scanner.to_table()
86 changes: 14 additions & 72 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from lance import LanceDataset, LanceFragment
from lance.dataset import Index, VectorIndexReader
from lance.indices import IndexFileVersion, IndicesBuilder
from lance.query import MatchQuery, PhraseQuery
from lance.query import MatchQuery
from lance.util import validate_vector_index # noqa: E402
from lance.vector import vec_to_table # noqa: E402

Expand Down Expand Up @@ -2792,7 +2792,7 @@ def collect_ids_and_distances(ds_with_index):
assert np.allclose(a, b, atol=1e-6)


def test_fts_filter_vector_search(tmp_path):
def test_fts_post_filter_vector_search(tmp_path):
# Create dataset with vector and text columns
ids = list(range(1, 301))
vectors = [[float(i)] * 4 for i in ids]
Expand Down Expand Up @@ -2838,78 +2838,20 @@ def test_fts_filter_vector_search(tmp_path):
dataset.create_scalar_index("text", index_type="INVERTED", with_position=True)

query_vector = [300.0, 300.0, 300.0, 300.0]
nearest = {"column": "vector", "q": query_vector, "k": 5}

# Case 1: search with prefilter=true, query_filter=match("text")
scanner = dataset.scanner(
filter=MatchQuery("text", "text"),
nearest={"column": "vector", "q": query_vector, "k": 5},
prefilter=True,
)

result = scanner.to_table()
ids_result = result["id"].to_pylist()
assert [300, 299, 255, 254, 253] == ids_result
# Case 1: vector search + FTS post-filter
result = dataset.scanner(
nearest=nearest, filter=MatchQuery("text", "text")
).to_table()
assert [299, 300, 296, 297, 298] == result["id"].to_pylist()

# Case 2: search with prefilter=true, search_filter=match("text"),
# filter="category='geography'"
scanner = dataset.scanner(
nearest={"column": "vector", "q": query_vector, "k": 5},
prefilter=True,
# Case 2: vector search + FTS post-filter + score_threshold
result = dataset.scanner(
nearest=nearest,
filter={
"expr_filter": "category='geography'",
"search_filter": MatchQuery("text", "text"),
"score_threshold": float("inf"),
},
)

result = scanner.to_table()
ids_result = result["id"].to_pylist()
assert [300, 255, 252, 249, 246] == ids_result

# Case 3: search with prefilter=false, search_filter=match("text")
scanner = dataset.scanner(
filter=MatchQuery("text", "text"),
nearest={"column": "vector", "q": query_vector, "k": 5},
prefilter=False,
)

result = scanner.to_table()
ids_result = result["id"].to_pylist()
assert [300, 299] == ids_result

# Case 4: search with prefilter=false, search_filter=match("text"),
# filter="category='geography'"
scanner = dataset.scanner(
nearest={"column": "vector", "q": query_vector, "k": 5},
prefilter=False,
filter={
"expr_filter": "category='geography'",
"search_filter": MatchQuery("text", "text"),
},
)

result = scanner.to_table()
ids_result = result["id"].to_pylist()
assert [300] == ids_result

# Case 5: search with prefilter=false, search_filter=phrase("text")
scanner = dataset.scanner(
nearest={"column": "vector", "q": query_vector, "k": 5},
prefilter=False,
filter=PhraseQuery("text", "text"),
)

with pytest.raises(ValueError):
scanner.to_table()

# Case 6: search with prefilter=false, search_filter=phrase("text")
scanner = dataset.scanner(
nearest={"column": "vector", "q": query_vector, "k": 5},
prefilter=False,
filter={
"expr_filter": "category='geography'",
"search_filter": PhraseQuery("text", "text"),
},
)

with pytest.raises(ValueError):
scanner.to_table()
).to_table()
assert [] == result["id"].to_pylist()
Loading
Loading