Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,30 @@ public static Double cosineSimilarity(
return dotProduct / (normLeftMap * normRightMap);
}

@Description("Calculates the cosine distance between the give sparse vectors")
@ScalarFunction
@SqlNullable
@SqlType(StandardTypes.DOUBLE)
public static Double cosineDistance(
@OperatorDependency(
operator = IDENTICAL,
argumentTypes = {"varchar", "varchar"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionIsIdentical varcharIdentical,
@OperatorDependency(
operator = HASH_CODE,
argumentTypes = "varchar",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode varcharHashCode,
@SqlType("map(varchar,double)") SqlMap leftMap,
@SqlType("map(varchar,double)") SqlMap rightMap)
{
Double cosineSimilarity = cosineSimilarity(varcharIdentical, varcharHashCode, leftMap, rightMap);
if (cosineSimilarity == null) {
return null;
}

return 1.0 - cosineSimilarity;
}

private static double mapDotProduct(BlockPositionIsIdentical varcharIdentical, BlockPositionHashCode varcharHashCode, SqlMap leftMap, SqlMap rightMap)
{
int leftRawOffset = leftMap.getRawOffset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3457,6 +3457,25 @@ public void testCosineSimilarity()
.isNull(DOUBLE);
}

@Test
public void testCosineDistance()
{
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, 2.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - (2 * 3 / (Math.sqrt(5) * Math.sqrt(10))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - ((2 * 3 + -1 * 1) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['d', 'e'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1.0);

assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isNull();

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isNull();
}

@Test
public void testInverseNormalCdf()
{
Expand Down
10 changes: 10 additions & 0 deletions docs/src/main/sphinx/functions/math.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ SELECT cosine_distance(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]);
```
:::

:::{function} cosine_distance(x, y) -> double
:no-index:
Calculates the cosine distance between two sparse vectors:

```sql
SELECT cosine_distance(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0]));
-- 0.0
```
:::

:::{function} cosine_similarity(array(double), array(double)) -> double
Calculates the cosine similarity of two dense vectors:

Expand Down
Loading