Skip to content

Commit c115613

Browse files
authored
[Feature][Transform V2] Add vector dimension reduction transform (apache#9783)
1 parent 507e000 commit c115613

File tree

8 files changed

+639
-2
lines changed

8 files changed

+639
-2
lines changed

docs/en/transform-v2/sql-functions.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,4 +1221,44 @@ Calculates the Euclidean (L2) distance between two vectors.
12211221

12221222
Example:
12231223

1224-
L2_DISTANCE(vector1, vector2)
1224+
L2_DISTANCE(vector1, vector2)
1225+
1226+
### VECTOR_REDUCE
1227+
1228+
```VECTOR_REDUCE(vector_field, target_dimension, method)```
1229+
1230+
Generic vector dimension reduction function that supports multiple reduction methods.
1231+
1232+
**Parameters:**
1233+
- `vector_field`: The vector field to reduce (VECTOR type)
1234+
- `target_dimension`: The target dimension (INTEGER, must be smaller than source dimension)
1235+
- `method`: The reduction method (STRING):
1236+
- **'TRUNCATE'**: Truncates the vector by keeping only the first N elements. This is the simplest and fastest dimension reduction method, but may lose important information in the truncated dimensions.
1237+
- **'RANDOM_PROJECTION'**: Uses Gaussian random projection with normally distributed random matrix. This method preserves relative distances between vectors while reducing dimensionality, following the Johnson-Lindenstrauss lemma.
1238+
- **'SPARSE_RANDOM_PROJECTION'**: Uses sparse random projection where matrix elements are mostly zero (±√3, 0). This is more computationally efficient than regular random projection while maintaining similar distance preservation properties.
1239+
1240+
**Returns:** VECTOR type with reduced dimensions
1241+
1242+
**Example:**
1243+
```sql
1244+
SELECT id, VECTOR_REDUCE(embedding, 256, 'TRUNCATE') as reduced_embedding FROM table
1245+
SELECT id, VECTOR_REDUCE(embedding, 128, 'RANDOM_PROJECTION') as reduced_embedding FROM table
1246+
SELECT id, VECTOR_REDUCE(embedding, 64, 'SPARSE_RANDOM_PROJECTION') as reduced_embedding FROM table
1247+
```
1248+
1249+
### VECTOR_NORMALIZE
1250+
1251+
```VECTOR_NORMALIZE(vector_field)```
1252+
1253+
Normalizes a vector to unit length (magnitude = 1). This is useful for computing cosine similarity.
1254+
1255+
**Parameters:**
1256+
- `vector_field`: The vector field to normalize (VECTOR type)
1257+
1258+
**Returns:** VECTOR type - the normalized vector
1259+
1260+
**Example:**
1261+
```sql
1262+
SELECT id, VECTOR_NORMALIZE(embedding) as normalized_embedding FROM table
1263+
```
1264+

docs/zh/transform-v2/sql-functions.md

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,4 +1215,43 @@ L1_DISTANCE(vector1, vector2)
12151215

12161216
示例:
12171217

1218-
L2_DISTANCE(vector1, vector2)
1218+
L2_DISTANCE(vector1, vector2)
1219+
1220+
### VECTOR_REDUCE
1221+
1222+
```VECTOR_REDUCE(vector_field, target_dimension, method)```
1223+
1224+
通用向量降维函数,支持多种降维方法。
1225+
1226+
**参数:**
1227+
- `vector_field`: 要降维的向量字段 (VECTOR 类型)
1228+
- `target_dimension`: 目标维度 (INTEGER,必须小于源维度)
1229+
- `method`: 降维方法 (STRING):
1230+
- **'TRUNCATE'**: 截断法,通过保留前N个元素来缩减向量维度。这是最简单、最快速的降维方法,但可能会丢失被截断维度中的重要信息。
1231+
- **'RANDOM_PROJECTION'**: 随机投影法,使用高斯随机投影和正态分布的随机矩阵。该方法在降维的同时保持向量间的相对距离,遵循Johnson-Lindenstrauss引理。
1232+
- **'SPARSE_RANDOM_PROJECTION'**: 稀疏随机投影法,矩阵元素大多为零(±√3, 0)。比常规随机投影在计算上更高效,同时保持相似的距离保持特性。
1233+
1234+
**返回值:** 降维后的 VECTOR 类型
1235+
1236+
**示例:**
1237+
```sql
1238+
SELECT id, VECTOR_REDUCE(embedding, 256, 'TRUNCATE') as reduced_embedding FROM table
1239+
SELECT id, VECTOR_REDUCE(embedding, 128, 'RANDOM_PROJECTION') as reduced_embedding FROM table
1240+
SELECT id, VECTOR_REDUCE(embedding, 64, 'SPARSE_RANDOM_PROJECTION') as reduced_embedding FROM table
1241+
```
1242+
1243+
### VECTOR_NORMALIZE
1244+
1245+
```VECTOR_NORMALIZE(vector_field)```
1246+
1247+
将向量归一化为单位长度(模长 = 1)。这对于计算余弦相似度很有用。
1248+
1249+
**参数:**
1250+
- `vector_field`: 要归一化的向量字段 (VECTOR 类型)
1251+
1252+
**返回值:** VECTOR 类型 - 归一化后的向量
1253+
1254+
**示例:**
1255+
```sql
1256+
SELECT id, VECTOR_NORMALIZE(embedding) as normalized_embedding FROM table
1257+
```

seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ public void testSQLTransform(TestContainer container) throws IOException, Interr
8686
Assertions.assertEquals(0, multiIfSql.getExitCode());
8787
}
8888

