Skip to content

SeaDatabricksClient: Add Metadata Commands #593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: exec-phase-sea
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions src/databricks/sql/backend/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Client-side filtering utilities for Databricks SQL connector.

This module provides filtering capabilities for result sets returned by different backends.
"""

import logging
from typing import (
List,
Optional,
Any,
Dict,
Callable,
TypeVar,
Generic,
cast,
TYPE_CHECKING,
)

from databricks.sql.backend.types import ExecuteResponse, CommandId
from databricks.sql.backend.sea.models.base import ResultData
from databricks.sql.backend.sea.backend import SeaDatabricksClient

if TYPE_CHECKING:
from databricks.sql.result_set import ResultSet, SeaResultSet

logger = logging.getLogger(__name__)


class ResultSetFilter:
"""
A general-purpose filter for result sets that can be applied to any backend.

This class provides methods to filter result sets based on various criteria,
similar to the client-side filtering in the JDBC connector.
"""

@staticmethod
def _filter_sea_result_set(
result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool]
) -> "SeaResultSet":
"""
Filter a SEA result set using the provided filter function.

Args:
result_set: The SEA result set to filter
filter_func: Function that takes a row and returns True if the row should be included

Returns:
A filtered SEA result set
"""
# Get all remaining rows
all_rows = result_set.results.remaining_rows()

# Filter rows
filtered_rows = [row for row in all_rows if filter_func(row)]

# Import SeaResultSet here to avoid circular imports
from databricks.sql.result_set import SeaResultSet

# Reuse the command_id from the original result set
command_id = result_set.command_id

# Create an ExecuteResponse with the filtered data
execute_response = ExecuteResponse(
command_id=command_id,
status=result_set.status,
description=result_set.description,
has_been_closed_server_side=result_set.has_been_closed_server_side,
lz4_compressed=result_set.lz4_compressed,
arrow_schema_bytes=result_set._arrow_schema_bytes,
is_staging_operation=False,
)

# Create a new ResultData object with filtered data
from databricks.sql.backend.sea.models.base import ResultData

result_data = ResultData(data=filtered_rows, external_links=None)

# Create a new SeaResultSet with the filtered data
filtered_result_set = SeaResultSet(
connection=result_set.connection,
execute_response=execute_response,
sea_client=cast(SeaDatabricksClient, result_set.backend),
buffer_size_bytes=result_set.buffer_size_bytes,
arraysize=result_set.arraysize,
result_data=result_data,
)

return filtered_result_set

@staticmethod
def filter_by_column_values(
result_set: "ResultSet",
column_index: int,
allowed_values: List[str],
case_sensitive: bool = False,
) -> "ResultSet":
"""
Filter a result set by values in a specific column.

Args:
result_set: The result set to filter
column_index: The index of the column to filter on
allowed_values: List of allowed values for the column
case_sensitive: Whether to perform case-sensitive comparison

Returns:
A filtered result set
"""
# Convert to uppercase for case-insensitive comparison if needed
if not case_sensitive:
allowed_values = [v.upper() for v in allowed_values]

# Determine the type of result set and apply appropriate filtering
from databricks.sql.result_set import SeaResultSet

if isinstance(result_set, SeaResultSet):
return ResultSetFilter._filter_sea_result_set(
result_set,
lambda row: (
len(row) > column_index
and isinstance(row[column_index], str)
and (
row[column_index].upper()
if not case_sensitive
else row[column_index]
)
in allowed_values
),
)

# For other result set types, return the original (should be handled by specific implementations)
logger.warning(
f"Filtering not implemented for result set type: {type(result_set).__name__}"
)
return result_set

@staticmethod
def filter_tables_by_type(
result_set: "ResultSet", table_types: Optional[List[str]] = None
) -> "ResultSet":
"""
Filter a result set of tables by the specified table types.

This is a client-side filter that processes the result set after it has been
retrieved from the server. It filters out tables whose type does not match
any of the types in the table_types list.

Args:
result_set: The original result set containing tables
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])

