Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions docs/components/vectordbs/dbs/databricks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@ config = {
"workspace_url": "https://your-workspace.databricks.com",
"access_token": "your-access-token",
"endpoint_name": "your-vector-search-endpoint",
"index_name": "catalog.schema.index_name",
"source_table_name": "catalog.schema.source_table",
"embedding_dimension": 1536
"catalog": "your_unity_catalog_catalog_name",
"schema": "your_unity_catalog_schema_name",
"table_name": "your_delta_table_name",
"collection_name": "your_vector_search_index_name",
"index_type": "DELTA_SYNC_OR_DIRECT_ACCESS",
"embedding_dimension": 1536,
"endpoint_type": "STANDARD_OR_STORAGE_OPTIMIZED",
"pipeline_type": "TRIGGERED_OR_CONTINUOUS",
"query_type": "ANN_OR_HYBRID",
"warehouse_name": "your_SQL_Warehouse_name",
"embedding_model_endpoint_name": "your_embedding_model_endpoint",
}
}
}
Expand All @@ -38,17 +46,39 @@ Here are the parameters available for configuring Databricks Vector Search:
| --- | --- | --- |
| `workspace_url` | The URL of your Databricks workspace | **Required** |
| `access_token` | Personal Access Token for authentication | `None` |
| `service_principal_client_id` | Service principal client ID (alternative to access_token) | `None` |
| `service_principal_client_secret` | Service principal client secret (required with client_id) | `None` |
| `client_id` | Service principal client ID (alternative to access_token) | `None` |
| `client_secret` | Service principal client secret (required with client_id) | `None` |
| `azure_client_id` | Azure AD application client ID (for Azure Databricks) | `None` |
| `azure_client_secret` | Azure AD application client secret (for Azure Databricks) | `None` |
| `endpoint_name` | Name of the Vector Search endpoint | **Required** |
| `index_name` | Name of the vector index (Unity Catalog format: catalog.schema.index) | **Required** |
| `source_table_name` | Name of the source Delta table (Unity Catalog format: catalog.schema.table) | **Required** |
| `embedding_dimension` | Dimension of self-managed embeddings | `1536` |
| `embedding_source_column` | Column name for text when using Databricks-computed embeddings | `None` |
| `catalog` | Unity Catalog catalog name | **Required** |
| `schema` | Unity Catalog schema name | **Required** |
| `table_name` | Name of the source Delta table | **Required** |
| `collection_name` | Name of the vector index | `mem0` |
| `index_type` | Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" | `DELTA_SYNC` |
| `embedding_model_endpoint_name` | Databricks serving endpoint for embeddings | `None` |
| `embedding_vector_column` | Column name for self-managed embedding vectors | `embedding` |
| `embedding_dimension` | Dimension of self-managed embeddings | `1536` |
| `endpoint_type` | Type of endpoint (`STANDARD` or `STORAGE_OPTIMIZED`) | `STANDARD` |
| `sync_computed_embeddings` | Whether to sync computed embeddings automatically | `True` |
| `pipeline_type` | Sync pipeline type, either "TRIGGERED" or "CONTINUOUS" | `TRIGGERED` |
| `warehouse_name` | Databricks SQL warehouse Name (if using SQL warehouse) | `None`
| `query_type` | Query type, either "ANN" or "HYBRID" | `ANN` |

### Schema

Columns for Storing Memories and Vector Search

| Name | Type | Description |
| --- | --- | --- |
| `memory_id` | `string` | Primary Key |
| `hash` | `string` | Hash of the memory content |
| `agent_id` | `string` | ID of the agent |
| `run_id` | `string` | ID of the run |
| `user_id` | `string` | ID of the user |
| `memory` | `string` | Memory content |
| `metadata` | `string` | Additional metadata |
| `created_at` | `timestamp` | Creation timestamp |
| `updated_at` | `timestamp` | Last update timestamp |
| `embedding` | `array<float>` | Embedding vector (Only if index_type="DIRECT_ACCESS") |

### Authentication

