Skip to content

Commit

Permalink
search updates (#19)
Browse files Browse the repository at this point in the history
* search updates

* add helper function

* make format

* updates
  • Loading branch information
prasmussen15 authored Aug 22, 2024
1 parent 6ae9c4e commit 94873f1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
25 changes: 13 additions & 12 deletions core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from time import time

from neo4j import AsyncDriver
from neo4j import time as neo4j_time

from core.edges import EntityEdge
from core.nodes import EntityNode, EpisodicNode
Expand All @@ -14,6 +15,10 @@
RELEVANT_SCHEMA_LIMIT = 3


def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
return neo_date.to_native() if neo_date else None


async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
Expand Down Expand Up @@ -122,8 +127,6 @@ async def edge_similarity_search(

edges: list[EntityEdge] = []

now = datetime.now()

for record in records:
edge = EntityEdge(
uuid=record['uuid'],
Expand All @@ -133,10 +136,10 @@ async def edge_similarity_search(
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)

edges.append(edge)
Expand Down Expand Up @@ -244,8 +247,6 @@ async def edge_fulltext_search(

edges: list[EntityEdge] = []

now = datetime.now()

for record in records:
edge = EntityEdge(
uuid=record['uuid'],
Expand All @@ -255,10 +256,10 @@ async def edge_fulltext_search(
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=now,
expired_at=now,
valid_at=now,
invalid_At=now,
created_at=record['created_at'].to_native(),
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)

edges.append(edge)
Expand Down
2 changes: 1 addition & 1 deletion core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ async def dedupe_edge_list(
unique_edges_data = llm_response.get('unique_edges', [])

end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start)*1000} ms ')
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')

# Get full edge data
unique_edges = []
Expand Down

0 comments on commit 94873f1

Please sign in to comment.