Skip to content

Commit 3a9262c

Browse files
committed
feature(dspy): Support more embedding models and return all fields
1 parent da1f8ae commit 3a9262c

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

dspy/retrieve/pgvector_rm.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class PgVectorRM(dspy.Retrieve):
3030
k (Optional[int]): Default number of top passages to retrieve. Defaults to 20
3131
embedding_field (str = "embedding"): Field containing passage embeddings. Defaults to "embedding"
3232
fields (List[str] = ['text']): Fields to retrieve from the table. Defaults to "text"
33+
embedding_model (str = "text-embedding-ada-002"): Field containing the OpenAI embedding model to use. Defaults to "text-embedding-ada-002"
3334
3435
Examples:
3536
Below is a code snippet that shows how to use PgVector as the default retriever
@@ -61,9 +62,10 @@ def __init__(
6162
db_url: str,
6263
pg_table_name: str,
6364
openai_client: openai.OpenAI,
64-
k: Optional[int]=20,
65+
k: Optional[int] = 20,
6566
embedding_field: str = "embedding",
6667
fields: List[str] = ['text'],
68+
embedding_model: str = "text-embedding-ada-002",
6769
):
6870
"""
6971
k = 20 is the number of paragraphs to retrieve
@@ -75,10 +77,11 @@ def __init__(
7577
self.pg_table_name = pg_table_name
7678
self.fields = fields
7779
self.embedding_field = embedding_field
80+
self.embedding_model = embedding_model
7881

7982
super().__init__(k=k)
8083

81-
def forward(self, query: str, k: Optional[int]=20):
84+
def forward(self, query: str, k: Optional[int] = 20):
8285
"""Search with PgVector for self.k top passages for query
8386
8487
Args:
@@ -89,7 +92,7 @@ def forward(self, query: str, k: Optional[int]=20):
8992
"""
9093
# Embed query
9194
query_embedding = self.openai_client.embeddings.create(
92-
model="text-embedding-ada-002",
95+
model=self.embedding_model,
9396
input=query,
9497
encoding_format="float",
9598
).data[0].embedding
@@ -112,7 +115,9 @@ def forward(self, query: str, k: Optional[int]=20):
112115
sql_query,
113116
(query_embedding, self.k))
114117
rows = cur.fetchall()
118+
columns = [descrip[0] for descrip in cur.description]
115119
for row in rows:
116-
related_paragraphs.append(dspy.Example(long_text=row[0], document_id=row[1]))
120+
data = dict(zip(columns, row))
121+
related_paragraphs.append(dspy.Example(**data))
117122
# Return Prediction
118123
return related_paragraphs

0 commit comments

Comments
 (0)