@@ -30,6 +30,7 @@ class PgVectorRM(dspy.Retrieve):
30
30
k (Optional[int]): Default number of top passages to retrieve. Defaults to 20
31
31
embedding_field (str = "embedding"): Field containing passage embeddings. Defaults to "embedding"
32
32
fields (List[str] = ['text']): Fields to retrieve from the table. Defaults to "text"
33
+ embedding_model (str = "text-embedding-ada-002"): Field containing the OpenAI embedding model to use. Defaults to "text-embedding-ada-002"
33
34
34
35
Examples:
35
36
Below is a code snippet that shows how to use PgVector as the default retriever
@@ -61,9 +62,10 @@ def __init__(
61
62
db_url : str ,
62
63
pg_table_name : str ,
63
64
openai_client : openai .OpenAI ,
64
- k : Optional [int ]= 20 ,
65
+ k : Optional [int ] = 20 ,
65
66
embedding_field : str = "embedding" ,
66
67
fields : List [str ] = ['text' ],
68
+ embedding_model : str = "text-embedding-ada-002" ,
67
69
):
68
70
"""
69
71
k = 20 is the number of paragraphs to retrieve
@@ -75,10 +77,11 @@ def __init__(
75
77
self .pg_table_name = pg_table_name
76
78
self .fields = fields
77
79
self .embedding_field = embedding_field
80
+ self .embedding_model = embedding_model
78
81
79
82
super ().__init__ (k = k )
80
83
81
- def forward (self , query : str , k : Optional [int ]= 20 ):
84
+ def forward (self , query : str , k : Optional [int ] = 20 ):
82
85
"""Search with PgVector for self.k top passages for query
83
86
84
87
Args:
@@ -89,7 +92,7 @@ def forward(self, query: str, k: Optional[int]=20):
89
92
"""
90
93
# Embed query
91
94
query_embedding = self .openai_client .embeddings .create (
92
- model = "text-embedding-ada-002" ,
95
+ model = self . embedding_model ,
93
96
input = query ,
94
97
encoding_format = "float" ,
95
98
).data [0 ].embedding
@@ -112,7 +115,9 @@ def forward(self, query: str, k: Optional[int]=20):
112
115
sql_query ,
113
116
(query_embedding , self .k ))
114
117
rows = cur .fetchall ()
118
+ columns = [descrip [0 ] for descrip in cur .description ]
115
119
for row in rows :
116
- related_paragraphs .append (dspy .Example (long_text = row [0 ], document_id = row [1 ]))
120
+ data = dict (zip (columns , row ))
121
+ related_paragraphs .append (dspy .Example (** data ))
117
122
# Return Prediction
118
123
return related_paragraphs
0 commit comments