Skip to content

Commit

Permalink
Added Hamming distance for MariaDB
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 6, 2024
1 parent ad6918f commit 5f88029
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 9 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ Supported values are:

- `euclidean`
- `cosine`
- `hamming`

For cosine distance with MariaDB, vectors must be normalized before being stored.

Expand All @@ -409,6 +410,26 @@ class CreateItems < ActiveRecord::Migration[7.2]
end
```

### Binary Vectors

Use the `bigint` type to store binary vectors

```ruby
class AddEmbeddingToItems < ActiveRecord::Migration[7.2]
def change
add_column :items, :embedding, :bigint
end
end
```

Note: Binary vectors can have up to 64 dimensions

Get the nearest neighbors by Hamming distance

```ruby
Item.nearest_neighbors(:embedding, 5, distance: "hamming").first(5)
```

## MySQL

### Distance
Expand Down
9 changes: 7 additions & 2 deletions lib/neighbor/attribute.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ module Neighbor
class Attribute < ActiveRecord::Type::Value
delegate :type, :serialize, :deserialize, :cast, to: :new_cast_type

def initialize(cast_type:, model:, type:)
def initialize(cast_type:, model:, type:, attribute_name:)
@cast_type = cast_type
@model = model
@type = type
@attribute_name = attribute_name
end

private
Expand All @@ -30,7 +31,11 @@ def new_cast_type
raise ArgumentError, "Unsupported type"
end
when :mariadb
Type::MysqlVector.new
if @model.columns_hash[@attribute_name.to_s]&.type == :integer
@cast_type
else
Type::MysqlVector.new
end
else
@cast_type
end
Expand Down
12 changes: 8 additions & 4 deletions lib/neighbor/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def self.neighbor_attributes
end

if ActiveRecord::VERSION::STRING.to_f >= 7.2
decorate_attributes(attribute_names) do |_name, cast_type|
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type)
decorate_attributes(attribute_names) do |name, cast_type|
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
end
else
attribute_names.each do |attribute_name|
attribute attribute_name do |cast_type|
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type)
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
end
end
end
Expand Down Expand Up @@ -142,7 +142,11 @@ def self.neighbor_attributes
"#{operator}(#{quoted_attribute}, #{query})"
end
when :mariadb
"VEC_DISTANCE(#{quoted_attribute}, #{query})"
if operator == "BIT_COUNT"
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
else
"VEC_DISTANCE(#{quoted_attribute}, #{query})"
end
when :mysql
if operator == "BIT_COUNT"
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
Expand Down
10 changes: 7 additions & 3 deletions lib/neighbor/utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ module Utils
def self.validate_dimensions(value, type, expected, adapter)
dimensions = type == :sparsevec ? value.dimensions : value.size
dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)

if expected && dimensions != expected
"Expected #{expected} dimensions, not #{dimensions}"
end
end

def self.validate_finite(value, type)
case type
when :bit
when :bit, :integer
true
when :sparsevec
value.values.all?(&:finite?)
Expand Down Expand Up @@ -63,8 +64,6 @@ def self.type(adapter, column_type)
else
column_type
end
when :mariadb
:vector
else
column_type
end
Expand All @@ -90,6 +89,11 @@ def self.operator(adapter, column_type, distance)
when "euclidean", "cosine"
"VEC_DISTANCE"
end
when :integer
case distance
when "hamming"
"BIT_COUNT"
end
else
raise ArgumentError, "Unsupported type: #{column_type}"
end
Expand Down
28 changes: 28 additions & 0 deletions test/mariadb_bit_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
require_relative "test_helper"
require_relative "support/mariadb"

class MariadbBitTest < Minitest::Test
def setup
MariadbBinaryItem.delete_all
end

def test_hamming
create_bit_items
result = MariadbBinaryItem.find(1).nearest_neighbors(:binary_embedding, distance: "hamming").first(3)
assert_equal [2, 3], result.map(&:id)
assert_elements_in_delta [2, 3], result.map(&:neighbor_distance)
end

def test_hamming_scope
create_bit_items
result = MariadbBinaryItem.nearest_neighbors(:binary_embedding, 5, distance: "hamming").first(5)
assert_equal [2, 3, 1], result.map(&:id)
assert_elements_in_delta [0, 1, 2], result.map(&:neighbor_distance)
end

def create_bit_items
MariadbBinaryItem.create!(id: 1, binary_embedding: 0)
MariadbBinaryItem.create!(id: 2, binary_embedding: 5)
MariadbBinaryItem.create!(id: 3, binary_embedding: 7)
end
end
8 changes: 8 additions & 0 deletions test/support/mariadb.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class MariadbRecord < ActiveRecord::Base
t.binary :embedding, null: false
t.index :embedding, type: :vector
end

create_table :mariadb_binary_items, force: true do |t|
t.bigint :binary_embedding
end
end

class MariadbItem < MariadbRecord
Expand All @@ -25,5 +29,9 @@ class MariadbDimensionsItem < MariadbRecord
self.table_name = "mariadb_items"
end

class MariadbBinaryItem < MariadbRecord
has_neighbors :binary_embedding
end

# ensure has_neighbors does not cause model schema to load
raise "has_neighbors loading model schema early" if MariadbItem.send(:schema_loaded?)

0 comments on commit 5f88029

Please sign in to comment.