Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,22 @@ public enum TransformFunctionType {
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC),
ordinal -> ordinal > 1 && ordinal < 4)),

// Vector functions
// TODO: Once VECTOR type is defined, we should update here.
COSINE_DISTANCE("cosineDistance", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC),
ordinal -> ordinal > 1 && ordinal < 4), "cosine_distance"),
INNER_PRODUCT("innerProduct", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "inner_product"),
L1_DISTANCE("l1Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l1_distance"),
L2_DISTANCE("l2Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l2_distance"),
VECTOR_DIMS("vectorDims", ReturnTypes.explicit(SqlTypeName.INTEGER),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_dims"),
VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"),

// Trigonometry
SIN("sin"),
COS("cos"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.common.function.scalar;

import com.google.common.base.Preconditions;
import org.apache.pinot.spi.annotations.ScalarFunction;


/**
* Inbuilt Vector Transformation Functions
* The functions can be used as UDFs in Query when added in the FunctionRegistry.
* @ScalarFunction annotation is used with each method for the registration
*
* Example usage:
*/
public class VectorFunctions {
private VectorFunctions() {
}

/**
* Returns the cosine distance between two vectors
* @param vector1 vector1
* @param vector2 vector2
* @return cosine distance
*/
@ScalarFunction(names = {"cosinedistance", "cosine_distance"})
public static double cosineDistance(float[] vector1, float[] vector2) {
return cosineDistance(vector1, vector2, Double.NaN);
}

/**
* Returns the cosine distance between two vectors, with a default value if the norm of either vector is 0.
* @param vector1 vector1
* @param vector2 vector2
* @param defaultValue default value when either vector has a norm of 0
* @return cosine distance
*/
@ScalarFunction(names = {"cosinedistance", "cosine_distance"})
public static double cosineDistance(float[] vector1, float[] vector2, double defaultValue) {
validateVectors(vector1, vector2);
double dotProduct = 0.0;
double norm1 = 0.0;
double norm2 = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
norm1 += Math.pow(vector1[i], 2);
norm2 += Math.pow(vector2[i], 2);
}
if (norm1 == 0 || norm2 == 0) {
return defaultValue;
}
return 1 - (dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to check for divide by zero

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make cosineDistance can take third optional argument default value, which is used when either vector has a norm of 0

}

/**
* Returns the inner product between two vectors
* @param vector1 vector1
* @param vector2 vector2
* @return inner product
*/
@ScalarFunction(names = {"innerproduct", "inner_product"})
public static double innerProduct(float[] vector1, float[] vector2) {
validateVectors(vector1, vector2);
double dotProduct = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
}
return dotProduct;
}

/**
* Returns the L2 distance between two vectors
* @param vector1 vector1
* @param vector2 vector2
* @return L2 distance
*/
@ScalarFunction(names = {"l2distance", "l2_distance"})
public static double l2Distance(float[] vector1, float[] vector2) {
validateVectors(vector1, vector2);
double distance = 0.0;
for (int i = 0; i < vector1.length; i++) {
distance += Math.pow(vector1[i] - vector2[i], 2);
}
return Math.sqrt(distance);
}

/**
* Returns the L1 distance between two vectors
* @param vector1 vector1
* @param vector2 vector2
* @return L1 distance
*/
@ScalarFunction(names = {"l1distance", "l1_distance"})
public static double l1Distance(float[] vector1, float[] vector2) {
validateVectors(vector1, vector2);
double distance = 0.0;
for (int i = 0; i < vector1.length; i++) {
distance += Math.abs(vector1[i] - vector2[i]);
}
return distance;
}

/**
* Returns the number of dimensions in a vector
* @param vector input vector
* @return number of dimensions
*/
@ScalarFunction(names = {"vectordims", "vector_dims"})
public static int vectorDims(float[] vector) {
validateVector(vector);
return vector.length;
}

/**
* Returns the norm of a vector
* @param vector input vector
* @return norm
*/
@ScalarFunction(names = {"vectornorm", "vector_norm"})
public static double vectorNorm(float[] vector) {
validateVector(vector);
double norm = 0.0;
for (int i = 0; i < vector.length; i++) {
norm += Math.pow(vector[i], 2);
}
return Math.sqrt(norm);
}

public static void validateVectors(float[] vector1, float[] vector2) {
Preconditions.checkArgument(vector1 != null && vector2 != null, "Null vector passed");
Preconditions.checkArgument(vector1.length == vector2.length, "Vector lengths do not match");
}

public static void validateVector(float[] vector) {
Preconditions.checkArgument(vector != null, "Null vector passed");
Preconditions.checkArgument(vector.length > 0, "Empty vector passed");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@
import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.SinhTransformFunction;
import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.TanTransformFunction;
import org.apache.pinot.core.operator.transform.function.TrigonometricTransformFunctions.TanhTransformFunction;
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.CosineDistanceTransformFunction;
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.InnerProductTransformFunction;
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.L1DistanceTransformFunction;
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.L2DistanceTransformFunction;
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorDimsTransformFunction;
import org.apache.pinot.core.operator.transform.function.VectorTransformFunctions.VectorNormTransformFunction;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils;
import org.apache.pinot.segment.spi.datasource.DataSource;
Expand Down Expand Up @@ -217,6 +223,14 @@ private static Map<String, Class<? extends TransformFunction>> createRegistry()
typeToImplementation.put(TransformFunctionType.DEGREES, DegreesTransformFunction.class);
typeToImplementation.put(TransformFunctionType.RADIANS, RadiansTransformFunction.class);

// Vector functions
typeToImplementation.put(TransformFunctionType.COSINE_DISTANCE, CosineDistanceTransformFunction.class);
typeToImplementation.put(TransformFunctionType.INNER_PRODUCT, InnerProductTransformFunction.class);
typeToImplementation.put(TransformFunctionType.L1_DISTANCE, L1DistanceTransformFunction.class);
typeToImplementation.put(TransformFunctionType.L2_DISTANCE, L2DistanceTransformFunction.class);
typeToImplementation.put(TransformFunctionType.VECTOR_DIMS, VectorDimsTransformFunction.class);
typeToImplementation.put(TransformFunctionType.VECTOR_NORM, VectorNormTransformFunction.class);

Map<String, Class<? extends TransformFunction>> registry = new HashMap<>(typeToImplementation.size());
for (Map.Entry<TransformFunctionType, Class<? extends TransformFunction>> entry : typeToImplementation.entrySet()) {
for (String alias : entry.getKey().getAlternativeNames()) {
Expand Down
Loading