11from pathlib import Path
2- from typing import Any , Dict , List , Optional , Type
2+ from typing import Any , Dict , List , Optional , Type , Union
33
44import redis .commands .search .reducers as reducers
55import yaml
88from redis .commands .search .aggregation import AggregateRequest , AggregateResult , Reducer
99from redis .exceptions import ResponseError
1010
11+ from redisvl .exceptions import RedisModuleVersionError
1112from redisvl .extensions .constants import ROUTE_VECTOR_FIELD_NAME
1213from redisvl .extensions .router .schema import (
1314 DistanceAggregationMethod ,
1718 SemanticRouterIndexSchema ,
1819)
1920from redisvl .index import SearchIndex
20- from redisvl .query import VectorRangeQuery
21+ from redisvl .query import FilterQuery , VectorRangeQuery
22+ from redisvl .query .filter import Tag
23+ from redisvl .redis .connection import RedisConnectionFactory
2124from redisvl .redis .utils import convert_bytes , hashify , make_dict
2225from redisvl .utils .log import get_logger
23- from redisvl .utils .utils import deprecated_argument , model_to_dict
26+ from redisvl .utils .utils import deprecated_argument , model_to_dict , scan_by_pattern
2427from redisvl .utils .vectorize .base import BaseVectorizer
2528from redisvl .utils .vectorize .text .huggingface import HFTextVectorizer
2629
@@ -98,9 +101,41 @@ def __init__(
98101 routes = routes ,
99102 vectorizer = vectorizer ,
100103 routing_config = routing_config ,
104+ redis_url = redis_url ,
105+ redis_client = redis_client ,
101106 )
107+
102108 self ._initialize_index (redis_client , redis_url , overwrite , ** connection_kwargs )
103109
110+ self ._index .client .json ().set (f"{ self .name } :route_config" , f"." , self .to_dict ()) # type: ignore
111+
112+ @classmethod
113+ def from_existing (
114+ cls ,
115+ name : str ,
116+ redis_client : Optional [Redis ] = None ,
117+ redis_url : str = "redis://localhost:6379" ,
118+ ** kwargs ,
119+ ) -> "SemanticRouter" :
120+ """Return SemanticRouter instance from existing index."""
121+ try :
122+ if redis_url :
123+ redis_client = RedisConnectionFactory .get_redis_connection (
124+ redis_url = redis_url ,
125+ ** kwargs ,
126+ )
127+ elif redis_client :
128+ RedisConnectionFactory .validate_sync_redis (redis_client )
129+ except RedisModuleVersionError as e :
130+ raise RedisModuleVersionError (
131+ f"Loading from existing index failed. { str (e )} "
132+ )
133+
134+ router_dict = redis_client .json ().get (f"{ name } :route_config" ) # type: ignore
135+ return cls .from_dict (
136+ router_dict , redis_url = redis_url , redis_client = redis_client
137+ )
138+
104139 @deprecated_argument ("dtype" )
105140 def _initialize_index (
106141 self ,
@@ -111,9 +146,11 @@ def _initialize_index(
111146 ** connection_kwargs ,
112147 ):
113148 """Initialize the search index and handle Redis connection."""
149+
114150 schema = SemanticRouterIndexSchema .from_params (
115151 self .name , self .vectorizer .dims , self .vectorizer .dtype # type: ignore
116152 )
153+
117154 self ._index = SearchIndex (
118155 schema = schema ,
119156 redis_client = redis_client ,
@@ -174,10 +211,10 @@ def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]):
174211 if route .name in route_thresholds :
175212 route .distance_threshold = route_thresholds [route .name ] # type: ignore
176213
177- def _route_ref_key (self , route_name : str , reference : str ) -> str :
214+ @staticmethod
215+ def _route_ref_key (index : SearchIndex , route_name : str , reference_hash : str ) -> str :
178216 """Generate the route reference key."""
179- reference_hash = hashify (reference )
180- return f"{ self ._index .prefix } :{ route_name } :{ reference_hash } "
217+ return f"{ index .prefix } :{ route_name } :{ reference_hash } "
181218
182219 def _add_routes (self , routes : List [Route ]):
183220 """Add routes to the router and index.
@@ -195,14 +232,18 @@ def _add_routes(self, routes: List[Route]):
195232 )
196233 # set route references
197234 for i , reference in enumerate (route .references ):
235+ reference_hash = hashify (reference )
198236 route_references .append (
199237 {
238+ "reference_id" : reference_hash ,
200239 "route_name" : route .name ,
201240 "reference" : reference ,
202241 "vector" : reference_vectors [i ],
203242 }
204243 )
205- keys .append (self ._route_ref_key (route .name , reference ))
244+ keys .append (
245+ self ._route_ref_key (self ._index , route .name , reference_hash )
246+ )
206247
207248 # set route if does not yet exist client side
208249 if not self .get (route .name ):
@@ -438,7 +479,7 @@ def remove_route(self, route_name: str) -> None:
438479 else :
439480 self ._index .drop_keys (
440481 [
441- self ._route_ref_key (route .name , reference )
482+ self ._route_ref_key (self . _index , route .name , hashify ( reference ) )
442483 for reference in route .references
443484 ]
444485 )
@@ -596,3 +637,155 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None:
596637 with open (fp , "w" ) as f :
597638 yaml_data = self .to_dict ()
598639 yaml .dump (yaml_data , f , sort_keys = False )
640+
641+ # reference methods
642+ def add_route_references (
643+ self ,
644+ route_name : str ,
645+ references : Union [str , List [str ]],
646+ ) -> List [str ]:
647+ """Add a reference(s) to an existing route.
648+
649+ Args:
650+ router_name (str): The name of the router.
651+ references (Union[str, List[str]]): The reference or list of references to add.
652+
653+ Returns:
654+ List[str]: The list of added references keys.
655+ """
656+
657+ if isinstance (references , str ):
658+ references = [references ]
659+
660+ route_references : List [Dict [str , Any ]] = []
661+ keys : List [str ] = []
662+
663+ # embed route references as a single batch
664+ reference_vectors = self .vectorizer .embed_many (references , as_buffer = True )
665+
666+ # set route references
667+ for i , reference in enumerate (references ):
668+ reference_hash = hashify (reference )
669+
670+ route_references .append (
671+ {
672+ "reference_id" : reference_hash ,
673+ "route_name" : route_name ,
674+ "reference" : reference ,
675+ "vector" : reference_vectors [i ],
676+ }
677+ )
678+ keys .append (self ._route_ref_key (self ._index , route_name , reference_hash ))
679+
680+ keys = self ._index .load (route_references , keys = keys )
681+
682+ route = self .get (route_name )
683+ if not route :
684+ raise ValueError (f"Route { route_name } not found in the SemanticRouter" )
685+ route .references .extend (references )
686+ self ._update_router_state ()
687+ return keys
688+
689+ @staticmethod
690+ def _make_filter_queries (ids : List [str ]) -> List [FilterQuery ]:
691+ """Create a filter query for the given ids."""
692+
693+ queries = []
694+
695+ for id in ids :
696+ fe = Tag ("reference_id" ) == id
697+ fq = FilterQuery (
698+ return_fields = ["reference_id" , "route_name" , "reference" ],
699+ filter_expression = fe ,
700+ )
701+ queries .append (fq )
702+
703+ return queries
704+
705+ def get_route_references (
706+ self ,
707+ route_name : str = "" ,
708+ reference_ids : List [str ] = [],
709+ keys : List [str ] = [],
710+ ) -> List [Dict [str , Any ]]:
711+ """Get references for an existing route route.
712+
713+ Args:
714+ router_name (str): The name of the router.
715+ references (Union[str, List[str]]): The reference or list of references to add.
716+
717+ Returns:
718+ List[Dict[str, Any]]]: Reference objects stored
719+ """
720+
721+ if reference_ids :
722+ queries = self ._make_filter_queries (reference_ids )
723+ elif route_name :
724+ if not keys :
725+ keys = scan_by_pattern (
726+ self ._index .client , f"{ self ._index .prefix } :{ route_name } :*" # type: ignore
727+ )
728+
729+ queries = self ._make_filter_queries (
730+ [key .split (":" )[- 1 ] for key in convert_bytes (keys )]
731+ )
732+ else :
733+ raise ValueError (
734+ "Must provide a route name, reference ids, or keys to get references"
735+ )
736+
737+ res = self ._index .batch_query (queries )
738+
739+ return [r [0 ] for r in res if len (r ) > 0 ]
740+
741+ def delete_route_references (
742+ self ,
743+ route_name : str = "" ,
744+ reference_ids : List [str ] = [],
745+ keys : List [str ] = [],
746+ ) -> int :
747+ """Get references for an existing semantic router route.
748+
749+ Args:
750+ router_name Optional(str): The name of the router.
751+ reference_ids Optional(List[str]]): The reference or list of references to delete.
752+ keys Optional(List[str]]): List of fully qualified keys (prefix:router:reference_id) to delete.
753+
754+ Returns:
755+ int: Number of objects deleted
756+ """
757+
758+ if reference_ids and not keys :
759+ queries = self ._make_filter_queries (reference_ids )
760+ res = self ._index .batch_query (queries )
761+ keys = [r [0 ]["id" ] for r in res if len (r ) > 0 ]
762+ elif not keys :
763+ keys = scan_by_pattern (
764+ self ._index .client , f"{ self ._index .prefix } :{ route_name } :*" # type: ignore
765+ )
766+
767+ if not keys :
768+ raise ValueError (f"No references found for route { route_name } " )
769+
770+ to_be_deleted = []
771+ for key in keys :
772+ route_name = key .split (":" )[- 2 ]
773+ to_be_deleted .append (
774+ (route_name , convert_bytes (self ._index .client .hgetall (key ))) # type: ignore
775+ )
776+
777+ deleted = self ._index .drop_keys (keys )
778+
779+ for route_name , delete in to_be_deleted :
780+ route = self .get (route_name )
781+ if not route :
782+ raise ValueError (f"Route { route_name } not found in the SemanticRouter" )
783+ route .references .remove (delete ["reference" ])
784+
785+ self ._update_router_state ()
786+
787+ return deleted
788+
789+ def _update_router_state (self ) -> None :
790+ """Update the router configuration in Redis."""
791+ self ._index .client .json ().set (f"{ self .name } :route_config" , f"." , self .to_dict ()) # type: ignore
0 commit comments