Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "opteryx-sqlalchemy"
version = "0.0.5"
version = "0.0.6"
description = "Opteryx SQLAlchemy Dialect"
requires-python = ">=3.9"
dependencies = [
Expand Down
25 changes: 23 additions & 2 deletions sqlalchemy_dialect/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,31 @@ def __init__(self, connection: "Connection") -> None:
token = body.get("access_token") or body.get("token") or body.get("jwt")
if token:
self._jwt_token = token
logger.info("Authentication successful for user: %s", username)
token_type = body.get("token_type", "bearer")
# Capitalize token_type properly: "bearer" -> "Bearer"
if token_type:
token_type = (
token_type[0].upper() + token_type[1:].lower()
if len(token_type) > 0
else "Bearer"
)
else:
token_type = "Bearer"
expires_in = body.get("expires_in")
refresh_token = body.get("refresh_token")
logger.info(
"Authentication successful for user: %s (token_type=%s, expires_in=%s)",
username,
token_type,
expires_in,
)
if refresh_token:
logger.debug("Refresh token received for user: %s", username)
# Set Authorization header for subsequent requests via the connection session
try:
self._connection._session.headers["Authorization"] = f"Bearer {token}"
auth_header = f"{token_type} {token}"
self._connection._session.headers["Authorization"] = auth_header
logger.debug("Set Authorization header to: %s ...", auth_header[:50])
except Exception as e:
logger.warning("Failed to set Authorization header: %s", e)
else:
Expand Down
17 changes: 11 additions & 6 deletions tests/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,30 @@
import os
import sys

from sqlalchemy import create_engine
from sqlalchemy import text

# Make local package importable in editable/test mode (same pattern as tests/plain_script)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

# Import the package so the dialect registers itself in editable/test mode
from tests import load_dotenv_simple
import sqlalchemy_dialect # noqa: F401

username = ""
password = ""
load_dotenv_simple("../.env")

from sqlalchemy import create_engine, text
DEFAULT_CLIENT_ID = os.environ.get("CLIENT_ID")
DEFAULT_CLIENT_SECRET = os.environ.get("CLIENT_SECRET")

# Configure logging to see the new debug output
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# Enable INFO level for opteryx dialect to see query execution times
logging.getLogger("sqlalchemy.dialects.opteryx").setLevel(logging.INFO)

# username:token@host:port/database?ssl=true
engine = create_engine(f"opteryx://{username}:{password}@opteryx.app:443/default?ssl=true")
engine = create_engine(
f"opteryx://{DEFAULT_CLIENT_ID}:{DEFAULT_CLIENT_SECRET}@opteryx.app:443/default?ssl=true"
)

with engine.connect() as conn:
res = conn.execute(text("SELECT * FROM benchmarks.tpch.lineitem LIMIT 50"))
res = conn.execute(text("SELECT * FROM public.examples.planets LIMIT 50"))
print(res.fetchall())
54 changes: 50 additions & 4 deletions tests/plain_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import json
import os
import sys
import time
Expand All @@ -14,7 +15,6 @@
from typing import Optional
from typing import Sequence

import orjson
import requests
from orso import DataFrame

Expand All @@ -35,7 +35,7 @@
DEFAULT_DATA_URL = "https://jobs.opteryx.app"
DEFAULT_CLIENT_ID = os.environ.get("CLIENT_ID")
DEFAULT_CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
SQL_STATEMENT = "SELECT * FROM $planets AS P"
SQL_STATEMENT = "SELECT * FROM public.examples.planets AS P"


def fatal(msg: str) -> None:
Expand Down Expand Up @@ -79,11 +79,24 @@ def get_token(

r = requests.post(url, data=data, timeout=10)
if r.status_code != 200:
raise RuntimeError(f"token endpoint returned status {r.status_code}: {r.text}")
error_msg = f"token endpoint returned status {r.status_code}"
if r.status_code == 401:
print("\n⚠️ 401 Unauthorized Response:")
print(f"Headers: {dict(r.headers)}")
print(f"Body: {r.text}")
try:
json_body = r.json()
print(f"Parsed JSON: {json.dumps(json_body, indent=2)}")
except ValueError:
pass
raise RuntimeError(f"{error_msg}: {r.text}")
body = r.json()
token = body.get("access_token")
if not token:
raise RuntimeError("token endpoint returned no access_token")
expires_in = body.get("expires_in")
if expires_in:
print(f" Token expires in {expires_in} seconds")
return token


Expand All @@ -100,6 +113,17 @@ def create_statement(
payload["describeOnly"] = describe_only

r = requests.post(url, json=payload, headers=headers, timeout=10)
if r.status_code == 401:
print("\n⚠️ 401 Unauthorized Response on jobs endpoint:")
print(f"URL: {url}")
print(f"Headers sent: {headers}")
print(f"Response Headers: {dict(r.headers)}")
print(f"Body: {r.text}")
try:
json_body = r.json()
print(f"Parsed JSON: {json.dumps(json_body, indent=2)}")
except ValueError:
pass
r.raise_for_status()
return r.json()

Expand All @@ -108,6 +132,17 @@ def get_statement_status(data_url: str, token: str, handle: str) -> Dict[str, An
url = f"{data_url.rstrip('/')}/api/v1/jobs/{handle}/status"
headers = {"Authorization": f"Bearer {token}"}
r = requests.get(url, headers=headers, timeout=10)
if r.status_code == 401:
print("\n⚠️ 401 Unauthorized Response on status endpoint:")
print(f"URL: {url}")
print(f"Headers sent: {headers}")
print(f"Response Headers: {dict(r.headers)}")
print(f"Body: {r.text}")
try:
json_body = r.json()
print(f"Parsed JSON: {json.dumps(json_body, indent=2)}")
except ValueError:
pass
r.raise_for_status()
return r.json()

Expand All @@ -122,6 +157,17 @@ def get_statement_data(
if brotli is not None:
headers["Accept-Encoding"] = "br"
r = requests.get(url, headers=headers, timeout=10)
if r.status_code == 401:
print("\n⚠️ 401 Unauthorized Response on results endpoint:")
print(f"URL: {url}")
print(f"Headers sent: {headers}")
print(f"Response Headers: {dict(r.headers)}")
print(f"Body: {r.text}")
try:
json_body = r.json()
print(f"Parsed JSON: {json.dumps(json_body, indent=2)}")
except ValueError:
pass
r.raise_for_status()
encoding = r.headers.get("Content-Encoding", "")
if encoding.lower() == "br":
Expand All @@ -132,7 +178,7 @@ def get_statement_data(
except brotli.error:
# Already decompressed or corrupt stream; fall back to raw bytes
content = r.content
return orjson.loads(content)
return json.loads(content)
return r.json()


Expand Down
Loading