23
23
query_graphql_organizations ,
24
24
)
25
25
from src .api .tools .types import Tool , WorkspaceTarget
26
+ from src .auth .session_credentials_manager import (
27
+ get_session_credentials_manager ,
28
+ invalidate_credentials ,
29
+ )
26
30
from src .utils .uuid_validation import validate_workspace_id , validate_uuid_string
27
31
from src .utils .elicitation import try_elicitation , ElicitationError
28
32
from src .logger import get_logger
35
39
)
36
40
37
41
42
+ class DatabaseCredentials (BaseModel ):
43
+ """Schema for database authentication credentials when using API key."""
44
+
45
+ username : str = Field (..., description = "Database username for authentication" )
46
+ password : str = Field (..., description = "Database password for authentication" )
47
+
48
+
49
+ async def _get_database_credentials (
50
+ ctx : Context , target : WorkspaceTarget , database_name : str | None = None
51
+ ) -> tuple [str , str ]:
52
+ """
53
+ Get database credentials based on the authentication method.
54
+
55
+ Args:
56
+ ctx: The MCP context
57
+ target: The workspace target
58
+ database_name: The database name to use for key generation
59
+
60
+ Returns:
61
+ Tuple of (username, password)
62
+
63
+ Raises:
64
+ Exception: If credentials cannot be obtained
65
+ """
66
+ settings = config .get_settings ()
67
+
68
+ # Check if we're using API key authentication
69
+ is_using_api_key = (
70
+ not settings .is_remote
71
+ and isinstance (settings , config .LocalSettings )
72
+ and settings .api_key is not None
73
+ )
74
+
75
+ if is_using_api_key :
76
+ # For API key authentication, we need database credentials
77
+ # Generate database key using credentials manager
78
+ credentials_manager = get_session_credentials_manager ()
79
+ database_key = credentials_manager .generate_database_key (
80
+ workspace_name = target .name , database_name = database_name
81
+ )
82
+
83
+ # Check if we have cached credentials for this database
84
+ if credentials_manager .has_credentials (database_key ):
85
+ cached_creds = credentials_manager .get_credentials (database_key )
86
+ if cached_creds :
87
+ logger .debug (f"Using cached credentials for workspace: { target .name } " )
88
+ return cached_creds
89
+
90
+ # Dedicated workspaces: need to request database credentials from user
91
+ elicitation_message = (
92
+ f"API key authentication detected. To connect to the dedicated workspace '{ target .name } ', "
93
+ f"please provide your database username and password for this workspace."
94
+ )
95
+
96
+ try :
97
+ elicitation_result , error = await try_elicitation (
98
+ ctx = ctx , message = elicitation_message , schema = DatabaseCredentials
99
+ )
100
+
101
+ if error == ElicitationError .NOT_SUPPORTED :
102
+ # Fallback: raise exception with clear message
103
+ raise Exception (
104
+ "Database credentials required for API key authentication on dedicated workspaces. "
105
+ f"Please provide your database username and password for workspace '{ target .name } '. "
106
+ "You can obtain these credentials from your SingleStore portal. "
107
+ "Note: This is different from your SingleStore account credentials - these are "
108
+ "database-specific credentials for connecting to the workspace."
109
+ )
110
+ elif elicitation_result .status == "success" and elicitation_result .data :
111
+ username = elicitation_result .data .username
112
+ password = elicitation_result .data .password
113
+
114
+ # Store credentials in session cache for future use
115
+ try :
116
+ credentials_manager .store_credentials (
117
+ database_key , username , password
118
+ )
119
+ logger .debug (f"Cached credentials for workspace: { target .name } " )
120
+ except Exception as e :
121
+ logger .warning (f"Failed to cache credentials: { e } " )
122
+
123
+ return (username , password )
124
+ else :
125
+ raise Exception (
126
+ "Database credentials are required but were not provided. Please ask the user to provide the database credentials"
127
+ )
128
+ except Exception as e :
129
+ if "Database credentials required" in str (e ):
130
+ raise # Re-raise our specific credential error
131
+ logger .error (f"Error during credential elicitation: { e } " )
132
+ raise Exception (f"Failed to obtain database credentials: { str (e )} " )
133
+ else :
134
+ # JWT authentication: use user_id and access_token as before
135
+ return __get_user_id (), get_access_token ()
136
+
137
+
38
138
async def __execute_sql_unified (
39
139
ctx : Context ,
40
140
target : WorkspaceTarget ,
@@ -60,39 +160,71 @@ async def __execute_sql_unified(
60
160
host = endpoint
61
161
port = None
62
162
63
- s2_manager = S2Manager (
64
- host = host ,
65
- port = port ,
66
- user = username ,
67
- password = password ,
68
- database = database_name ,
163
+ # Generate database key for credential management
164
+ credentials_manager = get_session_credentials_manager ()
165
+ database_key = credentials_manager .generate_database_key (
166
+ workspace_name = target .name , database_name = database_name
69
167
)
70
168
71
- workspace_type = "shared/virtual" if target .is_shared else "dedicated"
72
- await ctx .info (
73
- f"Executing SQL query on { workspace_type } workspace '{ target .name } ' with database '{ database_name } ': { sql_query } "
74
- "This query may take some time depending on the complexity and size of the data."
75
- )
76
- s2_manager .execute (sql_query )
77
- columns = (
78
- [desc [0 ] for desc in s2_manager .cursor .description ]
79
- if s2_manager .cursor .description
80
- else []
81
- )
82
- rows = s2_manager .fetchmany ()
83
- results = []
84
- for row in rows :
85
- result_dict = {}
86
- for i , column in enumerate (columns ):
87
- result_dict [column ] = row [i ]
88
- results .append (result_dict )
89
- s2_manager .close ()
90
- return {
91
- "data" : results ,
92
- "row_count" : len (rows ),
93
- "columns" : columns ,
94
- "status" : "Success" ,
95
- }
169
+ try :
170
+ s2_manager = S2Manager (
171
+ host = host ,
172
+ port = port ,
173
+ user = username ,
174
+ password = password ,
175
+ database = database_name ,
176
+ )
177
+
178
+ workspace_type = "shared/virtual" if target .is_shared else "dedicated"
179
+ await ctx .info (
180
+ f"Executing SQL query on { workspace_type } workspace '{ target .name } ' with database '{ database_name } ': { sql_query } "
181
+ "This query may take some time depending on the complexity and size of the data."
182
+ )
183
+ s2_manager .execute (sql_query )
184
+ columns = (
185
+ [desc [0 ] for desc in s2_manager .cursor .description ]
186
+ if s2_manager .cursor .description
187
+ else []
188
+ )
189
+ rows = s2_manager .fetchmany ()
190
+ results = []
191
+ for row in rows :
192
+ result_dict = {}
193
+ for i , column in enumerate (columns ):
194
+ result_dict [column ] = row [i ]
195
+ results .append (result_dict )
196
+ s2_manager .close ()
197
+ return {
198
+ "data" : results ,
199
+ "row_count" : len (rows ),
200
+ "columns" : columns ,
201
+ "status" : "Success" ,
202
+ }
203
+ except Exception as e :
204
+ # Check if this is an authentication error
205
+ error_msg = str (e ).lower ()
206
+ is_auth_error = any (
207
+ auth_keyword in error_msg
208
+ for auth_keyword in [
209
+ "access denied" ,
210
+ "authentication failed" ,
211
+ "invalid credentials" ,
212
+ "login failed" ,
213
+ "permission denied" ,
214
+ "unauthorized" ,
215
+ "auth" ,
216
+ ]
217
+ )
218
+
219
+ if is_auth_error :
220
+ logger .warning (
221
+ f"Authentication failed for database { database_key } , invalidating cached credentials"
222
+ )
223
+ invalidate_credentials (database_key )
224
+ raise Exception (f"Authentication failed: { str (e )} " )
225
+ else :
226
+ # Non-authentication error, re-raise as-is
227
+ raise
96
228
97
229
98
230
def __get_virtual_workspace (virtual_workspace_id : str ):
@@ -781,19 +913,62 @@ async def run_sql(
781
913
if target .is_shared and target .database_name and not database_name :
782
914
database_name = target .database_name
783
915
784
- username = __get_user_id ()
785
- password = get_access_token ()
916
+ # Get database credentials based on authentication method
917
+ try :
918
+ username , password = await _get_database_credentials (ctx , target , database_name )
919
+ except Exception as e :
920
+ if "Database credentials required" in str (e ):
921
+ # Handle the specific case where elicitation is not supported
922
+ return {
923
+ "status" : "error" ,
924
+ "message" : str (e ),
925
+ "error_code" : "CREDENTIALS_REQUIRED" ,
926
+ "workspace_id" : validated_id ,
927
+ "workspace_name" : target .name ,
928
+ "workspace_type" : "shared" if target .is_shared else "dedicated" ,
929
+ "instruction" : (
930
+ "Please call this function again with the same parameters once you have "
931
+ "the database credentials available, or ask the user to provide their "
932
+ "database username and password for this workspace."
933
+ ),
934
+ }
935
+ else :
936
+ return {
937
+ "status" : "error" ,
938
+ "message" : f"Failed to obtain database credentials: { str (e )} " ,
939
+ "error_code" : "AUTHENTICATION_ERROR" ,
940
+ }
786
941
787
942
# Execute the SQL query
788
943
start_time = time .time ()
789
- result = await __execute_sql_unified (
790
- ctx = ctx ,
791
- target = target ,
792
- sql_query = sql_query ,
793
- username = username ,
794
- password = password ,
795
- database = database_name ,
796
- )
944
+ try :
945
+ result = await __execute_sql_unified (
946
+ ctx = ctx ,
947
+ target = target ,
948
+ sql_query = sql_query ,
949
+ username = username ,
950
+ password = password ,
951
+ database = database_name ,
952
+ )
953
+ except Exception as e :
954
+ # Check if this is an authentication error from __execute_sql_unified
955
+ if "Authentication failed:" in str (e ):
956
+ # Authentication error already handled by __execute_sql_unified (credentials invalidated)
957
+ return {
958
+ "status" : "error" ,
959
+ "message" : str (e ),
960
+ "error_code" : "AUTHENTICATION_ERROR" ,
961
+ "workspace_id" : validated_id ,
962
+ "workspace_name" : target .name ,
963
+ "workspace_type" : "shared" if target .is_shared else "dedicated" ,
964
+ "instruction" : (
965
+ "Authentication failed. Please provide valid database credentials "
966
+ "for this workspace and try again."
967
+ ),
968
+ }
969
+ else :
970
+ # Non-authentication error, re-raise
971
+ raise
797
972
798
973
results_data = result .get ("data" , [])
799
974
0 commit comments