Expand All @@ -61,11 +91,13 @@ config = {
"provider": "databricks",
"config": {
"workspace_url": "https://your-workspace.databricks.com",
"service_principal_client_id": "your-service-principal-id",
"service_principal_client_secret": "your-service-principal-secret",
"client_id": "your-service-principal-id",
"client_secret": "your-service-principal-secret",
"endpoint_name": "your-endpoint",
"index_name": "catalog.schema.index_name",
"source_table_name": "catalog.schema.source_table"
"catalog": "your_unity_catalog_catalog_name",
"schema": "your_unity_catalog_schema_name",
"table_name": "your_delta_table_name",
"collection_name": "your_index_name",
}
}
}
Expand All @@ -80,8 +112,10 @@ config = {
"workspace_url": "https://your-workspace.databricks.com",
"access_token": "your-personal-access-token",
"endpoint_name": "your-endpoint",
"index_name": "catalog.schema.index_name",
"source_table_name": "catalog.schema.source_table"
"catalog": "your_unity_catalog_catalog_name",
"schema": "your_unity_catalog_schema_name",
"table_name": "your_delta_table_name",
"collection_name": "your_index_name",
}
}
}
Expand All @@ -98,8 +132,8 @@ config = {
"provider": "databricks",
"config": {
# ... authentication config ...
# By default the column name will be "memory"
"embedding_dimension": 768, # Match your embedding model
"embedding_vector_column": "embedding"
}
}
}
Expand All @@ -114,7 +148,7 @@ config = {
"provider": "databricks",
"config": {
# ... authentication config ...
"embedding_source_column": "text",
# By default the column name will be "embedding"
"embedding_model_endpoint_name": "e5-small-v2"
}
}
Expand Down
73 changes: 53 additions & 20 deletions mem0/vector_stores/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
catalog (str): Unity Catalog catalog name.
schema (str): Unity Catalog schema name.
table_name (str): Source Delta table name.
index_name (str, optional): Vector search index name (default: "mem0").
collection_name (str, optional): Vector search index name (default: "mem0").
index_type (str, optional): Index type, either "DELTA_SYNC" or "DIRECT_ACCESS" (default: "DELTA_SYNC").
embedding_model_endpoint_name (str, optional): Embedding model endpoint for Databricks-computed embeddings.
embedding_dimension (int, optional): Vector embedding dimensions (default: 1536).
Expand Down Expand Up @@ -572,14 +572,28 @@ def get(self, vector_id) -> MemoryResult:
filters = {"memory_id": vector_id}
filters_json = json.dumps(filters)

results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_text=" ", # Empty query, rely on filters
num_results=1,
query_type=self.query_type,
filters_json=filters_json,
)
if self.index_type == VectorIndexType.DELTA_SYNC:
# Text-based search
results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_text=" ",
num_results=1,
query_type=self.query_type,
filters_json=filters_json,
)
elif self.index_type == VectorIndexType.DIRECT_ACCESS:
# Vector-based search
results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_vector=[0.0] * self.embedding_dimension,
num_results=1,
query_type=self.query_type,
filters_json=filters_json,
)
else:
raise ValueError("Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS.")

# Process results
result_data = results.result if hasattr(results, "result") else results
Expand All @@ -589,7 +603,9 @@ def get(self, vector_id) -> MemoryResult:
raise KeyError(f"Vector with ID {vector_id} not found")

result = data_array[0]
columns = columns = [col.name for col in results.manifest.columns] if results.manifest and results.manifest.columns else []
columns = columns = (
[col.name for col in results.manifest.columns] if results.manifest and results.manifest.columns else []
)
row_data = dict(zip(columns, result))

# Build payload following the standard schema
Expand All @@ -609,7 +625,7 @@ def get(self, vector_id) -> MemoryResult:
payload[field] = row_data[field]

# Add metadata
if "metadata" in row_data and row_data.get('metadata'):
if "metadata" in row_data and row_data.get("metadata"):
try:
metadata = json.loads(extract_json(row_data["metadata"]))
payload.update(metadata)
Expand Down Expand Up @@ -686,14 +702,31 @@ def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]:
filters_json = json.dumps(filters) if filters else None
num_results = limit or 100
columns = self.column_names
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=columns,
query_text=" ",
num_results=num_results,
query_type=self.query_type,
filters_json=filters_json,
)

# Choose query type
if self.index_type == VectorIndexType.DELTA_SYNC:
# Text-based search
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_text=" ",
num_results=limit,
query_type=self.query_type,
filters_json=filters_json,
)
elif self.index_type == VectorIndexType.DIRECT_ACCESS:
# Vector-based search
sdk_results = self.client.vector_search_indexes.query_index(
index_name=self.fully_qualified_index_name,
columns=self.column_names,
query_vector=[0.0] * self.embedding_dimension,
num_results=limit,
query_type=self.query_type,
filters_json=filters_json,
)
else:
raise ValueError("Must provide query text for DELTA_SYNC or vectors for DIRECT_ACCESS.")

result_data = sdk_results.result if hasattr(sdk_results, "result") else sdk_results
data_array = result_data.data_array if hasattr(result_data, "data_array") else []

Expand All @@ -708,7 +741,7 @@ def list(self, filters: dict = None, limit: int = None) -> list[MemoryResult]:
except Exception:
pass
memory_id = row_dict.get("memory_id") or row_dict.get("id")
payload['data'] = payload['memory']
payload["data"] = payload["memory"]
memory_results.append(MemoryResult(id=memory_id, payload=payload))
return [memory_results]
except Exception as e:
Expand Down