1818_singleton_lock = Lock ()
1919
2020
21+ class MuxMatchingError (Exception ):
22+ """An exception for muxing matching errors."""
23+
24+ pass
25+
26+
2127async def get_muxing_rules_registry ():
2228 """Returns a singleton instance of the muxing rules registry."""
2329
@@ -48,9 +54,9 @@ def __init__(
4854class MuxingRuleMatcher (ABC ):
4955 """Base class for matching muxing rules."""
5056
51- def __init__ (self , route : ModelRoute , matcher_blob : str ):
57+ def __init__ (self , route : ModelRoute , mux_rule : mux_models . MuxRule ):
5258 self ._route = route
53- self ._matcher_blob = matcher_blob
59+ self ._mux_rule = mux_rule
5460
5561 @abstractmethod
5662 def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
@@ -67,32 +73,24 @@ class MuxingMatcherFactory:
6773 """Factory for creating muxing matchers."""
6874
6975 @staticmethod
70- def create (mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
76+ def create (db_mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
7177 """Create a muxing matcher for the given endpoint and model."""
7278
7379 factory : Dict [mux_models .MuxMatcherType , MuxingRuleMatcher ] = {
74- mux_models .MuxMatcherType .catch_all : CatchAllMuxingRuleMatcher ,
75- mux_models .MuxMatcherType .filename_match : FileMuxingRuleMatcher ,
76- mux_models .MuxMatcherType .request_type_match : RequestTypeMuxingRuleMatcher ,
80+ mux_models .MuxMatcherType .catch_all : RequestTypeAndFileMuxingRuleMatcher ,
81+ mux_models .MuxMatcherType .fim : RequestTypeAndFileMuxingRuleMatcher ,
82+ mux_models .MuxMatcherType .chat : RequestTypeAndFileMuxingRuleMatcher ,
7783 }
7884
7985 try :
8086 # Initialize the MuxingRuleMatcher
81- return factory [mux_rule .matcher_type ](route , mux_rule .matcher_blob )
87+ mux_rule = mux_models .MuxRule .from_db_mux_rule (db_mux_rule )
88+ return factory [mux_rule .matcher_type ](route , mux_rule )
8289 except KeyError :
8390 raise ValueError (f"Unknown matcher type: { mux_rule .matcher_type } " )
8491
8592
86- class CatchAllMuxingRuleMatcher (MuxingRuleMatcher ):
87- """A catch all muxing rule matcher."""
88-
89- def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
90- logger .info ("Catch all rule matched" )
91- return True
92-
93-
94- class FileMuxingRuleMatcher (MuxingRuleMatcher ):
95- """A file muxing rule matcher."""
93+ class RequestTypeAndFileMuxingRuleMatcher (MuxingRuleMatcher ):
9694
9795 def _extract_request_filenames (self , detected_client : ClientType , data : dict ) -> set [str ]:
9896 """
@@ -103,47 +101,51 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) ->
103101 return body_extractor .extract_unique_filenames (data )
104102 except BodyCodeSnippetExtractorError as e :
105103 logger .error (f"Error extracting filenames from request: { e } " )
106- return set ( )
104+ raise MuxMatchingError ( "Error extracting filenames from request" )
107105
108- def match (self , thing_to_match : mux_models . ThingToMatchMux ) -> bool :
106+ def _is_matcher_in_filenames (self , detected_client : ClientType , data : dict ) -> bool :
109107 """
110- Retun True if there is a filename in the request that matches the matcher_blob.
111- The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
108+ Check if the matcher is in the request filenames.
112109 """
113- # If there is no matcher_blob, we don't match
114- if not self ._matcher_blob :
115- return False
116- filenames_to_match = self ._extract_request_filenames (
117- thing_to_match .client_type , thing_to_match .body
110+ # Empty matcher_blob means we match everything
111+ if not self ._mux_rule .matcher :
112+ return True
113+ filenames_to_match = self ._extract_request_filenames (detected_client , data )
114+ # _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames
115+ # match the rule.
116+ is_filename_match = any (
117+ self ._mux_rule .matcher == filename or filename .endswith (self ._mux_rule .matcher )
118+ for filename in filenames_to_match
118119 )
119- is_filename_match = any (self ._matcher_blob in filename for filename in filenames_to_match )
120- if is_filename_match :
121- logger .info (
122- "Filename rule matched" , filenames = filenames_to_match , matcher = self ._matcher_blob
123- )
124120 return is_filename_match
125121
126-
127- class RequestTypeMuxingRuleMatcher (MuxingRuleMatcher ):
128- """A catch all muxing rule matcher."""
122+ def _is_request_type_match (self , is_fim_request : bool ) -> bool :
123+ """
124+ Check if the request type matches the MuxMatcherType.
125+ """
126+ # Catch all rule matches both chat and FIM requests
127+ if self ._mux_rule .matcher_type == mux_models .MuxMatcherType .catch_all :
128+ return True
129+ incoming_request_type = "fim" if is_fim_request else "chat"
130+ if incoming_request_type == self ._mux_rule .matcher_type :
131+ return True
132+ return False
129133
130134 def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
131135 """
132- Return True if the request type matches the matcher_blob.
133- The matcher_blob is either "fim" or "chat" .
136+ Return True if the matcher is in one of the request filenames and
137+ if the request type matches the MuxMatcherType .
134138 """
135- # If there is no matcher_blob, we don't match
136- if not self ._matcher_blob :
137- return False
138- incoming_request_type = "fim" if thing_to_match .is_fim_request else "chat"
139- is_request_type_match = self ._matcher_blob == incoming_request_type
140- if is_request_type_match :
139+ is_rule_matched = self ._is_matcher_in_filenames (
140+ thing_to_match .client_type , thing_to_match .body
141+ ) and self ._is_request_type_match (thing_to_match .is_fim_request )
142+ if is_rule_matched :
141143 logger .info (
142- "Request type rule matched" ,
143- matcher = self ._matcher_blob ,
144- request_type = incoming_request_type ,
144+ "Request type and rule matched" ,
145+ matcher = self ._mux_rule . matcher ,
146+ is_fim_request = thing_to_match . is_fim_request ,
145147 )
146- return is_request_type_match
148+ return is_rule_matched
147149
148150
149151class MuxingRulesinWorkspaces :
0 commit comments