Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cosine_distance for sparse vectors #24027

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,24 @@ public static Double cosineSimilarity(
return dotProduct / (normLeftMap * normRightMap);
}

@Description("Calculates the cosine distance between the give sparse vectors")
@ScalarFunction
@SqlType(StandardTypes.DOUBLE)
public static double cosineDistance(
@OperatorDependency(
operator = IDENTICAL,
argumentTypes = {"varchar", "varchar"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionIsIdentical varcharIdentical,
@OperatorDependency(
operator = HASH_CODE,
argumentTypes = "varchar",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode varcharHashCode,
@SqlType("map(varchar,double)") SqlMap leftMap,
@SqlType("map(varchar,double)") SqlMap rightMap)
{
return 1.0 - cosineSimilarity(varcharIdentical, varcharHashCode, leftMap, rightMap);
}

private static double mapDotProduct(BlockPositionIsIdentical varcharIdentical, BlockPositionHashCode varcharHashCode, SqlMap leftMap, SqlMap rightMap)
{
int leftRawOffset = leftMap.getRawOffset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3457,6 +3457,25 @@
.isNull(DOUBLE);
}

@Test
public void testCosineDistance()
mosabua marked this conversation as resolved.
Show resolved Hide resolved
{
assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, 2.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - (2 * 3 / (Math.sqrt(5) * Math.sqrt(10))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1 - ((2 * 3 + -1 * 1) / (Math.sqrt(1 + 4 + 1) * Math.sqrt(1 + 9))));

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b', 'c'], ARRAY[1.0E0, 2.0E0, -1.0E0])", "map(ARRAY['d', 'e'], ARRAY[1.0E0, 3.0E0])"))
.isEqualTo(1.0);

assertThat(assertions.function("cosine_distance", "null", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))
.isNull();

assertThat(assertions.function("cosine_distance", "map(ARRAY['a', 'b'], ARRAY[1.0E0, null])", "map(ARRAY['c', 'b'], ARRAY[1.0E0, 3.0E0])"))

Check failure on line 3475 in core/trino-main/src/test/java/io/trino/operator/scalar/TestMathFunctions.java

View workflow job for this annotation

GitHub Actions / test (core/trino-main)

TestMathFunctions.testCosineDistance

Cannot invoke "java.lang.Double.doubleValue()" because the return value of "io.trino.operator.scalar.MathFunctions.cosineSimilarity(io.trino.type.BlockTypeOperators$BlockPositionIsIdentical, io.trino.type.BlockTypeOperators$BlockPositionHashCode, io.trino.spi.block.SqlMap, io.trino.spi.block.SqlMap)" is null
.isNull();
}

@Test
public void testInverseNormalCdf()
{
Expand Down
10 changes: 10 additions & 0 deletions docs/src/main/sphinx/functions/math.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ SELECT cosine_distance(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]);
```
:::

:::{function} cosine_distance(x, y) -> double
:no-index:
Calculates the cosine distance between two sparse vectors:

```sql
SELECT cosine_distance(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0]));
mosabua marked this conversation as resolved.
Show resolved Hide resolved
-- 0.0
```
:::

:::{function} cosine_similarity(array(double), array(double)) -> double
Calculates the cosine similarity of two dense vectors:

Expand Down
Loading