Skip to content

Commit 443bb27

Browse files
authored
Fix/run sql api key (#78)
* wip * release: bump version to 0.4.3 (patch release) * create changelog file on new release * fix run_sql when an API is passed
1 parent 390cf6a commit 443bb27

File tree

7 files changed

+777
-47
lines changed

7 files changed

+777
-47
lines changed

changelog/0.4.3.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# [0.4.3] - 2025-07-23
2+
3+
## Fixed
4+
5+
- `run_sql` for virtual workspaces when an API key is given.

scripts/mark-release.sh

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,35 @@ git commit -m "release: bump version to $NEW_VERSION ($RELEASE_TYPE release)"
9696
git push origin "$CURRENT_BRANCH"
9797

9898
echo ""
99+
# Create changelog file if it doesn't exist
100+
CHANGELOG_FILE="changelog/$NEW_VERSION.md"
101+
if [ ! -f "$CHANGELOG_FILE" ]; then
102+
CURRENT_DATE=$(date +%Y-%m-%d)
103+
echo "# [$NEW_VERSION] - $CURRENT_DATE
104+
105+
## Added
106+
-
107+
-
108+
109+
## Fixed
110+
-
111+
- " > "$CHANGELOG_FILE"
112+
113+
# Add changelog to git
114+
git add "$CHANGELOG_FILE"
115+
git commit --amend --no-edit
116+
fi
117+
99118
echo -e "${GREEN}✅ Branch marked for release!${NC}"
100119
echo ""
101120
echo -e "${BLUE}Next steps:${NC}"
102-
echo "1. Push your branch: ${YELLOW}git push origin $CURRENT_BRANCH${NC}"
103-
echo "2. Create PR to main"
104-
echo "3. When PR is merged, automatic PyPI publication will be triggered"
105-
echo "4. Optionally create git tag manually: ${YELLOW}git tag v$NEW_VERSION && git push origin v$NEW_VERSION${NC}"
121+
echo "1. Update the changelog at: ${YELLOW}$CHANGELOG_FILE${NC}"
122+
echo "2. Push your branch: ${YELLOW}git push origin $CURRENT_BRANCH${NC}"
123+
echo "3. Create PR to main"
124+
echo "4. When PR is merged, automatic PyPI publication will be triggered"
125+
echo "5. Optionally create git tag manually: ${YELLOW}git tag v$NEW_VERSION && git push origin v$NEW_VERSION${NC}"
106126
echo ""
107127
echo -e "${BLUE}Release details:${NC}"
108128
echo " 🏷️ Release type: $RELEASE_TYPE"
109129
echo " 📦 New version: $NEW_VERSION"
110-
echo " 📝 Version file updated and committed"
130+
echo " 📝 Version file and changelog updated and committed"

src/api/tools/tools.py

Lines changed: 216 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
query_graphql_organizations,
2424
)
2525
from src.api.tools.types import Tool, WorkspaceTarget
26+
from src.auth.session_credentials_manager import (
27+
get_session_credentials_manager,
28+
invalidate_credentials,
29+
)
2630
from src.utils.uuid_validation import validate_workspace_id, validate_uuid_string
2731
from src.utils.elicitation import try_elicitation, ElicitationError
2832
from src.logger import get_logger
@@ -35,6 +39,102 @@
3539
)
3640

3741

