Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dependencies = [
"httpx[socks]>=0.28.1",
"litellm>=1.74.0.post1",
"jinja2>=3.0.0",
"ddgs",
]

[project.scripts]
Expand Down
2 changes: 2 additions & 0 deletions quantmind/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BaseSourceConfig,
NewsSourceConfig,
WebSourceConfig,
SearchSourceConfig,
)
from .storage import BaseStorageConfig, LocalStorageConfig
from .taggers import LLMTaggerConfig
Expand All @@ -37,6 +38,7 @@
"ArxivSourceConfig",
"NewsSourceConfig",
"WebSourceConfig",
"SearchSourceConfig",
# Storage Configurations
"BaseStorageConfig",
"LocalStorageConfig",
Expand Down
11 changes: 11 additions & 0 deletions quantmind/config/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,22 @@ def validate_delay(cls, v):
return v


class SearchSourceConfig(BaseSourceConfig):
"""Configuration for search sources."""

site: Optional[str] = Field(default=None, description="Restrict search to a specific domain.")
filetype: Optional[str] = Field(default=None, description="Search for specific file types.")
start_date: Optional[str] = Field(default=None, description="Start date for search results (YYYY-MM-DD).")
end_date: Optional[str] = Field(default=None, description="End date for search results (YYYY-MM-DD).")



# Source configuration registry
SOURCE_CONFIGS = {
"arxiv": ArxivSourceConfig,
"news": NewsSourceConfig,
"web": WebSourceConfig,
"search": SearchSourceConfig,
}


Expand Down
3 changes: 2 additions & 1 deletion quantmind/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from .content import BaseContent, KnowledgeItem
from .paper import Paper
from .search import SearchContent

__all__ = ["Paper", "BaseContent", "KnowledgeItem"]
__all__ = ["Paper", "BaseContent", "KnowledgeItem", "SearchContent"]
35 changes: 35 additions & 0 deletions quantmind/models/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Search content model."""

from typing import Any, Dict, List, Optional

from quantmind.models.content import BaseContent


class SearchContent(BaseContent):
"""Represents content from a search engine result."""

title: str
url: str
snippet: str
source: str = "search"
query: Optional[str] = None
meta_info: Dict[str, Any] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the function of meta_info?


def get_primary_id(self) -> str:
"""Return the primary identifier for the content."""
return self.url

def get_text_for_embedding(self) -> str:
"""Return the text to be used for generating embeddings."""
return f"{self.title}{self.snippet}"

def to_dict(self) -> Dict[str, any]:
"""Convert the object to a dictionary."""
return {
"title": self.title,
"url": self.url,
"snippet": self.snippet,
"source": self.source,
"query": self.query,
"meta_info": self.meta_info,
}
8 changes: 8 additions & 0 deletions quantmind/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,11 @@
__all__.append("ArxivSource")
except ImportError:
pass

# Conditionally import Search source
try:
from quantmind.sources.search_source import SearchSource

__all__.append("SearchSource")
except ImportError:
pass
111 changes: 111 additions & 0 deletions quantmind/sources/search_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Search source for fetching content from search engines."""

from typing import List, Optional

from ddgs import DDGS
from quantmind.config import SearchSourceConfig
from quantmind.models.search import SearchContent
from quantmind.sources.base import BaseSource
from quantmind.utils.logger import get_logger

logger = get_logger(__name__)


class SearchSource(BaseSource[SearchContent]):
"""SearchSource provides a way to fetch content from search engines.

Currently, it uses DuckDuckGo as the search provider.
"""

def __init__(self, config: Optional[SearchSourceConfig] = None):
"""
Initializes the SearchSource with an optional configuration.

Args:
config: A SearchSourceConfig object. If not provided, a default config is used.
"""
self.config = config or SearchSourceConfig()
super().__init__(self.config)
self.client = DDGS()

