11from functools import partial
2- from typing import TYPE_CHECKING , Optional
2+ from typing import TYPE_CHECKING , Optional , Tuple
33
44from graphgen .bases import BaseOperator
55from graphgen .common .init_storage import init_storage
6- from graphgen .utils import compute_content_hash , logger , run_concurrent
6+ from graphgen .utils import logger , run_concurrent
77
88if TYPE_CHECKING :
99 import pandas as pd
@@ -19,42 +19,47 @@ def __init__(
1919 self ,
2020 working_dir : str = "cache" ,
2121 kv_backend : str = "rocksdb" ,
22- data_sources : list = None ,
22+ data_source : str = None ,
2323 ** kwargs ,
2424 ):
25- super ().__init__ (working_dir = working_dir , op_name = "search_service" )
26- self .working_dir = working_dir
27- self .data_sources = data_sources or []
25+ super ().__init__ (
26+ working_dir = working_dir , kv_backend = kv_backend , op_name = "search"
27+ )
28+ self .data_source = data_source
2829 self .kwargs = kwargs
2930 self .search_storage = init_storage (
3031 backend = kv_backend , working_dir = working_dir , namespace = "search"
3132 )
32- self .searchers = {}
33+ self .searcher = None
3334
34- def _init_searchers (self ):
35+ def _init_searcher (self ):
3536 """
36- Initialize all searchers (deferred import to avoid circular imports).
37+ Initialize the searcher (deferred import to avoid circular imports).
3738 """
38- for datasource in self .data_sources :
39- if datasource in self .searchers :
40- continue
41- if datasource == "uniprot" :
42- from graphgen .models import UniProtSearch
39+ if self .searcher is not None :
40+ return
41+
42+ if not self .data_source :
43+ logger .error ("Data source not specified" )
44+ return
4345
44- params = self .kwargs .get ("uniprot_params" , {})
45- self .searchers [datasource ] = UniProtSearch (** params )
46- elif datasource == "ncbi" :
47- from graphgen .models import NCBISearch
46+ if self .data_source == "uniprot" :
47+ from graphgen .models import UniProtSearch
4848
49- params = self .kwargs .get ("ncbi_params " , {})
50- self .searchers [ datasource ] = NCBISearch (** params )
51- elif datasource == "rnacentral " :
52- from graphgen .models import RNACentralSearch
49+ params = self .kwargs .get ("uniprot_params " , {})
50+ self .searcher = UniProtSearch (** params )
51+ elif self . data_source == "ncbi " :
52+ from graphgen .models import NCBISearch
5353
54- params = self .kwargs .get ("rnacentral_params" , {})
55- self .searchers [datasource ] = RNACentralSearch (** params )
56- else :
57- logger .error (f"Unknown data source: { datasource } , skipping" )
54+ params = self .kwargs .get ("ncbi_params" , {})
55+ self .searcher = NCBISearch (** params )
56+ elif self .data_source == "rnacentral" :
57+ from graphgen .models import RNACentralSearch
58+
59+ params = self .kwargs .get ("rnacentral_params" , {})
60+ self .searcher = RNACentralSearch (** params )
61+ else :
62+ logger .error (f"Unknown data source: { self .data_source } " )
5863
5964 @staticmethod
6065 async def _perform_search (
@@ -76,91 +81,59 @@ async def _perform_search(
7681
7782 result = searcher_obj .search (query )
7883 if result :
79- result ["_doc_id" ] = compute_content_hash (str (data_source ) + query , "doc-" )
8084 result ["data_source" ] = data_source
8185 result ["type" ] = seed .get ("type" , "text" )
8286
8387 return result
8488
85- def _process_single_source (
86- self , data_source : str , seed_data : list [dict ]
87- ) -> list [dict ]:
88- """
89- process a single data source: check cache, search missing, update cache.
89+ def process (self , batch : list ) -> Tuple [list , dict ]:
9090 """
91- searcher = self .searchers [data_source ]
92-
93- seeds_with_ids = []
94- for seed in seed_data :
95- query = seed .get ("content" , "" )
96- if not query :
97- continue
98- doc_id = compute_content_hash (str (data_source ) + query , "doc-" )
99- seeds_with_ids .append ((doc_id , seed ))
100-
101- if not seeds_with_ids :
102- return []
103-
104- doc_ids = [doc_id for doc_id , _ in seeds_with_ids ]
105- cached_results = self .search_storage .get_by_ids (doc_ids )
106-
107- to_search_seeds = []
108- final_results = []
91+ Search for items in the batch using the configured data source.
10992
110- for (doc_id , seed ), cached in zip (seeds_with_ids , cached_results ):
111- if cached is not None :
112- if "_doc_id" not in cached :
113- cached ["_doc_id" ] = doc_id
114- final_results .append (cached )
115- else :
116- to_search_seeds .append (seed )
117-
118- if to_search_seeds :
119- new_results = run_concurrent (
120- partial (
121- self ._perform_search , searcher_obj = searcher , data_source = data_source
122- ),
123- to_search_seeds ,
124- desc = f"Searching { data_source } database" ,
125- unit = "keyword" ,
126- )
127- new_results = [res for res in new_results if res is not None ]
128-
129- if new_results :
130- upsert_data = {res ["_doc_id" ]: res for res in new_results }
131- self .search_storage .upsert (upsert_data )
132- logger .info (
133- f"Saved { len (upsert_data )} new results to { data_source } cache"
134- )
135-
136- final_results .extend (new_results )
137-
138- return final_results
139-
140- def process (self , batch : "pd.DataFrame" ) -> "pd.DataFrame" :
141- import pandas as pd
142-
143- docs = batch .to_dict (orient = "records" )
93+ :param batch: List of items with 'content' and '_trace_id' fields
94+ :return: A tuple of (results, meta_updates)
95+ results: A list of search results.
96+ meta_updates: A dict mapping source IDs to lists of trace IDs for the search results.
97+ """
98+ self ._init_searcher ()
14499
145- self ._init_searchers ()
100+ if not self .searcher :
101+ logger .error ("Searcher not initialized" )
102+ return [], {}
146103
147- seed_data = [doc for doc in docs if doc and "content" in doc ]
104+ # Filter seeds with valid content and _trace_id
105+ seed_data = [
106+ item for item in batch if item and "content" in item and "_trace_id" in item
107+ ]
148108
149109 if not seed_data :
150110 logger .warning ("No valid seeds in batch" )
151- return pd .DataFrame ([])
152-
153- all_results = []
111+ return [], {}
112+
113+ # Perform concurrent searches
114+ results = run_concurrent (
115+ partial (
116+ self ._perform_search ,
117+ searcher_obj = self .searcher ,
118+ data_source = self .data_source ,
119+ ),
120+ seed_data ,
121+ desc = f"Searching { self .data_source } database" ,
122+ unit = "keyword" ,
123+ )
154124
155- for data_source in self .data_sources :
156- if data_source not in self .searchers :
157- logger .error (f"Data source { data_source } not initialized, skipping" )
125+ # Filter out None results and add _trace_id from original seeds
126+ final_results = []
127+ meta_updates = {}
128+ for result , seed in zip (results , seed_data ):
129+ if result is None :
158130 continue
131+ result ["_trace_id" ] = self .get_trace_id (result )
132+ final_results .append (result )
133+ # Map from source seed trace ID to search result trace ID
134+ meta_updates .setdefault (seed ["_trace_id" ], []).append (result ["_trace_id" ])
159135
160- source_results = self ._process_single_source (data_source , seed_data )
161- all_results .extend (source_results )
162-
163- if not all_results :
136+ if not final_results :
164137 logger .warning ("No search results generated for this batch" )
165138
166- return pd . DataFrame ( all_results )
139+ return final_results , meta_updates
0 commit comments