Skip to content

Commit c35c4f8

Browse files
fix: update search_service (#174)
* fix: update search_service * refactor: refactor search_service
1 parent 9facd1f commit c35c4f8

File tree

4 files changed

+73
-100
lines changed

4 files changed

+73
-100
lines changed

examples/search/search_dna/search_dna_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ nodes:
2222
batch_size: 10
2323
save_output: true
2424
params:
25-
data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
25+
data_source: ncbi # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
2626
ncbi_params:
2727
email: test@example.com # NCBI requires an email address
2828
tool: GraphGen # tool name for NCBI API

examples/search/search_protein/search_protein_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ nodes:
2222
batch_size: 10
2323
save_output: true
2424
params:
25-
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
25+
data_source: uniprot # data source for searcher, support: wikipedia, google, uniprot
2626
uniprot_params:
2727
use_local_blast: true # whether to use local blast for uniprot search
2828
local_blast_db: /path/to/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot

examples/search/search_rna/search_rna_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ nodes:
2222
batch_size: 10
2323
save_output: true
2424
params:
25-
data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
25+
data_source: rnacentral # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
2626
rnacentral_params:
2727
use_local_blast: true # whether to use local blast for RNA search
2828
local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension)
Lines changed: 70 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from functools import partial
2-
from typing import TYPE_CHECKING, Optional
2+
from typing import TYPE_CHECKING, Optional, Tuple
33

44
from graphgen.bases import BaseOperator
55
from graphgen.common.init_storage import init_storage
6-
from graphgen.utils import compute_content_hash, logger, run_concurrent
6+
from graphgen.utils import logger, run_concurrent
77

88
if TYPE_CHECKING:
99
import pandas as pd
@@ -19,42 +19,47 @@ def __init__(
1919
self,
2020
working_dir: str = "cache",
2121
kv_backend: str = "rocksdb",
22-
data_sources: list = None,
22+
data_source: str = None,
2323
**kwargs,
2424
):
25-
super().__init__(working_dir=working_dir, op_name="search_service")
26-
self.working_dir = working_dir
27-
self.data_sources = data_sources or []
25+
super().__init__(
26+
working_dir=working_dir, kv_backend=kv_backend, op_name="search"
27+
)
28+
self.data_source = data_source
2829
self.kwargs = kwargs
2930
self.search_storage = init_storage(
3031
backend=kv_backend, working_dir=working_dir, namespace="search"
3132
)
32-
self.searchers = {}
33+
self.searcher = None
3334

34-
def _init_searchers(self):
35+
def _init_searcher(self):
3536
"""
36-
Initialize all searchers (deferred import to avoid circular imports).
37+
Initialize the searcher (deferred import to avoid circular imports).
3738
"""
38-
for datasource in self.data_sources:
39-
if datasource in self.searchers:
40-
continue
41-
if datasource == "uniprot":
42-
from graphgen.models import UniProtSearch
39+
if self.searcher is not None:
40+
return
41+
42+
if not self.data_source:
43+
logger.error("Data source not specified")
44+
return
4345

44-
params = self.kwargs.get("uniprot_params", {})
45-
self.searchers[datasource] = UniProtSearch(**params)
46-
elif datasource == "ncbi":
47-
from graphgen.models import NCBISearch
46+
if self.data_source == "uniprot":
47+
from graphgen.models import UniProtSearch
4848

49-
params = self.kwargs.get("ncbi_params", {})
50-
self.searchers[datasource] = NCBISearch(**params)
51-
elif datasource == "rnacentral":
52-
from graphgen.models import RNACentralSearch
49+
params = self.kwargs.get("uniprot_params", {})
50+
self.searcher = UniProtSearch(**params)
51+
elif self.data_source == "ncbi":
52+
from graphgen.models import NCBISearch
5353

54-
params = self.kwargs.get("rnacentral_params", {})
55-
self.searchers[datasource] = RNACentralSearch(**params)
56-
else:
57-
logger.error(f"Unknown data source: {datasource}, skipping")
54+
params = self.kwargs.get("ncbi_params", {})
55+
self.searcher = NCBISearch(**params)
56+
elif self.data_source == "rnacentral":
57+
from graphgen.models import RNACentralSearch
58+
59+
params = self.kwargs.get("rnacentral_params", {})
60+
self.searcher = RNACentralSearch(**params)
61+
else:
62+
logger.error(f"Unknown data source: {self.data_source}")
5863