def search(
self,
query: str,
max_results: Optional[int] = None,
site: Optional[str] = None,
filetype: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> List[SearchContent]:
"""Performs a search query and returns a list of SearchContent objects.

Args:
query: The search query string.
max_results: The maximum number of results to return. Defaults to
the value in the config.
site: Restrict search to a specific domain.
filetype: Search for specific file types.
start_date: Start date for search results (YYYY-MM-DD).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the datetime format is wrong?

Add something like this?

@field_validator("start_date", "end_date")
@classmethod
def _validate_date(cls, v):
    if v is None:
        return v
    if not re.match(r"^\d{4}-\d{2}-\d{2}$", v):
        raise ValueError("date must be YYYY-MM-DD")
    return v

end_date: End date for search results (YYYY-MM-DD).

Returns:
A list of SearchContent objects.
"""
if max_results is None:
max_results = self.config.max_results

# Build the query with advanced search operators
search_query = query
if site or self.config.site:
search_query += f" site:{site or self.config.site}"
if filetype or self.config.filetype:
search_query += f" filetype:{filetype or self.config.filetype}"

# Handle date range
final_start_date = start_date or self.config.start_date
final_end_date = end_date or self.config.end_date
if final_start_date and final_end_date:
search_query += f" daterange:{final_start_date}..{final_end_date}"
elif final_start_date:
search_query += f" daterange:{final_start_date}.."
elif final_end_date:
search_query += f" daterange:..{final_end_date}"

try:
results = self.client.text(search_query, max_results=max_results)
search_content_list = [
SearchContent(
title=result["title"],
url=result["href"],
snippet=result["body"],
query=search_query,
source=self.name,
meta_info={},
)
for result in results
]
logger.info(
f"Found {len(search_content_list)} results for query: '{search_query}'"
)
return search_content_list
except Exception as e:
logger.error(f"An error occurred while searching with DuckDuckGo: {e}")
return []

def get_by_id(self, content_id: str) -> Optional[SearchContent]:
"""Retrieves content by its ID (URL).

This is not a standard use case for a search source, but it's
implemented for interface consistency. It performs a search for the URL.

Args:
content_id: The URL of the content to retrieve.

Returns:
A SearchContent object if the URL is found, otherwise None.
"""
# A bit of a hack to satisfy the interface. Search for the URL.
results = self.search(query=content_id, max_results=1)
if results and results[0].url == content_id:
return results[0]
return None
44 changes: 44 additions & 0 deletions tests/sources/test_search_source_advanced_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Integration tests for the advanced features of the SearchSource."""

import unittest

from quantmind.sources.search_source import SearchSource
from quantmind.config.sources import SearchSourceConfig
from quantmind.models.search import SearchContent


class TestSearchSourceAdvancedIntegration(unittest.TestCase):
"""Test suite for the advanced features of the SearchSource with real network requests."""

def setUp(self):
"""Set up the test case."""
self.source = SearchSource()

def test_search_with_site_filter(self):
"""Test a real search with a site filter."""
results = self.source.search("machine learning", site="arxiv.org")

self.assertGreater(len(results), 0)
for result in results:
self.assertIn("arxiv.org", result.url)

def test_search_with_filetype_filter(self):
"""Test a real search with a filetype filter."""
results = self.source.search("financial report", filetype="pdf")

self.assertGreater(len(results), 0)
# We can't guarantee that all results will have a .pdf extension in the URL,
# as the filetype search is a hint to the search engine.
# However, we can check if the query was constructed correctly.
self.assertIn("filetype:pdf", results[0].query)

def test_search_with_date_filter(self):
"""Test a real search with a date filter."""
results = self.source.search("AI", start_date="2023-01-01", end_date="2023-01-31")

self.assertGreater(len(results), 0)
self.assertIn("daterange:2023-01-01..2023-01-31", results[0].query)


if __name__ == "__main__":
unittest.main()
31 changes: 31 additions & 0 deletions tests/sources/test_search_source_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Integration tests for the SearchSource."""

import unittest

from quantmind.sources.search_source import SearchSource
from quantmind.config.sources import SearchSourceConfig
from quantmind.models.search import SearchContent


class TestSearchSourceIntegration(unittest.TestCase):
"""Test suite for the SearchSource with real network requests."""

def setUp(self):
"""Set up the test case."""
self.config = SearchSourceConfig(max_results=5)
self.source = SearchSource(config=self.config)

def test_search_finreport(self):
"""Test a real search for 'finreport'."""
results = self.source.search("finreport")

self.assertGreater(len(results), 0)
self.assertIsInstance(results[0], SearchContent)
self.assertIsNotNone(results[0].title)
self.assertIsNotNone(results[0].url)
self.assertIsNotNone(results[0].snippet)
self.assertIn("finreport", results[0].query.lower())


if __name__ == "__main__":
unittest.main()
Loading
Loading