Skip to content

Commit 4d4fa52

Browse files
Merge pull request stanfordnlp#1313 from lmoros-DB/DatabricksRM-returning_metadata
databricks_rm: returning extra columns
2 parents af5617a + 7e1f33e commit 4d4fa52

File tree

1 file changed

+120
-37
lines changed

1 file changed

+120
-37
lines changed

dspy/retrieve/databricks_rm.py

+120-37
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import json
22
import os
3-
from collections import defaultdict
4-
from typing import List, Union
5-
3+
from typing import List, Union, Any, Dict
64
import requests
7-
85
import dspy
96
from dspy.primitives.prediction import Prediction
107

@@ -26,7 +23,7 @@ class DatabricksRM(dspy.Retrieve):
2623
Examples:
2724
Below is a code snippet that shows how to configure Databricks Vector Search endpoints:
2825
29-
(example adapted from "Databricks: How to create and query a Vector Search Index:
26+
(example adapted from "Databricks: How to create and query a Vector Search Index:
3027
https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index)
3128
3229
```python
@@ -41,7 +38,7 @@ class DatabricksRM(dspy.Retrieve):
4138
endpoint_type="STANDARD"
4239
)
4340
44-
#Creating Vector Search Index using Python SDK
41+
#Creating Vector Search Index using Python SDK
4542
#Example for Direct Vector Access Index
4643
4744
index = client.create_direct_access_index(
@@ -56,9 +53,9 @@ class DatabricksRM(dspy.Retrieve):
5653
"field3": "float",
5754
"text_vector": "array<float>"}
5855
)
59-
56+
6057
llm = dspy.OpenAI(model="gpt-3.5-turbo")
61-
retriever_model = DatabricksRM(databricks_index_name = "your_index_name",
58+
retriever_model = DatabricksRM(databricks_index_name = "your_index_name",
6259
databricks_endpoint = "your_databricks_host_url", databricks_token = "your_databricks_token", columns= ["id", "field2", "field3", "text_vector"], k=3)
6360
dspy.settings.configure(lm=llm, rm=retriever_model)
6461
```
@@ -68,26 +65,93 @@ class DatabricksRM(dspy.Retrieve):
6865
self.retrieve = DatabricksRM(query=[1, 2, 3], query_type = 'vector')
6966
```
7067
"""
71-
def __init__(self, databricks_index_name = None, databricks_endpoint = None, databricks_token = None, columns = None, filters_json = None, k = 3, docs_id_column_name = 'id', text_column_name = 'text'):
68+
69+
def __init__(
70+
self,
71+
databricks_index_name=None,
72+
databricks_endpoint=None,
73+
databricks_token=None,
74+
columns=None,
75+
filters_json=None,
76+
k=3,
77+
docs_id_column_name="id",
78+
text_column_name="text",
79+
):
7280
super().__init__(k=k)
7381
if not databricks_token and not os.environ.get("DATABRICKS_TOKEN"):
74-
raise ValueError("You must supply databricks_token or set environment variable DATABRICKS_TOKEN")
82+
raise ValueError(
83+
"You must supply databricks_token or set environment variable DATABRICKS_TOKEN"
84+
)
7585
if not databricks_endpoint and not os.environ.get("DATABRICKS_HOST"):
76-
raise ValueError("You must supply databricks_endpoint or set environment variable DATABRICKS_HOST")
86+
raise ValueError(
87+
"You must supply databricks_endpoint or set environment variable DATABRICKS_HOST"
88+
)
7789
if not databricks_index_name:
7890
raise ValueError("You must supply vector index name")
7991
if not columns:
80-
raise ValueError("You must specify a list of column names to be included in the response")
81-
self.databricks_token = databricks_token if databricks_token else os.environ["DATABRICKS_TOKEN"]
82-
self.databricks_endpoint = databricks_endpoint if databricks_endpoint else os.environ["DATABRICKS_HOST"]
92+
raise ValueError(
93+
"You must specify a list of column names to be included in the response"
94+
)
95+
self.databricks_token = (
96+
databricks_token if databricks_token else os.environ["DATABRICKS_TOKEN"]
97+
)
98+
self.databricks_endpoint = (
99+
databricks_endpoint
100+
if databricks_endpoint
101+
else os.environ["DATABRICKS_HOST"]
102+
)
83103
self.databricks_index_name = databricks_index_name
84104
self.columns = columns
85105
self.filters_json = filters_json
86106
self.k = k
87107
self.docs_id_column_name = docs_id_column_name
88108
self.text_column_name = text_column_name
89109

90-
def forward(self, query: Union[str, List[float]], query_type: str = 'text', filters_json: str = None) -> dspy.Prediction:
110+
def _extract_doc_ids(self, item: Dict[str, Any]) -> str:
111+
"""Extracts the document id from a search result
112+
113+
Args:
114+
item: Dict[str, Any]: a record from the search results.
115+
Returns:
116+
str: document id.
117+
"""
118+
if self.docs_id_column_name == "metadata":
119+
docs_dict = json.loads(item["metadata"])
120+
return docs_dict["document_id"]
121+
return item[self.docs_id_column_name]
122+
123+
def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]:
124+
"""Extracts search result column values, excluding the "text" and not "id" columns
125+
126+
Args:
127+
item: Dict[str, Any]: a record from the search results.
128+
Returns:
129+
Dict[str, Any]: Search result column values, excluding the "text" and not "id" columns.
130+
"""
131+
extra_columns = {
132+
k: v
133+
for k, v in item.items()
134+
if k not in [self.docs_id_column_name, self.text_column_name]
135+
}
136+
if self.docs_id_column_name == "metadata":
137+
extra_columns = {
138+
**extra_columns,
139+
**{
140+
"metadata": {
141+
k: v
142+
for k, v in json.loads(item["metadata"]).items()
143+
if k != "document_id"
144+
}
145+
},
146+
}
147+
return extra_columns
148+
149+
def forward(
150+
self,
151+
query: Union[str, List[float]],
152+
query_type: str = "text",
153+
filters_json: str = None,
154+
) -> dspy.Prediction:
91155
"""Search with Databricks Vector Search Client for self.k top results for query
92156
93157
Args:
@@ -105,11 +169,11 @@ def forward(self, query: Union[str, List[float]], query_type: str = 'text', filt
105169
"columns": self.columns,
106170
"num_results": self.k,
107171
}
108-
if query_type == 'vector':
172+
if query_type == "vector":
109173
if not isinstance(query, list):
110174
raise ValueError("Query must be a list of floats for query_vector")
111175
payload["query_vector"] = query
112-
elif query_type == 'text':
176+
elif query_type == "text":
113177
if not isinstance(query, str):
114178
raise ValueError("Query must be a string for query_text")
115179
payload["query_text"] = query
@@ -125,23 +189,42 @@ def forward(self, query: Union[str, List[float]], query_type: str = 'text', filt
125189
)
126190
results = response.json()
127191

128-
docs = defaultdict(float)
129-
doc_ids = []
130-
text, score = None, None
131-
for data_row in results["result"]["data_array"]:
132-
for col, val in zip(results["manifest"]["columns"], data_row):
133-
if col["name"] == self.docs_id_column_name:
134-
if self.docs_id_column_name == 'metadata':
135-
docs_dict = json.loads(val)
136-
doc_ids.append(str(docs_dict["document_id"]))
137-
else:
138-
doc_ids.append(str(val))
139-
text = val
140-
if col["name"] == self.text_column_name:
141-
text = val
142-
if col["name"] == 'score':
143-
score = val
144-
docs[text] += score
145-
146-
sorted_docs = sorted(docs.items(), key=lambda x: x[1], reverse=True)[:self.k]
147-
return Prediction(docs=[doc for doc, _ in sorted_docs], doc_ids = doc_ids)
192+
# Check for errors from REST API call
193+
if response.json().get("error_code", None) != None:
194+
raise Exception(
195+
f"ERROR: {response.json()['error_code']} -- {response.json()['message']}"
196+
)
197+
198+
# Checking if defined columns are present in the index columns
199+
col_names = [column["name"] for column in results["manifest"]["columns"]]
200+
201+
if self.docs_id_column_name not in col_names:
202+
raise Exception(
203+
f"docs_id_column_name: '{self.docs_id_column_name}' is not in the index columns: \n {col_names}"
204+
)
205+
206+
if self.text_column_name not in col_names:
207+
raise Exception(
208+
f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}"
209+
)
210+
211+
# Extracting the results
212+
items = []
213+
for idx, data_row in enumerate(results["result"]["data_array"]):
214+
item = {}
215+
for col_name, val in zip(col_names, data_row):
216+
item[col_name] = val
217+
items += [item]
218+
219+
# Sorting results by score in descending order
220+
sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[:self.k]
221+
222+
# Returning the prediction
223+
return Prediction(
224+
docs=[doc[self.text_column_name] for doc in sorted_docs],
225+
doc_ids=[
226+
self._extract_doc_ids(doc)
227+
for doc in sorted_docs
228+
],
229+
extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
230+
)

0 commit comments

Comments
 (0)