Returns:
A filtered result set containing only tables of the specified types
"""
# Default table types if none specified
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
valid_types = (
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
)

# Table type is the 6th column (index 5)
return ResultSetFilter.filter_by_column_values(
result_set, 5, valid_types, case_sensitive=True
)
119 changes: 107 additions & 12 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,22 @@ def get_catalogs(
max_rows: int,
max_bytes: int,
cursor: "Cursor",
):
"""Not implemented yet."""
raise NotImplementedError("get_catalogs is not yet implemented for SEA backend")
) -> "ResultSet":
"""Get available catalogs by executing 'SHOW CATALOGS'."""
result = self.execute_command(
operation="SHOW CATALOGS",
session_id=session_id,
max_rows=max_rows,
max_bytes=max_bytes,
lz4_compression=False,
cursor=cursor,
use_cloud_fetch=False,
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"
return result

def get_schemas(
self,
Expand All @@ -331,9 +344,30 @@ def get_schemas(
cursor: "Cursor",
catalog_name: Optional[str] = None,
schema_name: Optional[str] = None,
):
"""Not implemented yet."""
raise NotImplementedError("get_schemas is not yet implemented for SEA backend")
) -> "ResultSet":
"""Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'."""
if not catalog_name:
raise ValueError("Catalog name is required for get_schemas")

operation = f"SHOW SCHEMAS IN `{catalog_name}`"

if schema_name:
operation += f" LIKE '{schema_name}'"

result = self.execute_command(
operation=operation,
session_id=session_id,
max_rows=max_rows,
max_bytes=max_bytes,
lz4_compression=False,
cursor=cursor,
use_cloud_fetch=False,
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"
return result

def get_tables(
self,
Expand All @@ -345,9 +379,43 @@ def get_tables(
schema_name: Optional[str] = None,
table_name: Optional[str] = None,
table_types: Optional[List[str]] = None,
):
"""Not implemented yet."""
raise NotImplementedError("get_tables is not yet implemented for SEA backend")
) -> "ResultSet":
"""Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'."""
if not catalog_name:
raise ValueError("Catalog name is required for get_tables")

operation = "SHOW TABLES IN " + (
"ALL CATALOGS"
if catalog_name in [None, "*", "%"]
else f"CATALOG `{catalog_name}`"
)

if schema_name:
operation += f" SCHEMA LIKE '{schema_name}'"

if table_name:
operation += f" LIKE '{table_name}'"

result = self.execute_command(
operation=operation,
session_id=session_id,
max_rows=max_rows,
max_bytes=max_bytes,
lz4_compression=False,
cursor=cursor,
use_cloud_fetch=False,
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"

# Apply client-side filtering by table_types if specified
from databricks.sql.backend.filters import ResultSetFilter

result = ResultSetFilter.filter_tables_by_type(result, table_types)

return result

def get_columns(
self,
Expand All @@ -359,6 +427,33 @@ def get_columns(
schema_name: Optional[str] = None,
table_name: Optional[str] = None,
column_name: Optional[str] = None,
):
"""Not implemented yet."""
raise NotImplementedError("get_columns is not yet implemented for SEA backend")
) -> "ResultSet":
"""Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'."""
if not catalog_name:
raise ValueError("Catalog name is required for get_columns")

operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`"

if schema_name:
operation += f" SCHEMA LIKE '{schema_name}'"

if table_name:
operation += f" TABLE LIKE '{table_name}'"

if column_name:
operation += f" LIKE '{column_name}'"

result = self.execute_command(
operation=operation,
session_id=session_id,
max_rows=max_rows,
max_bytes=max_bytes,
lz4_compression=False,
cursor=cursor,
use_cloud_fetch=False,
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
)
assert result is not None, "execute_command returned None in synchronous mode"
return result
4 changes: 2 additions & 2 deletions src/databricks/sql/backend/sea/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]:

@dataclass
class CreateSessionRequest:
"""Representation of a request to create a new session."""
"""Request to create a new session."""

warehouse_id: str
session_confs: Optional[Dict[str, str]] = None
Expand All @@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]:

@dataclass
class DeleteSessionRequest:
"""Representation of a request to delete a session."""
"""Request to delete a session."""

warehouse_id: str
session_id: str
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":

@dataclass
class CreateSessionResponse:
"""Representation of the response from creating a new session."""
"""Response from creating a new session."""

session_id: str

Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/utils/http_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import requests
from typing import Callable, Dict, Any, Optional, List, Tuple
from typing import Callable, Dict, Any, Optional, Union, List, Tuple
from urllib.parse import urljoin

from databricks.sql.auth.authenticators import AuthProvider
Expand Down
Loading
Loading