Skip to content

Commit

Permalink
Add SearchApi integration (#7936)
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastjanPrachovskij authored Feb 28, 2024
1 parent db49062 commit 5bab940
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pipelines/examples/agents/react_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.")
parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev, SerpAPI or SearchApi.io key.")
parser.add_argument('--llm_name', choices=['THUDM/chatglm-6b', "THUDM/chatglm-6b-v1.1", "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b-v1.1", help="The chatbot models ")
parser.add_argument("--api_key", default=None, type=str, help="The API Key.")
args = parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions pipelines/examples/agents/react_example_cn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.")
parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.")
parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.")
parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI'], default="dense", help="The type of Retriever.")
parser.add_argument("--retriever", choices=['dense', 'SerperDev', 'SerpAPI', 'SearchApi'], default="dense", help="The type of Retriever.")
parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.")
parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.")
parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.")
parser.add_argument("--query_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The query_embedding_model path")
parser.add_argument("--passage_embedding_model", default="rocketqa-zh-base-query-encoder", type=str, help="The passage_embedding_model path")
parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path")
parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index")
parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev or SerpAPI key.")
parser.add_argument("--search_api_key", default=None, type=str, help="The Serper.dev, SerpAPI or SearchApi.io key.")
parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding")
parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types")
parser.add_argument('--llm_name', choices=['ernie-bot', 'THUDM/chatglm-6b', "gpt-3.5-turbo", "gpt-4"], default="THUDM/chatglm-6b", help="The chatbot models ")
Expand Down
107 changes: 107 additions & 0 deletions pipelines/pipelines/nodes/search_engine/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,110 @@ def search(self, query: str, **kwargs) -> List[Document]:
logger.debug("Serper.dev API returned %s documents for the query '%s'", len(documents), query)
result_docs = documents[:top_k]
return self.score_results(result_docs, len(answer_box) > 0)


class SearchApi(SearchEngine):
"""
SearchApi is a real-time search engine that provides an API to access search results from Google, Google Scholar, YouTube,
YouTube transcripts and more. See the [SearchApi website](https://www.searchapi.io/) for more details.
"""

def __init__(
self,
api_key: str,
top_k: Optional[int] = 10,
engine: Optional[str] = "google",
search_engine_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for SearchApi.
:param top_k: Number of results to return.
:param engine: Search engine to use, for example google, google_scholar, youtube, youtube_transcripts.
See the [SearchApi documentation](https://www.searchapi.io/docs/google) for the full list of supported engines.
:param search_engine_kwargs: Additional parameters passed to the SearchApi.
See the [SearchApi documentation](https://www.searchapi.io/docs/google) for the full list of supported parameters.
"""
super().__init__()
self.params_dict: Dict[str, Union[str, int, float]] = {}
self.api_key = api_key
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
self.engine = engine
self.top_k = top_k

def search(self, query: str, **kwargs) -> List[Document]:
"""
:param query: Query string.
:param kwargs: Additional parameters passed to the SearchApi. For example, you can set 'location' to 'New York,United States'
to localize search to the specific location.
:return: List[Document]
"""
kwargs = {**self.kwargs, **kwargs}
top_k = kwargs.pop("top_k", self.top_k)
url = "https://www.searchapi.io/api/v1/search"

params = {"q": query, **kwargs}
headers = {"Authorization": f"Bearer {self.api_key}", "X-SearchApi-Source": "PaddleNLP"}

if self.engine:
params["engine"] = self.engine
response = requests.get(url, params=params, headers=headers, timeout=90)

if response.status_code != 200:
raise Exception(f"Error while querying {self.__class__.__name__}: {response.text}")

json_content = json.loads(response.text)
documents = []
has_answer_box = False

if json_content.get("answer_box"):
if json_content["answer_box"].get("organic_result"):
title = json_content["answer_box"].get("organic_result").get("title", "")
link = json_content["answer_box"].get("organic_result").get("link", "")
if json_content["answer_box"].get("type") == "population_graph":
title = json_content["answer_box"].get("place", "")
link = json_content["answer_box"].get("explore_more_link", "")

title = json_content["answer_box"].get("title", "")
link = json_content["answer_box"].get("link")
content = json_content["answer_box"].get("answer") or json_content["answer_box"].get("snippet")

if link and content:
has_answer_box = True
documents.append(Document.from_dict({"title": title, "content": content, "link": link}))

if json_content.get("knowledge_graph"):
if json_content["knowledge_graph"].get("source"):
link = json_content["knowledge_graph"].get("source").get("link", "")

link = json_content["knowledge_graph"].get("website", "")
content = json_content["knowledge_graph"].get("description")

if link and content:
documents.append(
Document.from_dict(
{"title": json_content["knowledge_graph"].get("title", ""), "content": content, "link": link}
)
)

documents += [
Document.from_dict({"title": c["title"], "content": c.get("snippet", ""), "link": c["link"]})
for c in json_content["organic_results"]
]

if json_content.get("related_questions"):
for question in json_content["related_questions"]:
if question.get("source"):
link = question.get("source").get("link", "")
else:
link = ""

content = question.get("answer", "")

if link and content:
documents.append(
Document.from_dict({"title": question.get("question", ""), "content": content, "link": link})
)

logger.debug("SearchApi returned %s documents for the query '%s'", len(documents), query)
result_docs = documents[:top_k]
return self.score_results(result_docs, has_answer_box)
1 change: 1 addition & 0 deletions pipelines/pipelines/nodes/search_engine/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class WebSearch(BaseComponent):
WebSerach currently supports the following search engines providers (bridges):
- SerperDev (default)
- SearchApi
- SerpAPI
- BingAPI
Expand Down

0 comments on commit 5bab940

Please sign in to comment.