42+
class DatabaseCredentials(BaseModel):
43+
"""Schema for database authentication credentials when using API key."""
44+
45+
username: str = Field(..., description="Database username for authentication")
46+
password: str = Field(..., description="Database password for authentication")
47+
48+
49+
async def _get_database_credentials(
50+
ctx: Context, target: WorkspaceTarget, database_name: str | None = None
51+
) -> tuple[str, str]:
52+
"""
53+
Get database credentials based on the authentication method.
54+
55+
Args:
56+
ctx: The MCP context
57+
target: The workspace target
58+
database_name: The database name to use for key generation
59+
60+
Returns:
61+
Tuple of (username, password)
62+
63+
Raises:
64+
Exception: If credentials cannot be obtained
65+
"""
66+
settings = config.get_settings()
67+
68+
# Check if we're using API key authentication
69+
is_using_api_key = (
70+
not settings.is_remote
71+
and isinstance(settings, config.LocalSettings)
72+
and settings.api_key is not None
73+
)
74+
75+
if is_using_api_key:
76+
# For API key authentication, we need database credentials
77+
# Generate database key using credentials manager
78+
credentials_manager = get_session_credentials_manager()
79+
database_key = credentials_manager.generate_database_key(
80+
workspace_name=target.name, database_name=database_name
81+
)
82+
83+
# Check if we have cached credentials for this database
84+
if credentials_manager.has_credentials(database_key):
85+
cached_creds = credentials_manager.get_credentials(database_key)
86+
if cached_creds:
87+
logger.debug(f"Using cached credentials for workspace: {target.name}")
88+
return cached_creds
89+
90+
# Dedicated workspaces: need to request database credentials from user
91+
elicitation_message = (
92+
f"API key authentication detected. To connect to the dedicated workspace '{target.name}', "
93+
f"please provide your database username and password for this workspace."
94+
)
95+
96+
try:
97+
elicitation_result, error = await try_elicitation(
98+
ctx=ctx, message=elicitation_message, schema=DatabaseCredentials
99+
)
100+
101+
if error == ElicitationError.NOT_SUPPORTED:
102+
# Fallback: raise exception with clear message
103+
raise Exception(
104+
"Database credentials required for API key authentication on dedicated workspaces. "
105+
f"Please provide your database username and password for workspace '{target.name}'. "
106+
"You can obtain these credentials from your SingleStore portal. "
107+
"Note: This is different from your SingleStore account credentials - these are "
108+
"database-specific credentials for connecting to the workspace."
109+
)
110+
elif elicitation_result.status == "success" and elicitation_result.data:
111+
username = elicitation_result.data.username
112+
password = elicitation_result.data.password
113+
114+
# Store credentials in session cache for future use
115+
try:
116+
credentials_manager.store_credentials(
117+
database_key, username, password
118+
)
119+
logger.debug(f"Cached credentials for workspace: {target.name}")
120+
except Exception as e:
121+
logger.warning(f"Failed to cache credentials: {e}")
122+
123+
return (username, password)
124+
else:
125+
raise Exception(
126+
"Database credentials are required but were not provided. Please ask the user to provide the database credentials"
127+
)
128+
except Exception as e:
129+
if "Database credentials required" in str(e):
130+
raise # Re-raise our specific credential error
131+
logger.error(f"Error during credential elicitation: {e}")
132+
raise Exception(f"Failed to obtain database credentials: {str(e)}")
133+
else:
134+
# JWT authentication: use user_id and access_token as before
135+
return __get_user_id(), get_access_token()
136+
137+
38138
async def __execute_sql_unified(
39139
ctx: Context,
40140
target: WorkspaceTarget,
@@ -60,39 +160,71 @@ async def __execute_sql_unified(
60160
host = endpoint
61161
port = None
62162

63-
s2_manager = S2Manager(
64-
host=host,
65-
port=port,
66-
user=username,
67-
password=password,
68-
database=database_name,
163+
# Generate database key for credential management
164+
credentials_manager = get_session_credentials_manager()
165+
database_key = credentials_manager.generate_database_key(
166+
workspace_name=target.name, database_name=database_name
69167
)
70168

71-
workspace_type = "shared/virtual" if target.is_shared else "dedicated"
72-
await ctx.info(
73-
f"Executing SQL query on {workspace_type} workspace '{target.name}' with database '{database_name}': {sql_query}"
74-
"This query may take some time depending on the complexity and size of the data."
75-
)
76-
s2_manager.execute(sql_query)
77-
columns = (
78-
[desc[0] for desc in s2_manager.cursor.description]
79-
if s2_manager.cursor.description
80-
else []
81-
)
82-
rows = s2_manager.fetchmany()
83-
results = []
84-
for row in rows:
85-
result_dict = {}
86-
for i, column in enumerate(columns):
87-
result_dict[column] = row[i]
88-
results.append(result_dict)
89-
s2_manager.close()
90-
return {
91-
"data": results,
92-
"row_count": len(rows),
93-
"columns": columns,
94-
"status": "Success",
95-
}
169+
try:
170+
s2_manager = S2Manager(
171+
host=host,
172+
port=port,
173+
user=username,
174+
password=password,
175+
database=database_name,
176+
)
177+
178+
workspace_type = "shared/virtual" if target.is_shared else "dedicated"
179+
await ctx.info(
180+
f"Executing SQL query on {workspace_type} workspace '{target.name}' with database '{database_name}': {sql_query}"
181+
"This query may take some time depending on the complexity and size of the data."
182+
)
183+
s2_manager.execute(sql_query)
184+
columns = (
185+
[desc[0] for desc in s2_manager.cursor.description]
186+
if s2_manager.cursor.description
187+
else []
188+
)
189+
rows = s2_manager.fetchmany()
190+
results = []
191+
for row in rows:
192+
result_dict = {}
193+
for i, column in enumerate(columns):
194+
result_dict[column] = row[i]
195+
results.append(result_dict)
196+
s2_manager.close()
197+
return {
198+
"data": results,
199+
"row_count": len(rows),
200+
"columns": columns,
201+
"status": "Success",
202+
}
203+
except Exception as e:
204+
# Check if this is an authentication error
205+
error_msg = str(e).lower()
206+
is_auth_error = any(
207+
auth_keyword in error_msg
208+
for auth_keyword in [
209+
"access denied",
210+
"authentication failed",
211+
"invalid credentials",
212+
"login failed",
213+
"permission denied",
214+
"unauthorized",
215+
"auth",
216+
]
217+
)
218+
219+
if is_auth_error:
220+
logger.warning(
221+
f"Authentication failed for database {database_key}, invalidating cached credentials"
222+
)
223+
invalidate_credentials(database_key)
224+
raise Exception(f"Authentication failed: {str(e)}")
225+
else:
226+
# Non-authentication error, re-raise as-is
227+
raise
96228

97229

98230
def __get_virtual_workspace(virtual_workspace_id: str):
@@ -781,19 +913,62 @@ async def run_sql(
781913
if target.is_shared and target.database_name and not database_name:
782914
database_name = target.database_name
783915

784-
username = __get_user_id()
785-
password = get_access_token()
916+
# Get database credentials based on authentication method
917+
try:
918+
username, password = await _get_database_credentials(ctx, target, database_name)
919+
except Exception as e:
920+
if "Database credentials required" in str(e):
921+
# Handle the specific case where elicitation is not supported
922+
return {
923+
"status": "error",
924+
"message": str(e),
925+
"error_code": "CREDENTIALS_REQUIRED",
926+
"workspace_id": validated_id,
927+
"workspace_name": target.name,
928+
"workspace_type": "shared" if target.is_shared else "dedicated",
929+
"instruction": (
930+
"Please call this function again with the same parameters once you have "
931+
"the database credentials available, or ask the user to provide their "
932+
"database username and password for this workspace."
933+
),
934+
}
935+
else:
936+
return {
937+
"status": "error",
938+
"message": f"Failed to obtain database credentials: {str(e)}",
939+
"error_code": "AUTHENTICATION_ERROR",
940+
}
786941

787942
# Execute the SQL query
788943
start_time = time.time()
789-
result = await __execute_sql_unified(
790-
ctx=ctx,
791-
target=target,
792-
sql_query=sql_query,
793-
username=username,
794-
password=password,
795-
database=database_name,
796-
)
944+
try:
945+
result = await __execute_sql_unified(
946+
ctx=ctx,
947+
target=target,
948+
sql_query=sql_query,
949+
username=username,
950+
password=password,
951+
database=database_name,
952+
)
953+
except Exception as e:
954+
# Check if this is an authentication error from __execute_sql_unified
955+
if "Authentication failed:" in str(e):
956+
# Authentication error already handled by __execute_sql_unified (credentials invalidated)
957+
return {
958+
"status": "error",
959+
"message": str(e),
960+
"error_code": "AUTHENTICATION_ERROR",
961+
"workspace_id": validated_id,
962+
"workspace_name": target.name,
963+
"workspace_type": "shared" if target.is_shared else "dedicated",
964+
"instruction": (
965+
"Authentication failed. Please provide valid database credentials "
966+
"for this workspace and try again."
967+
),
968+
}
969+
else:
970+
# Non-authentication error, re-raise
971+
raise
797972

798973
results_data = result.get("data", [])
799974

0 commit comments

Comments
 (0)