Skip to content

Commit ac9e398

Browse files
committed
Improved examples [skip ci]
1 parent 70ff5d4 commit ac9e398

File tree

4 files changed

+28
-23
lines changed

4 files changed

+28
-23
lines changed

examples/cohere/example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding bit(1024))')
1313

1414

15-
def fetch_embeddings(input, input_type):
15+
def embed(input, input_type):
1616
co = cohere.Client()
1717
response = co.embed(texts=input, model='embed-english-v3.0', input_type=input_type, embedding_types=['ubinary'])
1818
return [np.unpackbits(np.array(embedding, dtype=np.uint8)) for embedding in response.embeddings.ubinary]
@@ -23,12 +23,12 @@ def fetch_embeddings(input, input_type):
2323
'The cat is purring',
2424
'The bear is growling'
2525
]
26-
embeddings = fetch_embeddings(input, 'search_document')
26+
embeddings = embed(input, 'search_document')
2727
for content, embedding in zip(input, embeddings):
2828
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, Bit(embedding)))
2929

3030
query = 'forest'
31-
query_embedding = fetch_embeddings([query], 'search_query')[0]
31+
query_embedding = embed([query], 'search_query')[0]
3232
result = conn.execute('SELECT content FROM documents ORDER BY embedding <~> %s LIMIT 5', (Bit(query_embedding),)).fetchall()
3333
for row in result:
3434
print(row[0])

examples/openai/example.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from openai import OpenAI
23
from pgvector.psycopg import register_vector
34
import psycopg
@@ -10,20 +11,24 @@
1011
conn.execute('DROP TABLE IF EXISTS documents')
1112
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(1536))')
1213

14+
15+
def embed(input):
16+
client = OpenAI()
17+
response = client.embeddings.create(input=input, model='text-embedding-3-small')
18+
return [v.embedding for v in response.data]
19+
20+
1321
input = [
1422
'The dog is barking',
1523
'The cat is purring',
1624
'The bear is growling'
1725
]
18-
19-
client = OpenAI()
20-
response = client.embeddings.create(input=input, model='text-embedding-3-small')
21-
embeddings = [v.embedding for v in response.data]
22-
26+
embeddings = embed(input)
2327
for content, embedding in zip(input, embeddings):
24-
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, embedding))
28+
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, np.array(embedding)))
2529

26-
document_id = 1
27-
neighbors = conn.execute('SELECT content FROM documents WHERE id != %(id)s ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = %(id)s) LIMIT 5', {'id': document_id}).fetchall()
28-
for neighbor in neighbors:
29-
print(neighbor[0])
30+
query = 'forest'
31+
query_embedding = embed([query])[0]
32+
result = conn.execute('SELECT content FROM documents ORDER BY embedding <=> %s LIMIT 5', (np.array(query_embedding),)).fetchall()
33+
for row in result:
34+
print(row[0])

examples/sentence_transformers/example.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010
conn.execute('DROP TABLE IF EXISTS documents')
1111
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))')
1212

13+
model = SentenceTransformer('all-MiniLM-L6-v2')
14+
1315
input = [
1416
'The dog is barking',
1517
'The cat is purring',
1618
'The bear is growling'
1719
]
18-
19-
model = SentenceTransformer('all-MiniLM-L6-v2')
2020
embeddings = model.encode(input)
21-
2221
for content, embedding in zip(input, embeddings):
2322
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, embedding))
2423

25-
document_id = 1
26-
neighbors = conn.execute('SELECT content FROM documents WHERE id != %(id)s ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = %(id)s) LIMIT 5', {'id': document_id}).fetchall()
27-
for neighbor in neighbors:
28-
print(neighbor[0])
24+
query = 'forest'
25+
query_embedding = model.encode(query)
26+
result = conn.execute('SELECT content FROM documents ORDER BY embedding <=> %s LIMIT 5', (query_embedding,)).fetchall()
27+
for row in result:
28+
print(row[0])

examples/sparse_search/example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
special_token_ids = [tokenizer.vocab[token] for token in tokenizer.special_tokens_map.values()]
2121

2222

23-
def fetch_embeddings(input):
23+
def embed(input):
2424
feature = tokenizer(
2525
input,
2626
padding=True,
@@ -42,12 +42,12 @@ def fetch_embeddings(input):
4242
'The cat is purring',
4343
'The bear is growling'
4444
]
45-
embeddings = fetch_embeddings(input)
45+
embeddings = embed(input)
4646
for content, embedding in zip(input, embeddings):
4747
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, SparseVector(embedding)))
4848

4949
query = 'forest'
50-
query_embedding = fetch_embeddings([query])[0]
50+
query_embedding = embed([query])[0]
5151
result = conn.execute('SELECT content FROM documents ORDER BY embedding <#> %s LIMIT 5', (SparseVector(query_embedding),)).fetchall()
5252
for row in result:
5353
print(row[0])

0 commit comments

Comments
 (0)