1
1
import json
2
2
import os
3
- from collections import defaultdict
4
- from typing import List , Union
5
-
3
+ from typing import List , Union , Any , Dict
6
4
import requests
7
-
8
5
import dspy
9
6
from dspy .primitives .prediction import Prediction
10
7
@@ -26,7 +23,7 @@ class DatabricksRM(dspy.Retrieve):
26
23
Examples:
27
24
Below is a code snippet that shows how to configure Databricks Vector Search endpoints:
28
25
29
- (example adapted from "Databricks: How to create and query a Vector Search Index:
26
+ (example adapted from "Databricks: How to create and query a Vector Search Index:
30
27
https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index)
31
28
32
29
```python
@@ -41,7 +38,7 @@ class DatabricksRM(dspy.Retrieve):
41
38
endpoint_type="STANDARD"
42
39
)
43
40
44
- #Creating Vector Search Index using Python SDK
41
+ #Creating Vector Search Index using Python SDK
45
42
#Example for Direct Vector Access Index
46
43
47
44
index = client.create_direct_access_index(
@@ -56,9 +53,9 @@ class DatabricksRM(dspy.Retrieve):
56
53
"field3": "float",
57
54
"text_vector": "array<float>"}
58
55
)
59
-
56
+
60
57
llm = dspy.OpenAI(model="gpt-3.5-turbo")
61
- retriever_model = DatabricksRM(databricks_index_name = "your_index_name",
58
+ retriever_model = DatabricksRM(databricks_index_name = "your_index_name",
62
59
databricks_endpoint = "your_databricks_host_url", databricks_token = "your_databricks_token", columns= ["id", "field2", "field3", "text_vector"], k=3)
63
60
dspy.settings.configure(lm=llm, rm=retriever_model)
64
61
```
@@ -68,26 +65,93 @@ class DatabricksRM(dspy.Retrieve):
68
65
self.retrieve = DatabricksRM(query=[1, 2, 3], query_type = 'vector')
69
66
```
70
67
"""
71
- def __init__ (self , databricks_index_name = None , databricks_endpoint = None , databricks_token = None , columns = None , filters_json = None , k = 3 , docs_id_column_name = 'id' , text_column_name = 'text' ):
68
+
69
+ def __init__ (
70
+ self ,
71
+ databricks_index_name = None ,
72
+ databricks_endpoint = None ,
73
+ databricks_token = None ,
74
+ columns = None ,
75
+ filters_json = None ,
76
+ k = 3 ,
77
+ docs_id_column_name = "id" ,
78
+ text_column_name = "text" ,
79
+ ):
72
80
super ().__init__ (k = k )
73
81
if not databricks_token and not os .environ .get ("DATABRICKS_TOKEN" ):
74
- raise ValueError ("You must supply databricks_token or set environment variable DATABRICKS_TOKEN" )
82
+ raise ValueError (
83
+ "You must supply databricks_token or set environment variable DATABRICKS_TOKEN"
84
+ )
75
85
if not databricks_endpoint and not os .environ .get ("DATABRICKS_HOST" ):
76
- raise ValueError ("You must supply databricks_endpoint or set environment variable DATABRICKS_HOST" )
86
+ raise ValueError (
87
+ "You must supply databricks_endpoint or set environment variable DATABRICKS_HOST"
88
+ )
77
89
if not databricks_index_name :
78
90
raise ValueError ("You must supply vector index name" )
79
91
if not columns :
80
- raise ValueError ("You must specify a list of column names to be included in the response" )
81
- self .databricks_token = databricks_token if databricks_token else os .environ ["DATABRICKS_TOKEN" ]
82
- self .databricks_endpoint = databricks_endpoint if databricks_endpoint else os .environ ["DATABRICKS_HOST" ]
92
+ raise ValueError (
93
+ "You must specify a list of column names to be included in the response"
94
+ )
95
+ self .databricks_token = (
96
+ databricks_token if databricks_token else os .environ ["DATABRICKS_TOKEN" ]
97
+ )
98
+ self .databricks_endpoint = (
99
+ databricks_endpoint
100
+ if databricks_endpoint
101
+ else os .environ ["DATABRICKS_HOST" ]
102
+ )
83
103
self .databricks_index_name = databricks_index_name
84
104
self .columns = columns
85
105
self .filters_json = filters_json
86
106
self .k = k
87
107
self .docs_id_column_name = docs_id_column_name
88
108
self .text_column_name = text_column_name
89
109
90
- def forward (self , query : Union [str , List [float ]], query_type : str = 'text' , filters_json : str = None ) -> dspy .Prediction :
110
+ def _extract_doc_ids (self , item : Dict [str , Any ]) -> str :
111
+ """Extracts the document id from a search result
112
+
113
+ Args:
114
+ item: Dict[str, Any]: a record from the search results.
115
+ Returns:
116
+ str: document id.
117
+ """
118
+ if self .docs_id_column_name == "metadata" :
119
+ docs_dict = json .loads (item ["metadata" ])
120
+ return docs_dict ["document_id" ]
121
+ return item [self .docs_id_column_name ]
122
+
123
+ def _get_extra_columns (self , item : Dict [str , Any ]) -> Dict [str , Any ]:
124
+ """Extracts search result column values, excluding the "text" and not "id" columns
125
+
126
+ Args:
127
+ item: Dict[str, Any]: a record from the search results.
128
+ Returns:
129
+ Dict[str, Any]: Search result column values, excluding the "text" and not "id" columns.
130
+ """
131
+ extra_columns = {
132
+ k : v
133
+ for k , v in item .items ()
134
+ if k not in [self .docs_id_column_name , self .text_column_name ]
135
+ }
136
+ if self .docs_id_column_name == "metadata" :
137
+ extra_columns = {
138
+ ** extra_columns ,
139
+ ** {
140
+ "metadata" : {
141
+ k : v
142
+ for k , v in json .loads (item ["metadata" ]).items ()
143
+ if k != "document_id"
144
+ }
145
+ },
146
+ }
147
+ return extra_columns
148
+
149
+ def forward (
150
+ self ,
151
+ query : Union [str , List [float ]],
152
+ query_type : str = "text" ,
153
+ filters_json : str = None ,
154
+ ) -> dspy .Prediction :
91
155
"""Search with Databricks Vector Search Client for self.k top results for query
92
156
93
157
Args:
@@ -105,11 +169,11 @@ def forward(self, query: Union[str, List[float]], query_type: str = 'text', filt
105
169
"columns" : self .columns ,
106
170
"num_results" : self .k ,
107
171
}
108
- if query_type == ' vector' :
172
+ if query_type == " vector" :
109
173
if not isinstance (query , list ):
110
174
raise ValueError ("Query must be a list of floats for query_vector" )
111
175
payload ["query_vector" ] = query
112
- elif query_type == ' text' :
176
+ elif query_type == " text" :
113
177
if not isinstance (query , str ):
114
178
raise ValueError ("Query must be a string for query_text" )
115
179
payload ["query_text" ] = query
@@ -125,23 +189,42 @@ def forward(self, query: Union[str, List[float]], query_type: str = 'text', filt
125
189
)
126
190
results = response .json ()
127
191
128
- docs = defaultdict (float )
129
- doc_ids = []
130
- text , score = None , None
131
- for data_row in results ["result" ]["data_array" ]:
132
- for col , val in zip (results ["manifest" ]["columns" ], data_row ):
133
- if col ["name" ] == self .docs_id_column_name :
134
- if self .docs_id_column_name == 'metadata' :
135
- docs_dict = json .loads (val )
136
- doc_ids .append (str (docs_dict ["document_id" ]))
137
- else :
138
- doc_ids .append (str (val ))
139
- text = val
140
- if col ["name" ] == self .text_column_name :
141
- text = val
142
- if col ["name" ] == 'score' :
143
- score = val
144
- docs [text ] += score
145
-
146
- sorted_docs = sorted (docs .items (), key = lambda x : x [1 ], reverse = True )[:self .k ]
147
- return Prediction (docs = [doc for doc , _ in sorted_docs ], doc_ids = doc_ids )
192
+ # Check for errors from REST API call
193
+ if response .json ().get ("error_code" , None ) != None :
194
+ raise Exception (
195
+ f"ERROR: { response .json ()['error_code' ]} -- { response .json ()['message' ]} "
196
+ )
197
+
198
+ # Checking if defined columns are present in the index columns
199
+ col_names = [column ["name" ] for column in results ["manifest" ]["columns" ]]
200
+
201
+ if self .docs_id_column_name not in col_names :
202
+ raise Exception (
203
+ f"docs_id_column_name: '{ self .docs_id_column_name } ' is not in the index columns: \n { col_names } "
204
+ )
205
+
206
+ if self .text_column_name not in col_names :
207
+ raise Exception (
208
+ f"text_column_name: '{ self .text_column_name } ' is not in the index columns: \n { col_names } "
209
+ )
210
+
211
+ # Extracting the results
212
+ items = []
213
+ for idx , data_row in enumerate (results ["result" ]["data_array" ]):
214
+ item = {}
215
+ for col_name , val in zip (col_names , data_row ):
216
+ item [col_name ] = val
217
+ items += [item ]
218
+
219
+ # Sorting results by score in descending order
220
+ sorted_docs = sorted (items , key = lambda x : x ["score" ], reverse = True )[:self .k ]
221
+
222
+ # Returning the prediction
223
+ return Prediction (
224
+ docs = [doc [self .text_column_name ] for doc in sorted_docs ],
225
+ doc_ids = [
226
+ self ._extract_doc_ids (doc )
227
+ for doc in sorted_docs
228
+ ],
229
+ extra_columns = [self ._get_extra_columns (item ) for item in sorted_docs ],
230
+ )
0 commit comments