Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/search/search_dna/search_dna_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ nodes:
batch_size: 10
save_output: true
params:
data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
data_source: ncbi # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
ncbi_params:
email: test@example.com # NCBI requires an email address
tool: GraphGen # tool name for NCBI API
Expand Down
2 changes: 1 addition & 1 deletion examples/search/search_protein/search_protein_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ nodes:
batch_size: 10
save_output: true
params:
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
data_source: uniprot # data source for searcher, support: wikipedia, google, uniprot
uniprot_params:
use_local_blast: true # whether to use local blast for uniprot search
local_blast_db: /path/to/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot
Expand Down
2 changes: 1 addition & 1 deletion examples/search/search_rna/search_rna_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ nodes:
batch_size: 10
save_output: true
params:
data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
data_source: rnacentral # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
rnacentral_params:
use_local_blast: true # whether to use local blast for RNA search
local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension)
Expand Down
167 changes: 70 additions & 97 deletions graphgen/operators/search/search_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import partial
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple

from graphgen.bases import BaseOperator
from graphgen.common.init_storage import init_storage
from graphgen.utils import compute_content_hash, logger, run_concurrent
from graphgen.utils import logger, run_concurrent

if TYPE_CHECKING:
import pandas as pd
Expand All @@ -19,42 +19,47 @@ def __init__(
self,
working_dir: str = "cache",
kv_backend: str = "rocksdb",
data_sources: list = None,
data_source: str = None,
**kwargs,
):
super().__init__(working_dir=working_dir, op_name="search_service")
self.working_dir = working_dir
self.data_sources = data_sources or []
super().__init__(
working_dir=working_dir, kv_backend=kv_backend, op_name="search"
)
self.data_source = data_source
self.kwargs = kwargs
self.search_storage = init_storage(
backend=kv_backend, working_dir=working_dir, namespace="search"
)
self.searchers = {}
self.searcher = None

def _init_searchers(self):
def _init_searcher(self):
"""
Initialize all searchers (deferred import to avoid circular imports).
Initialize the searcher (deferred import to avoid circular imports).
"""
for datasource in self.data_sources:
if datasource in self.searchers:
continue
if datasource == "uniprot":
from graphgen.models import UniProtSearch
if self.searcher is not None:
return

if not self.data_source:
logger.error("Data source not specified")
return

params = self.kwargs.get("uniprot_params", {})
self.searchers[datasource] = UniProtSearch(**params)
elif datasource == "ncbi":
from graphgen.models import NCBISearch
if self.data_source == "uniprot":
from graphgen.models import UniProtSearch

params = self.kwargs.get("ncbi_params", {})
self.searchers[datasource] = NCBISearch(**params)
elif datasource == "rnacentral":
from graphgen.models import RNACentralSearch
params = self.kwargs.get("uniprot_params", {})
self.searcher = UniProtSearch(**params)
elif self.data_source == "ncbi":
from graphgen.models import NCBISearch

params = self.kwargs.get("rnacentral_params", {})
self.searchers[datasource] = RNACentralSearch(**params)
else:
logger.error(f"Unknown data source: {datasource}, skipping")
params = self.kwargs.get("ncbi_params", {})
self.searcher = NCBISearch(**params)
elif self.data_source == "rnacentral":
from graphgen.models import RNACentralSearch

params = self.kwargs.get("rnacentral_params", {})
self.searcher = RNACentralSearch(**params)
else:
logger.error(f"Unknown data source: {self.data_source}")

@staticmethod
async def _perform_search(
Expand All @@ -76,91 +81,59 @@ async def _perform_search(

result = searcher_obj.search(query)
if result:
result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-")
result["data_source"] = data_source
result["type"] = seed.get("type", "text")

return result

def _process_single_source(
self, data_source: str, seed_data: list[dict]
) -> list[dict]:
"""
process a single data source: check cache, search missing, update cache.
def process(self, batch: list) -> Tuple[list, dict]:
"""
searcher = self.searchers[data_source]

seeds_with_ids = []
for seed in seed_data:
query = seed.get("content", "")
if not query:
continue
doc_id = compute_content_hash(str(data_source) + query, "doc-")
seeds_with_ids.append((doc_id, seed))

if not seeds_with_ids:
return []

doc_ids = [doc_id for doc_id, _ in seeds_with_ids]
cached_results = self.search_storage.get_by_ids(doc_ids)

to_search_seeds = []
final_results = []
Search for items in the batch using the configured data source.

for (doc_id, seed), cached in zip(seeds_with_ids, cached_results):
if cached is not None:
if "_doc_id" not in cached:
cached["_doc_id"] = doc_id
final_results.append(cached)
else:
to_search_seeds.append(seed)

if to_search_seeds:
new_results = run_concurrent(
partial(
self._perform_search, searcher_obj=searcher, data_source=data_source
),
to_search_seeds,
desc=f"Searching {data_source} database",
unit="keyword",
)
new_results = [res for res in new_results if res is not None]

if new_results:
upsert_data = {res["_doc_id"]: res for res in new_results}
self.search_storage.upsert(upsert_data)
logger.info(
f"Saved {len(upsert_data)} new results to {data_source} cache"
)

final_results.extend(new_results)

return final_results

def process(self, batch: "pd.DataFrame") -> "pd.DataFrame":
import pandas as pd

docs = batch.to_dict(orient="records")
:param batch: List of items with 'content' and '_trace_id' fields
:return: A tuple of (results, meta_updates)
results: A list of search results.
meta_updates: A dict mapping source IDs to lists of trace IDs for the search results.
"""
self._init_searcher()

self._init_searchers()
if not self.searcher:
logger.error("Searcher not initialized")
return [], {}

seed_data = [doc for doc in docs if doc and "content" in doc]
# Filter seeds with valid content and _trace_id
seed_data = [
item for item in batch if item and "content" in item and "_trace_id" in item
]

if not seed_data:
logger.warning("No valid seeds in batch")
return pd.DataFrame([])

all_results = []
return [], {}

# Perform concurrent searches
results = run_concurrent(
partial(
self._perform_search,
searcher_obj=self.searcher,
data_source=self.data_source,
),
seed_data,
desc=f"Searching {self.data_source} database",
unit="keyword",
)

for data_source in self.data_sources:
if data_source not in self.searchers:
logger.error(f"Data source {data_source} not initialized, skipping")
# Filter out None results and add _trace_id from original seeds
final_results = []
meta_updates = {}
for result, seed in zip(results, seed_data):
if result is None:
continue
result["_trace_id"] = self.get_trace_id(result)
final_results.append(result)
# Map from source seed trace ID to search result trace ID
meta_updates.setdefault(seed["_trace_id"], []).append(result["_trace_id"])

source_results = self._process_single_source(data_source, seed_data)
all_results.extend(source_results)

if not all_results:
if not final_results:
logger.warning("No search results generated for this batch")

return pd.DataFrame(all_results)
return final_results, meta_updates