Skip to content

Commit 2e8002d

Browse files
committed
Added Citus example [skip ci]
1 parent 3772ac2 commit 2e8002d

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Or check out some examples:
2626
- [Sentence embeddings](https://github.com/pgvector/pgvector-elixir/blob/master/examples/bumblebee/example.exs) with Bumblebee
2727
- [Hybrid search](https://github.com/pgvector/pgvector-elixir/blob/master/examples/hybrid_search/example.exs) with Bumblebee (Reciprocal Rank Fusion)
2828
- [Sparse search](https://github.com/pgvector/pgvector-elixir/blob/master/examples/sparse_search/example.exs) with Bumblebee
29+
- [Horizontal scaling](https://github.com/pgvector/pgvector-elixir/blob/master/examples/citus/example.exs) with Citus
2930
- [Bulk loading](https://github.com/pgvector/pgvector-elixir/blob/master/examples/loading/example.exs) with `COPY`
3031

3132
## Ecto

examples/citus/example.exs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
Postgrex.Types.define(Example.PostgrexTypes, Pgvector.extensions(), [])
2+
3+
rows = 100_000
4+
dimensions = 128
5+
6+
IO.puts("Generating random data")
7+
8+
key = Nx.Random.key(42)
9+
{embeddings, new_key} = Nx.Random.uniform(key, shape: {rows, dimensions})
10+
{categories, new_key} = Nx.Random.randint(new_key, 1, 100, shape: {rows})
11+
{queries, _new_key} = Nx.Random.uniform(new_key, shape: {10, dimensions})
12+
13+
# enable extensions
14+
{:ok, pid} = Postgrex.start_link(database: "pgvector_citus", types: Example.PostgrexTypes)
15+
Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS citus", [])
16+
Postgrex.query!(pid, "CREATE EXTENSION IF NOT EXISTS vector", [])
17+
18+
# GUC variables set on the session do not propagate to Citus workers
19+
# https://github.com/citusdata/citus/issues/462
20+
# you can either:
21+
# 1. set them on the system, user, or database and reconnect
22+
# 2. set them for a transaction with SET LOCAL
23+
Postgrex.query!(pid, "ALTER DATABASE pgvector_citus SET maintenance_work_mem = '512MB'", [])
24+
Postgrex.query!(pid, "ALTER DATABASE pgvector_citus SET hnsw.ef_search = 20", [])
25+
# TODO close connection
26+
27+
# reconnect for updated GUC variables to take effect
28+
{:ok, pid} = Postgrex.start_link(database: "pgvector_citus", types: Example.PostgrexTypes)
29+
30+
IO.puts("Creating distributed table")
31+
32+
Postgrex.query!(pid, "DROP TABLE IF EXISTS items", [])
33+
34+
Postgrex.query!(
35+
pid,
36+
"CREATE TABLE items (id bigserial, embedding vector(#{dimensions}), category_id bigint, PRIMARY KEY (id, category_id))",
37+
[]
38+
)
39+
40+
Postgrex.query!(pid, "SET citus.shard_count = 4", [])
41+
Postgrex.query!(pid, "SELECT create_distributed_table('items', 'category_id')", [])
42+
43+
defmodule Example do
44+
# https://www.postgresql.org/docs/current/sql-copy.html
45+
def copy(stream, rows) do
46+
signature = "PGCOPY\n\xFF\r\n\0"
47+
48+
Enum.into(
49+
[
50+
<<signature::binary, 0::unsigned-32, 0::unsigned-32>>,
51+
Enum.map(rows, &copy_row(&1)),
52+
<<-1::unsigned-16>>
53+
],
54+
stream
55+
)
56+
end
57+
58+
defp copy_row(row) do
59+
count = row |> length()
60+
data = Enum.join(Enum.map(row, &copy_value(&1)))
61+
<<count::unsigned-16, data::binary>>
62+
end
63+
64+
defp copy_value(value) when is_struct(value, Pgvector) do
65+
data = value |> Pgvector.to_binary()
66+
<<IO.iodata_length(data)::unsigned-32, data::binary>>
67+
end
68+
69+
defp copy_value(value) when is_integer(value) do
70+
<<8::unsigned-32, value::64>>
71+
end
72+
end
73+
74+
IO.puts("Loading data in parallel")
75+
76+
Postgrex.transaction(
77+
pid,
78+
fn conn ->
79+
stream =
80+
Postgrex.stream(
81+
conn,
82+
"COPY items (embedding, category_id) FROM STDIN WITH (FORMAT BINARY)",
83+
[]
84+
)
85+
86+
rows =
87+
Enum.map(Enum.zip(embeddings |> Nx.to_list(), categories |> Nx.to_list()), fn {v, c} ->
88+
[v |> Pgvector.new(), c]
89+
end)
90+
91+
stream |> Example.copy(rows)
92+
end,
93+
timeout: 30000
94+
)
95+
96+
IO.puts("Creating index in parallel")
97+
98+
Postgrex.query!(pid, "CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)", [])
99+
100+
IO.puts("Running distributed queries")
101+
102+
for query <- Nx.to_list(queries) do
103+
result =
104+
Postgrex.query!(pid, "SELECT id FROM items ORDER BY embedding <-> $1 LIMIT 10", [
105+
query |> Pgvector.new()
106+
])
107+
108+
IO.inspect(Enum.map(result.rows, fn v -> List.first(v) end))
109+
end

examples/citus/mix.exs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
defmodule Example.MixProject do
2+
use Mix.Project
3+
4+
def project do
5+
[
6+
app: :example,
7+
version: "0.1.0",
8+
deps: deps()
9+
]
10+
end
11+
12+
defp deps do
13+
[
14+
{:pgvector, path: "../.."},
15+
{:postgrex, "~> 0.17"},
16+
{:nx, "~> 0.5"}
17+
]
18+
end
19+
end

0 commit comments

Comments
 (0)