89+
@TestTemplate
90+
@DisabledOnContainer(
91+
value = {},
92+
type = {EngineType.SPARK},
93+
disabledReason = "Vector functions are not supported in Spark engine")
94+
public void testVectorFunctions(TestContainer container)
95+
throws IOException, InterruptedException {
96+
Container.ExecResult vectorFunctionResult =
97+
container.executeJob("/sql_transform/func_vector.conf");
98+
Assertions.assertEquals(0, vectorFunctionResult.getExitCode());
99+
}
100+
89101
@TestTemplate
90102
public void testSQLTransformMultiTable(TestContainer container)
91103
throws IOException, InterruptedException {
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
######
18+
###### This config file is a demonstration of vector functions in SQL transform
19+
######
20+
21+
env {
22+
parallelism = 1
23+
job.mode = "BATCH"
24+
checkpoint.interval = 10000
25+
}
26+
27+
source {
28+
FakeSource {
29+
plugin_output = "fake"
30+
schema = {
31+
fields {
32+
id = "int"
33+
name = "string"
34+
vector_field = "array<float>"
35+
vector_field2 = "array<float>"
36+
}
37+
}
38+
rows = [
39+
{
40+
fields = [1, "test1", [1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]
41+
kind = INSERT
42+
},
43+
{
44+
fields = [2, "test2", [2.0, 4.0, 6.0, 8.0, 10.0], [0.6, 0.8, 0.0, 0.0, 0.0]]
45+
kind = INSERT
46+
},
47+
{
48+
fields = [3, "test3", [3.0, 4.0, 0.0, 0.0, 0.0], [3.0, 4.0, 0.0, 0.0, 0.0]]
49+
kind = INSERT
50+
}
51+
]
52+
}
53+
}
54+
55+
transform {
56+
Sql {
57+
plugin_input = "fake"
58+
plugin_output = "fake1"
59+
query = """SELECT
60+
id,
61+
name,
62+
VECTOR_DIMS(vector_field) as original_dim,
63+
VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'TRUNCATE')) as truncated_dim,
64+
VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'RANDOM_PROJECTION')) as projected_dim,
65+
VECTOR_DIMS(VECTOR_REDUCE(vector_field, 3, 'SPARSE_RANDOM_PROJECTION')) as sparse_projected_dim,
66+
VECTOR_DIMS(VECTOR_NORMALIZE(vector_field)) as normalized_dim
67+
FROM dual"""
68+
}
69+
}
70+
71+
sink {
72+
Assert {
73+
plugin_input = "fake1"
74+
rules = {
75+
field_rules = [
76+
{
77+
field_name = "id"
78+
field_type = "int"
79+
field_value = [
80+
{
81+
rule_type = NOT_NULL
82+
}
83+
]
84+
},
85+
{
86+
field_name = "name"
87+
field_type = "string"
88+
field_value = [
89+
{
90+
rule_type = NOT_NULL
91+
}
92+
]
93+
},
94+
{
95+
field_name = "original_dim"
96+
field_type = "int"
97+
field_value = [
98+
{equals_to = 5}
99+
]
100+
},
101+
{
102+
field_name = "truncated_dim"
103+
field_type = "int"
104+
field_value = [
105+
{equals_to = 3}
106+
]
107+
},
108+
{
109+
field_name = "projected_dim"
110+
field_type = "int"
111+
field_value = [
112+
{equals_to = 3}
113+
]
114+
},
115+
{
116+
field_name = "sparse_projected_dim"
117+
field_type = "int"
118+
field_value = [
119+
{equals_to = 3}
120+
]
121+
},
122+
{
123+
field_name = "normalized_dim"
124+
field_type = "int"
125+
field_value = [
126+
{equals_to = 5}
127+
]
128+
}
129+
]
130+
row_rules = [
131+
{
132+
rule_type = MAX_ROW
133+
rule_value = 3
134+
},
135+
{
136+
rule_type = MIN_ROW
137+
rule_value = 3
138+
}
139+
]
140+
}
141+
}
142+
}

seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ public class ZetaSQLFunction {
212212
public static final String VECTOR_NORM = "VECTOR_NORM";
213213
public static final String INNER_PRODUCT = "INNER_PRODUCT";
214214

215+
public static final String VECTOR_REDUCE = "VECTOR_REDUCE";
216+
public static final String VECTOR_NORMALIZE = "VECTOR_NORMALIZE";
217+
215218
private final SeaTunnelRowType inputRowType;
216219

217220
private final ZetaSQLType zetaSQLType;
@@ -619,6 +622,11 @@ public Object executeFunctionExpr(
619622
return VectorFunction.vectorNorm(args);
620623
case INNER_PRODUCT:
621624
return VectorFunction.innerProduct(args);
625+
case VECTOR_REDUCE:
626+
return VectorFunction.vectorReduce(
627+
args.get(0), (Integer) args.get(1), (String) args.get(2));
628+
case VECTOR_NORMALIZE:
629+
return VectorFunction.vectorNormalize(args.get(0));
622630
default:
623631
for (ZetaUDF udf : udfList) {
624632
if (udf.functionName().equalsIgnoreCase(functionName)) {

seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
2626
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
2727
import org.apache.seatunnel.api.table.type.SqlType;
28+
import org.apache.seatunnel.api.table.type.VectorType;
2829
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
2930
import org.apache.seatunnel.transform.exception.TransformException;
3031
import org.apache.seatunnel.transform.sql.zeta.functions.ArrayFunction;
@@ -489,6 +490,10 @@ private SeaTunnelDataType<?> getFunctionType(Function function) {
489490
case ZetaSQLFunction.MOD:
490491
// Result has the same type as second argument
491492
return getExpressionType(function.getParameters().getExpressions().get(1));
493+
// Vector functions
494+
case ZetaSQLFunction.VECTOR_REDUCE:
495+
case ZetaSQLFunction.VECTOR_NORMALIZE:
496+
return VectorType.VECTOR_FLOAT_TYPE;
492497
default:
493498
for (ZetaUDF udf : udfList) {
494499
if (udf.functionName().equalsIgnoreCase(function.getName())) {

0 commit comments

Comments
 (0)