5964
@staticmethod
6065
async def _perform_search(
@@ -76,91 +81,59 @@ async def _perform_search(
7681

7782
result = searcher_obj.search(query)
7883
if result:
79-
result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-")
8084
result["data_source"] = data_source
8185
result["type"] = seed.get("type", "text")
8286

8387
return result
8488

85-
def _process_single_source(
86-
self, data_source: str, seed_data: list[dict]
87-
) -> list[dict]:
88-
"""
89-
process a single data source: check cache, search missing, update cache.
89+
def process(self, batch: list) -> Tuple[list, dict]:
9090
"""
91-
searcher = self.searchers[data_source]
92-
93-
seeds_with_ids = []
94-
for seed in seed_data:
95-
query = seed.get("content", "")
96-
if not query:
97-
continue
98-
doc_id = compute_content_hash(str(data_source) + query, "doc-")
99-
seeds_with_ids.append((doc_id, seed))
100-
101-
if not seeds_with_ids:
102-
return []
103-
104-
doc_ids = [doc_id for doc_id, _ in seeds_with_ids]
105-
cached_results = self.search_storage.get_by_ids(doc_ids)
106-
107-
to_search_seeds = []
108-
final_results = []
91+
Search for items in the batch using the configured data source.
10992
110-
for (doc_id, seed), cached in zip(seeds_with_ids, cached_results):
111-
if cached is not None:
112-
if "_doc_id" not in cached:
113-
cached["_doc_id"] = doc_id
114-
final_results.append(cached)
115-
else:
116-
to_search_seeds.append(seed)
117-
118-
if to_search_seeds:
119-
new_results = run_concurrent(
120-
partial(
121-
self._perform_search, searcher_obj=searcher, data_source=data_source
122-
),
123-
to_search_seeds,
124-
desc=f"Searching {data_source} database",
125-
unit="keyword",
126-
)
127-
new_results = [res for res in new_results if res is not None]
128-
129-
if new_results:
130-
upsert_data = {res["_doc_id"]: res for res in new_results}
131-
self.search_storage.upsert(upsert_data)
132-
logger.info(
133-
f"Saved {len(upsert_data)} new results to {data_source} cache"
134-
)
135-
136-
final_results.extend(new_results)
137-
138-
return final_results
139-
140-
def process(self, batch: "pd.DataFrame") -> "pd.DataFrame":
141-
import pandas as pd
142-
143-
docs = batch.to_dict(orient="records")
93+
:param batch: List of items with 'content' and '_trace_id' fields
94+
:return: A tuple of (results, meta_updates)
95+
results: A list of search results.
96+
meta_updates: A dict mapping source IDs to lists of trace IDs for the search results.
97+
"""
98+
self._init_searcher()
14499

145-
self._init_searchers()
100+
if not self.searcher:
101+
logger.error("Searcher not initialized")
102+
return [], {}
146103

147-
seed_data = [doc for doc in docs if doc and "content" in doc]
104+
# Filter seeds with valid content and _trace_id
105+
seed_data = [
106+
item for item in batch if item and "content" in item and "_trace_id" in item
107+
]
148108

149109
if not seed_data:
150110
logger.warning("No valid seeds in batch")
151-
return pd.DataFrame([])
152-
153-
all_results = []
111+
return [], {}
112+
113+
# Perform concurrent searches
114+
results = run_concurrent(
115+
partial(
116+
self._perform_search,
117+
searcher_obj=self.searcher,
118+
data_source=self.data_source,
119+
),
120+
seed_data,
121+
desc=f"Searching {self.data_source} database",
122+
unit="keyword",
123+
)
154124

155-
for data_source in self.data_sources:
156-
if data_source not in self.searchers:
157-
logger.error(f"Data source {data_source} not initialized, skipping")
125+
# Filter out None results and add _trace_id from original seeds
126+
final_results = []
127+
meta_updates = {}
128+
for result, seed in zip(results, seed_data):
129+
if result is None:
158130
continue
131+
result["_trace_id"] = self.get_trace_id(result)
132+
final_results.append(result)
133+
# Map from source seed trace ID to search result trace ID
134+
meta_updates.setdefault(seed["_trace_id"], []).append(result["_trace_id"])
159135

160-
source_results = self._process_single_source(data_source, seed_data)
161-
all_results.extend(source_results)
162-
163-
if not all_results:
136+
if not final_results:
164137
logger.warning("No search results generated for this batch")
165138

166-
return pd.DataFrame(all_results)
139+
return final_results, meta_updates

0 commit comments

Comments
 (0)