Skip to content

Commit

Permalink
Improved support for Postgres arrays - closes #25
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Dec 31, 2024
1 parent 884db5e commit 2c7c748
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.5.2 (unreleased)

- Improved support for Postgres arrays

## 0.5.1 (2024-12-03)

- Added experimental support for MariaDB 11.7
Expand Down
15 changes: 15 additions & 0 deletions lib/neighbor/postgresql.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def self.initialize!

# prevent unknown OID warning
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(RegisterTypes)

# support vector[]/halfvec[]
ActiveRecord::ConnectionAdapters::PostgreSQL::OID::Array.prepend(ArrayMethods)
end

module RegisterTypes
Expand All @@ -39,5 +42,17 @@ def initialize_type_map(m = type_map)
end
end
end

ArrayWrapper = Struct.new(:to_a)

module ArrayMethods
def type_cast_array(value, method, ...)
if (subtype.is_a?(Neighbor::Type::Vector) || subtype.is_a?(Neighbor::Type::Halfvec)) && method != :deserialize && value.is_a?(::Array) && value.all? { |v| v.is_a?(::Numeric) }
super(ArrayWrapper.new(value), method, ...)
else
super
end
end
end
end
end
12 changes: 12 additions & 0 deletions test/halfvec_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,16 @@ def test_nan
end
assert_equal "Validation failed: Half embedding must have finite values", error.message
end

def test_array
item = Item.create!(half_embeddings: [[1, 2, 3], [4, 5, 6]])
assert_equal [[1, 2, 3], [4, 5, 6]], item.half_embeddings
assert_equal [[1, 2, 3], [4, 5, 6]], Item.last.half_embeddings
end

def test_array_2d
item = Item.create!(half_embeddings: [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
assert_equal [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], item.half_embeddings
assert_equal [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], Item.last.half_embeddings
end
end
1 change: 1 addition & 0 deletions test/support/postgresql.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class PostgresRecord < ActiveRecord::Base
t.sparsevec :sparse_embedding, limit: 3
t.sparsevec :sparse_factors, limit: 5
t.vector :embeddings, limit: 3, array: true
t.halfvec :half_embeddings, limit: 3, array: true

This comment has been minimized.

Copy link
@dimroc

dimroc Jan 2, 2025

🎉

end
add_index :items, :cube_embedding, using: :gist
add_index :items, :embedding, using: :hnsw, opclass: :vector_cosine_ops
Expand Down
10 changes: 8 additions & 2 deletions test/vector_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,14 @@ def test_nan
end

def test_array
Item.connection.execute("INSERT INTO items (embeddings) VALUES (ARRAY['[1,2,3]', '[4,5,6]']::vector[])")
item = Item.last
item = Item.create!(embeddings: [[1, 2, 3], [4, 5, 6]])
assert_equal [[1, 2, 3], [4, 5, 6]], item.embeddings
assert_equal [[1, 2, 3], [4, 5, 6]], Item.last.embeddings
end

def test_array_2d
item = Item.create!(embeddings: [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
assert_equal [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], item.embeddings
assert_equal [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], Item.last.embeddings
end
end

0 comments on commit 2c7c748

Please sign in to comment.