-
Notifications
You must be signed in to change notification settings - Fork 63
/
example.py
47 lines (36 loc) · 1.44 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from datasets import load_dataset
from imagehash import phash
import matplotlib.pyplot as plt
from pgvector.psycopg import register_vector, Bit
import psycopg
def hash_image(img):
return ''.join(['1' if v else '0' for v in phash(img).hash.flatten()])
conn = psycopg.connect(dbname='pgvector_example', autocommit=True)
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
register_vector(conn)
conn.execute('DROP TABLE IF EXISTS images')
conn.execute('CREATE TABLE images (id bigserial PRIMARY KEY, hash bit(64))')
print('Loading dataset')
dataset = load_dataset('mnist')
print('Generating hashes')
images = [{'hash': hash_image(row['image'])} for row in dataset['train']]
print('Storing hashes')
cur = conn.cursor()
with cur.copy('COPY images (hash) FROM STDIN') as copy:
for image in images:
copy.write_row([Bit(image['hash'])])
print('Querying hashes')
results = []
for i in range(5):
image = dataset['test'][i]['image']
result = conn.execute('SELECT id FROM images ORDER BY hash <~> %s LIMIT 5', (hash_image(image),)).fetchall()
nearest_images = [dataset['train'][row[0] - 1]['image'] for row in result]
results.append([image] + nearest_images)
print('Showing results (first column is query image)')
fig, axs = plt.subplots(len(results), len(results[0]))
for i, result in enumerate(results):
for j, image in enumerate(result):
ax = axs[i, j]
ax.imshow(image)
ax.set_axis_off()
plt.show(block=True)