Skip to content

Commit a40114c

Browse files
authored
[Feature][Transform-V2] Support vector series sql function (apache#9765)
1 parent 0d52102 commit a40114c

File tree

18 files changed

+643
-43
lines changed

18 files changed

+643
-43
lines changed

docs/en/transform-v2/sql-functions.md

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1155,4 +1155,70 @@ SELECT * FROM dual
11551155
LATERAL VIEW EXPLODE ( SPLIT ( pk_id, ';' ) ) AS pk_id
11561156
LATERAL VIEW OUTER EXPLODE ( age ) AS age
11571157
LATERAL VIEW OUTER EXPLODE ( ARRAY(1,1) ) AS num
1158-
```
1158+
```
1159+
1160+
## Vector Functions
1161+
1162+
### VECTOR_DIMS
1163+
1164+
```VECTOR_DIMS(vector) -> INT```
1165+
1166+
Returns an INT value representing the number of dimensions (elements) in the vector.
1167+
1168+
Example:
1169+
1170+
VECTOR_DIMS(vector)
1171+
1172+
### VECTOR_NORM
1173+
1174+
```VECTOR_NORM(vector) -> DOUBLE```
1175+
1176+
Calculates the L2 norm (Euclidean norm) of a vector, which represents the length or magnitude of the vector.
1177+
1178+
Example:
1179+
1180+
VECTOR_NORM(vector)
1181+
1182+
### INNER_PRODUCT
1183+
1184+
```INNER_PRODUCT(vector1, vector2) -> DOUBLE```
1185+
1186+
Calculates the inner product (dot product) of two vectors, which is used to measure the similarity and projection between the vectors.
1187+
1188+
Example:
1189+
1190+
INNER_PRODUCT(vector1, vector2)
1191+
1192+
### COSINE_DISTANCE
1193+
1194+
```COSINE_DISTANCE(vector1, vector2) -> DOUBLE```
1195+
1196+
Returns a DOUBLE value between 0 and 1:
1197+
1198+
0: Identical vectors (completely similar)
1199+
1200+
1: Orthogonal vectors (completely dissimilar)
1201+
1202+
Example:
1203+
1204+
COSINE_DISTANCE(vector1, vector2)
1205+
1206+
### L1_DISTANCE
1207+
1208+
```L1_DISTANCE(vector1, vector2) -> DOUBLE```
1209+
1210+
Calculates the Manhattan (L1) distance between two vectors.
1211+
1212+
Example:
1213+
1214+
L1_DISTANCE(vector1, vector2)
1215+
1216+
### L2_DISTANCE
1217+
1218+
```L2_DISTANCE(vector1, vector2) -> DOUBLE```
1219+
1220+
Calculates the Euclidean (L2) distance between two vectors.
1221+
1222+
Example:
1223+
1224+
L2_DISTANCE(vector1, vector2)

docs/zh/transform-v2/sql-functions.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,3 +1150,69 @@ SELECT * FROM dual
11501150
LATERAL VIEW OUTER EXPLODE ( age ) AS age
11511151
LATERAL VIEW OUTER EXPLODE ( ARRAY(1,1) ) AS num
11521152
```
1153+
1154+
## 向量函数
1155+
1156+
### VECTOR_DIMS
1157+
1158+
```VECTOR_DIMS(vector) -> INT```
1159+
1160+
返回一个INT值,表示向量中的维数(元素)。
1161+
1162+
示例:
1163+
1164+
VECTOR_DIMS(vector)
1165+
1166+
### VECTOR_NORM
1167+
1168+
```VECTOR_NORM(vector) -> DOUBLE```
1169+
1170+
计算向量的L2范数(欧几里得范数),表示向量的长度或大小。
1171+
1172+
示例:
1173+
1174+
VECTOR_NORM(vector)
1175+
1176+
### INNER_PRODUCT
1177+
1178+
```INNER_PRODUCT(vector1, vector2) -> DOUBLE```
1179+
1180+
计算两个向量的内积(点积),用于测量向量之间的相似性和投影。
1181+
1182+
示例:
1183+
1184+
INNER_PRODUCT(vector1, vector2)
1185+
1186+
### COSINE_DISTANCE
1187+
1188+
```COSINE_DISTANCE(vector1, vector2) -> DOUBLE```
1189+
1190+
返回介于 0 和 1 之间的 DOUBLE 值:
1191+
1192+
0:相同的向量(完全相似)
1193+
1194+
1:正交向量(完全不同)
1195+
1196+
示例:
1197+
1198+
COSINE_DISTANCE(vector1, vector2)
1199+
1200+
### L1_DISTANCE
1201+
1202+
```L1_DISTANCE(vector1, vector2) -> DOUBLE```
1203+
1204+
计算两个向量之间的曼哈顿(L1)距离。
1205+
1206+
示例:
1207+
1208+
L1_DISTANCE(vector1, vector2)
1209+
1210+
### L2_DISTANCE
1211+
1212+
```L2_DISTANCE(vector1, vector2) -> DOUBLE```
1213+
1214+
计算两个向量之间的欧几里得(L2)距离。
1215+
1216+
示例:
1217+
1218+
L2_DISTANCE(vector1, vector2)

seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/BufferUtils.java renamed to seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/VectorUtils.java

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535

3636
import java.nio.Buffer;
3737
import java.nio.ByteBuffer;
38+
import java.util.Arrays;
39+
import java.util.Map;
3840

39-
public class BufferUtils {
41+
public class VectorUtils {
4042

4143
public static ByteBuffer toByteBuffer(Short[] shortArray) {
4244
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2);
@@ -127,4 +129,46 @@ public static Integer[] toIntArray(ByteBuffer byteBuffer) {
127129

128130
return intArray;
129131
}
132+
133+
public static Float[] convertSparseVectorToFloatArray(Map<?, ?> sparseVector) {
134+
if (sparseVector.isEmpty()) {
135+
return new Float[0];
136+
}
137+
int maxIndex = -1;
138+
for (Map.Entry<?, ?> entry : sparseVector.entrySet()) {
139+
Object key = entry.getKey();
140+
if (!(key instanceof Integer)) {
141+
throw new IllegalArgumentException(
142+
String.format(
143+
"Sparse vector key must be Integer, but got: %s,",
144+
key.getClass().getName()));
145+
}
146+
int index = (Integer) key;
147+
if (index < 0) {
148+
throw new IllegalArgumentException(
149+
String.format("Sparse vector index cannot be negative: %d", index));
150+
}
151+
// prevent OOM
152+
if (index > 1000000) {
153+
throw new IllegalArgumentException(
154+
String.format("Sparse vector index too large: %d", index));
155+
}
156+
maxIndex = Math.max(maxIndex, index);
157+
}
158+
Float[] denseVector = new Float[maxIndex + 1];
159+
Arrays.fill(denseVector, 0.0f);
160+
for (Map.Entry<?, ?> entry : sparseVector.entrySet()) {
161+
Object key = entry.getKey();
162+
Object value = entry.getValue();
163+
if (!(value instanceof Number)) {
164+
throw new IllegalArgumentException(
165+
String.format(
166+
"Sparse vector value must be a Number, but got: %s",
167+
value.getClass().getName()));
168+
}
169+
int index = (Integer) key;
170+
denseVector[index] = ((Number) value).floatValue();
171+
}
172+
return denseVector;
173+
}
130174
}

seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/BufferUtilsTest.java renamed to seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/VectorUtilsTest.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,40 @@
2222

2323
import java.nio.ByteBuffer;
2424

25-
public class BufferUtilsTest {
25+
public class VectorUtilsTest {
2626

2727
@Test
2828
public void testToByteBufferAndToShortArray() {
2929
Short[] shortArray = {1, 2, 3, 4, 5};
30-
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(shortArray);
31-
Short[] resultArray = BufferUtils.toShortArray(byteBuffer);
30+
ByteBuffer byteBuffer = VectorUtils.toByteBuffer(shortArray);
31+
Short[] resultArray = VectorUtils.toShortArray(byteBuffer);
3232

3333
Assertions.assertArrayEquals(shortArray, resultArray, "Short array conversion failed");
3434
}
3535

3636
@Test
3737
public void testToByteBufferAndToFloatArray() {
3838
Float[] floatArray = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f};
39-
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(floatArray);
40-
Float[] resultArray = BufferUtils.toFloatArray(byteBuffer);
39+
ByteBuffer byteBuffer = VectorUtils.toByteBuffer(floatArray);
40+
Float[] resultArray = VectorUtils.toFloatArray(byteBuffer);
4141

4242
Assertions.assertArrayEquals(floatArray, resultArray, "Float array conversion failed");
4343
}
4444

4545
@Test
4646
public void testToByteBufferAndToDoubleArray() {
4747
Double[] doubleArray = {1.1, 2.2, 3.3, 4.4, 5.5};
48-
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(doubleArray);
49-
Double[] resultArray = BufferUtils.toDoubleArray(byteBuffer);
48+
ByteBuffer byteBuffer = VectorUtils.toByteBuffer(doubleArray);
49+
Double[] resultArray = VectorUtils.toDoubleArray(byteBuffer);
5050

5151
Assertions.assertArrayEquals(doubleArray, resultArray, "Double array conversion failed");
5252
}
5353

5454
@Test
5555
public void testToByteBufferAndToIntArray() {
5656
Integer[] intArray = {1, 2, 3, 4, 5};
57-
ByteBuffer byteBuffer = BufferUtils.toByteBuffer(intArray);
58-
Integer[] resultArray = BufferUtils.toIntArray(byteBuffer);
57+
ByteBuffer byteBuffer = VectorUtils.toByteBuffer(intArray);
58+
Integer[] resultArray = VectorUtils.toIntArray(byteBuffer);
5959

6060
Assertions.assertArrayEquals(intArray, resultArray, "Integer array conversion failed");
6161
}
@@ -64,26 +64,26 @@ public void testToByteBufferAndToIntArray() {
6464
public void testEmptyArrayConversion() {
6565
// Test empty arrays
6666
Short[] shortArray = {};
67-
ByteBuffer shortBuffer = BufferUtils.toByteBuffer(shortArray);
68-
Short[] shortResultArray = BufferUtils.toShortArray(shortBuffer);
67+
ByteBuffer shortBuffer = VectorUtils.toByteBuffer(shortArray);
68+
Short[] shortResultArray = VectorUtils.toShortArray(shortBuffer);
6969
Assertions.assertArrayEquals(
7070
shortArray, shortResultArray, "Empty Short array conversion failed");
7171

7272
Float[] floatArray = {};
73-
ByteBuffer floatBuffer = BufferUtils.toByteBuffer(floatArray);
74-
Float[] floatResultArray = BufferUtils.toFloatArray(floatBuffer);
73+
ByteBuffer floatBuffer = VectorUtils.toByteBuffer(floatArray);
74+
Float[] floatResultArray = VectorUtils.toFloatArray(floatBuffer);
7575
Assertions.assertArrayEquals(
7676
floatArray, floatResultArray, "Empty Float array conversion failed");
7777

7878
Double[] doubleArray = {};
79-
ByteBuffer doubleBuffer = BufferUtils.toByteBuffer(doubleArray);
80-
Double[] doubleResultArray = BufferUtils.toDoubleArray(doubleBuffer);
79+
ByteBuffer doubleBuffer = VectorUtils.toByteBuffer(doubleArray);
80+
Double[] doubleResultArray = VectorUtils.toDoubleArray(doubleBuffer);
8181
Assertions.assertArrayEquals(
8282
doubleArray, doubleResultArray, "Empty Double array conversion failed");
8383

8484
Integer[] intArray = {};
85-
ByteBuffer intBuffer = BufferUtils.toByteBuffer(intArray);
86-
Integer[] intResultArray = BufferUtils.toIntArray(intBuffer);
85+
ByteBuffer intBuffer = VectorUtils.toByteBuffer(intArray);
86+
Integer[] intResultArray = VectorUtils.toIntArray(intBuffer);
8787
Assertions.assertArrayEquals(
8888
intArray, intResultArray, "Empty Integer array conversion failed");
8989
}

seatunnel-connectors-v2/connector-elasticsearch/src/main/java/org/apache/seatunnel/connectors/seatunnel/elasticsearch/serialize/ElasticsearchRowSerializer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
2525
import org.apache.seatunnel.common.exception.CommonError;
2626
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
27-
import org.apache.seatunnel.common.utils.BufferUtils;
27+
import org.apache.seatunnel.common.utils.VectorUtils;
2828
import org.apache.seatunnel.connectors.seatunnel.elasticsearch.dto.ElasticsearchClusterInfo;
2929
import org.apache.seatunnel.connectors.seatunnel.elasticsearch.dto.IndexInfo;
3030
import org.apache.seatunnel.connectors.seatunnel.elasticsearch.exception.ElasticsearchConnectorException;
@@ -218,7 +218,7 @@ private Object convertValue(String fieldName, Object value) {
218218
// Check if this field is configured as a vectorization field
219219
if (vectorizationFields != null && vectorizationFields.contains(fieldName)) {
220220
ByteBuffer buffer = (ByteBuffer) value;
221-
Float[] floats = BufferUtils.toFloatArray(buffer);
221+
Float[] floats = VectorUtils.toFloatArray(buffer);
222222

223223
// Use the configured dimension or calculate it from the buffer size
224224
int dimension = vectorDimension > 0 ? vectorDimension : buffer.remaining() / 4;
@@ -232,7 +232,7 @@ private Object convertValue(String fieldName, Object value) {
232232
} else {
233233
// Default behavior for ByteBuffer fields not specified as vectorization fields
234234
ByteBuffer buffer = (ByteBuffer) value;
235-
Float[] floats = BufferUtils.toFloatArray(buffer);
235+
Float[] floats = VectorUtils.toFloatArray(buffer);
236236
int floatCount = buffer.remaining() / 4;
237237

238238
for (int i = 0; i < floatCount; i++) {

seatunnel-connectors-v2/connector-fake/src/main/java/org/apache/seatunnel/connectors/seatunnel/fake/utils/FakeDataRandomUtils.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import org.apache.seatunnel.api.table.catalog.Column;
2121
import org.apache.seatunnel.api.table.type.DecimalType;
22-
import org.apache.seatunnel.common.utils.BufferUtils;
22+
import org.apache.seatunnel.common.utils.VectorUtils;
2323
import org.apache.seatunnel.connectors.seatunnel.fake.config.FakeConfig;
2424

2525
import org.apache.commons.collections4.CollectionUtils;
@@ -231,7 +231,7 @@ public ByteBuffer randomFloatVector(Column column) {
231231
RandomUtils.nextFloat(
232232
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
233233
}
234-
return BufferUtils.toByteBuffer(floatVector);
234+
return VectorUtils.toByteBuffer(floatVector);
235235
}
236236

237237
public ByteBuffer randomFloat16Vector(Column column) {
@@ -244,7 +244,7 @@ public ByteBuffer randomFloat16Vector(Column column) {
244244
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
245245
float16Vector[i] = floatToFloat16(value);
246246
}
247-
return BufferUtils.toByteBuffer(float16Vector);
247+
return VectorUtils.toByteBuffer(float16Vector);
248248
}
249249

250250
public ByteBuffer randomBFloat16Vector(Column column) {
@@ -257,7 +257,7 @@ public ByteBuffer randomBFloat16Vector(Column column) {
257257
fakeConfig.getVectorFloatMin(), fakeConfig.getVectorFloatMax());
258258
bfloat16Vector[i] = floatToBFloat16(value);
259259
}
260-
return BufferUtils.toByteBuffer(bfloat16Vector);
260+
return VectorUtils.toByteBuffer(bfloat16Vector);
261261
}
262262

263263
public Map<Integer, Float> randomSparseFloatVector(Column column) {

seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/internal/dialect/oceanbase/OceanBaseMysqlJdbcRowConverter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import org.apache.seatunnel.api.table.type.SqlType;
2626
import org.apache.seatunnel.common.exception.CommonError;
2727
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
28-
import org.apache.seatunnel.common.utils.BufferUtils;
28+
import org.apache.seatunnel.common.utils.VectorUtils;
2929
import org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorErrorCode;
3030
import org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorException;
3131
import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.converter.AbstractJdbcRowConverter;
@@ -101,7 +101,7 @@ public SeaTunnelRow toInternal(ResultSet rs, TableSchema tableSchema) throws SQL
101101
for (int i = 0; i < stringArray.length; i++) {
102102
arrays[i] = Float.parseFloat(stringArray[i]);
103103
}
104-
fields[fieldIndex] = BufferUtils.toByteBuffer(arrays);
104+
fields[fieldIndex] = VectorUtils.toByteBuffer(arrays);
105105
}
106106
break;
107107
case DOUBLE:
@@ -188,7 +188,7 @@ public PreparedStatement toExternal(
188188
if (row.getField(fieldIndex) instanceof ByteBuffer) {
189189
ByteBuffer byteBuffer = (ByteBuffer) row.getField(fieldIndex);
190190
// Convert ByteBuffer to Float[]
191-
Float[] floatArray = BufferUtils.toFloatArray(byteBuffer);
191+
Float[] floatArray = VectorUtils.toFloatArray(byteBuffer);
192192
StringBuilder vector = new StringBuilder();
193193
vector.append("[");
194194
for (Float aFloat : floatArray) {

seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/sink/MilvusSinkConverter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
2929
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
3030
import org.apache.seatunnel.api.table.type.SqlType;
31-
import org.apache.seatunnel.common.utils.BufferUtils;
3231
import org.apache.seatunnel.common.utils.JsonUtils;
32+
import org.apache.seatunnel.common.utils.VectorUtils;
3333
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectionErrorCode;
3434
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectorException;
3535

@@ -74,7 +74,7 @@ public Object convertBySeaTunnelType(
7474
return value.toString();
7575
case FLOAT_VECTOR:
7676
ByteBuffer floatVectorBuffer = (ByteBuffer) value;
77-
Float[] floats = BufferUtils.toFloatArray(floatVectorBuffer);
77+
Float[] floats = VectorUtils.toFloatArray(floatVectorBuffer);
7878
return Arrays.stream(floats).collect(Collectors.toList());
7979
case BINARY_VECTOR:
8080
case BFLOAT16_VECTOR:

0 commit comments

Comments
 (0)