Skip to content

Commit 7d46cbc

Browse files
committed
[WIP] map opensearch math functions to calcite implementations
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 86f1b9c commit 7d46cbc

File tree

2 files changed

+289
-25
lines changed

2 files changed

+289
-25
lines changed

core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,17 @@
55

66
package org.opensearch.sql.calcite.utils;
77

8-
import java.math.BigDecimal;
9-
import java.util.Locale;
10-
import org.apache.calcite.sql.SqlOperator;
11-
import org.apache.calcite.sql.fun.SqlLibraryOperators;
12-
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
13-
import java.util.ArrayList;
14-
import java.util.Collections;
15-
import java.util.List;
16-
import java.util.Locale;
17-
18-
import org.apache.calcite.linq4j.tree.Types;
198
import org.apache.calcite.rex.RexNode;
20-
import org.apache.calcite.schema.ScalarFunction;
21-
import org.apache.calcite.schema.impl.ScalarFunctionImpl;
22-
import org.apache.calcite.sql.SqlIdentifier;
23-
import org.apache.calcite.sql.SqlKind;
249
import org.apache.calcite.sql.SqlOperator;
2510
import org.apache.calcite.sql.fun.SqlLibraryOperators;
2611
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
2712
import org.apache.calcite.sql.fun.SqlTrimFunction;
28-
import org.apache.calcite.sql.parser.SqlParserPos;
29-
import org.apache.calcite.sql.type.ReturnTypes;
30-
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
3113
import org.opensearch.sql.calcite.CalcitePlanContext;
14+
15+
import java.math.BigDecimal;
16+
import java.util.ArrayList;
17+
import java.util.List;
18+
import java.util.Locale;
3219
public interface BuiltinFunctionUtils {
3320

3421
static SqlOperator translate(String op) {
@@ -62,7 +49,7 @@ static SqlOperator translate(String op) {
6249
return SqlStdOperatorTable.MULTIPLY;
6350
case "/":
6451
return SqlStdOperatorTable.DIVIDE;
65-
// Built-in String Functions
52+
// Built-in String Functions
6653
case "CONCAT":
6754
return SqlLibraryOperators.CONCAT_FUNCTION;
6855
case "CONCAT_WS":
@@ -96,12 +83,58 @@ static SqlOperator translate(String op) {
9683
return SqlStdOperatorTable.IS_NULL;
9784
case "NULLIF":
9885
return SqlStdOperatorTable.NULLIF;
99-
// Built-in Math Functions
86+
// Built-in Math Functions
10087
case "ABS":
10188
return SqlStdOperatorTable.ABS;
89+
case "ACOS":
90+
return SqlStdOperatorTable.ACOS;
91+
case "ASIN":
92+
return SqlStdOperatorTable.ASIN;
10293
case "ATAN", "ATAN2":
10394
return SqlStdOperatorTable.ATAN2;
104-
// Built-in Date Functions
95+
case "CEILING":
96+
return SqlStdOperatorTable.CEIL;
97+
case "CONV":
98+
return SqlStdOperatorTable.CONVERT;
99+
case "COS":
100+
return SqlStdOperatorTable.COS;
101+
case "COT":
102+
return SqlStdOperatorTable.COT;
103+
case "DEGREES":
104+
return SqlStdOperatorTable.DEGREES;
105+
case "EXP":
106+
return SqlStdOperatorTable.EXP;
107+
case "FLOOR":
108+
return SqlStdOperatorTable.FLOOR;
109+
case "LN":
110+
return SqlStdOperatorTable.LN;
111+
case "LOG":
112+
return SqlLibraryOperators.LOG;
113+
case "LOG2":
114+
return SqlLibraryOperators.LOG2;
115+
case "LOG10":
116+
return SqlStdOperatorTable.LOG10;
117+
case "MOD":
118+
return SqlStdOperatorTable.MOD;
119+
case "PI":
120+
return SqlStdOperatorTable.PI;
121+
case "POW", "POWER":
122+
return SqlStdOperatorTable.POWER;
123+
case "RADIANS":
124+
return SqlStdOperatorTable.RADIANS;
125+
case "RAND":
126+
return SqlStdOperatorTable.RAND;
127+
case "ROUND":
128+
return SqlStdOperatorTable.ROUND;
129+
case "SIGN":
130+
return SqlStdOperatorTable.SIGN;
131+
case "SIN":
132+
return SqlStdOperatorTable.SIN;
133+
case "SQRT":
134+
return SqlStdOperatorTable.SQRT;
135+
case "CBRT":
136+
return SqlStdOperatorTable.CBRT;
137+
// Built-in Date Functions
105138
case "CURRENT_TIMESTAMP":
106139
return SqlStdOperatorTable.CURRENT_TIMESTAMP;
107140
case "CURRENT_DATE":

integ-test/src/test/java/org/opensearch/sql/calcite/CalcitePPLBasicIT.java

Lines changed: 235 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,20 @@
55

66
package org.opensearch.sql.calcite;
77

8-
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
9-
10-
import java.io.IOException;
8+
import com.google.gson.JsonArray;
9+
import com.google.gson.JsonObject;
10+
import com.google.gson.JsonParser;
1111
import org.junit.Ignore;
1212
import org.junit.jupiter.api.Test;
1313
import org.opensearch.client.Request;
1414

15+
import java.io.IOException;
16+
import java.math.BigDecimal;
17+
import java.util.List;
18+
19+
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
20+
21+
1522
public class CalcitePPLBasicIT extends CalcitePPLIntegTestCase {
1623

1724
@Override
@@ -23,7 +30,9 @@ public void init() throws IOException {
2330
Request request2 = new Request("PUT", "/test/_doc/2?refresh=true");
2431
request2.setJsonEntity("{\"name\": \"world\", \"age\": 30}");
2532
client().performRequest(request2);
26-
33+
Request request3 = new Request("PUT", "/people/_doc/2?refresh=true");
34+
request3.setJsonEntity("{\"name\": \"DummyEntityForMathVerification\", \"age\": 24}");
35+
client().performRequest(request3);
2736
loadIndex(Index.BANK);
2837
}
2938

@@ -1064,4 +1073,226 @@ public void testUpper() {
10641073
+ "}",
10651074
actual);
10661075
}
1076+
1077+
private static JsonArray parseAndGetFirstDataRow(String executionResult) {
1078+
JsonObject sqrtResJson = JsonParser.parseString(executionResult).getAsJsonObject();
1079+
JsonArray dataRows = sqrtResJson.getAsJsonArray("datarows");
1080+
return dataRows.get(0).getAsJsonArray();
1081+
}
1082+
1083+
private void testMathPPL(String query, List<? extends Number> expectedValues){
1084+
String execResult = execute(query);
1085+
JsonArray dataRow = parseAndGetFirstDataRow(execResult);
1086+
assertEquals(expectedValues.size(), dataRow.size());
1087+
for (int i = 0; i < expectedValues.size(); i++){
1088+
Number expected = expectedValues.get(i);
1089+
Number actual = dataRow.get(i).getAsNumber();
1090+
if (expected instanceof BigDecimal) {
1091+
assertEquals(expected, actual);
1092+
}
1093+
else if (expected instanceof Double || expected instanceof Float) {
1094+
assertDoubleUlpEquals(expected.doubleValue(), actual.doubleValue(), 8);
1095+
} else if (expected instanceof Long || expected instanceof Integer) {
1096+
assertEquals(expected.longValue(), actual.longValue());
1097+
} else {
1098+
fail("Unsupported number type: " + expected.getClass().getName());
1099+
}
1100+
}
1101+
}
1102+
1103+
@Test
1104+
public void testAbs() {
1105+
String absPpl = "source=people | eval `ABS(-1)` = ABS(-1) | fields `ABS(-1)`";
1106+
List<Integer> expected = List.of(1);
1107+
testMathPPL(absPpl, expected);
1108+
}
1109+
1110+
@Test
1111+
public void testAcos() {
1112+
String acosPpl = "source=people | eval `ACOS(0)` = ACOS(0) | fields `ACOS(0)`";
1113+
List<Double> expected = List.of(Math.PI / 2);
1114+
testMathPPL(acosPpl, expected);
1115+
}
1116+
1117+
@Test
1118+
public void testAsin() {
1119+
String asinPpl = "source=people | eval `ASIN(0)` = ASIN(0) | fields `ASIN(0)`";
1120+
List<Double> expected = List.of(0.0);
1121+
testMathPPL(asinPpl, expected);
1122+
}
1123+
1124+
@Test
1125+
public void testAtan() {
1126+
// TODO: Error while preparing plan [LogicalProject(ATAN(2)=[ATAN(2)], ATAN(2, 3)=[ATAN(2, 3)])
1127+
// ATAN defined in OpenSearch accepts single and double arguments, while that defined in SQL standard library accepts only single argument.
1128+
testMathPPL("source=people | eval `ATAN(2)` = ATAN(2), `ATAN(2, 3)` = ATAN(2, 3) | fields `ATAN(2)`, `ATAN(2, 3)`", List.of(Math.atan(2), Math.atan2(2, 3)));
1129+
}
1130+
1131+
@Test
1132+
public void testAtan2() {
1133+
testMathPPL("source=people | eval `ATAN2(2, 3)` = ATAN2(2, 3) | fields `ATAN2(2, 3)`", List.of(Math.atan2(2, 3)));
1134+
}
1135+
1136+
@Test
1137+
public void testCeiling() {
1138+
testMathPPL(
1139+
"source=people | eval `CEILING(0)` = CEILING(0), `CEILING(50.00005)` = CEILING(50.00005), `CEILING(-50.00005)` = CEILING(-50.00005) | fields `CEILING(0)`, `CEILING(50.00005)`, `CEILING(-50.00005)`",
1140+
List.of(Math.ceil(0.0), Math.ceil(50.00005), Math.ceil(-50.00005)));
1141+
testMathPPL(
1142+
"source=people | eval `CEILING(3147483647.12345)` = CEILING(3147483647.12345), `CEILING(113147483647.12345)` = CEILING(113147483647.12345), `CEILING(3147483647.00001)` = CEILING(3147483647.00001) | fields `CEILING(3147483647.12345)`, `CEILING(113147483647.12345)`, `CEILING(3147483647.00001)`",
1143+
List.of(Math.ceil(3147483647.12345), Math.ceil(113147483647.12345), Math.ceil(3147483647.00001)));
1144+
}
1145+
1146+
@Test
1147+
public void testConv() {
1148+
// TODO: Error while preparing plan [LogicalProject(CONV('12', 10, 16)=[CONVERT('12', 10, 16)], CONV('2C', 16, 10)=[CONVERT('2C', 16, 10)], CONV(12, 10, 2)=[CONVERT(12, 10, 2)], CONV(1111, 2, 10)=[CONVERT(1111, 2, 10)])
1149+
// OpenSearchTableScan(table=[[OpenSearch, people]])
1150+
String convPpl = "source=people | eval `CONV('12', 10, 16)` = CONV('12', 10, 16), `CONV('2C', 16, 10)` = CONV('2C', 16, 10), `CONV(12, 10, 2)` = CONV(12, 10, 2), `CONV(1111, 2, 10)` = CONV(1111, 2, 10) | fields `CONV('12', 10, 16)`, `CONV('2C', 16, 10)`, `CONV(12, 10, 2)`, `CONV(1111, 2, 10)`";
1151+
String execResult = execute(convPpl);
1152+
JsonArray dataRow = parseAndGetFirstDataRow(execResult);
1153+
assertEquals(4, dataRow.size());
1154+
assertEquals("c", dataRow.get(0).getAsString());
1155+
assertEquals("44", dataRow.get(1).getAsString());
1156+
assertEquals("1100", dataRow.get(2).getAsString());
1157+
assertEquals("15", dataRow.get(3).getAsString());
1158+
}
1159+
1160+
@Test
1161+
public void testCos() {
1162+
testMathPPL("source=people | eval `COS(0)` = COS(0) | fields `COS(0)`", List.of(1.0));
1163+
}
1164+
1165+
@Test
1166+
public void testCot() {
1167+
testMathPPL("source=people | eval `COT(1)` = COT(1) | fields `COT(1)`", List.of(1.0 / Math.tan(1)));
1168+
}
1169+
1170+
@Test
1171+
public void testCrc32() {
1172+
//TODO: No corresponding built-in implementation
1173+
testMathPPL("source=people | eval `CRC32('MySQL')` = CRC32('MySQL') | fields `CRC32('MySQL')`", List.of(3259397556L));
1174+
}
1175+
1176+
@Test
1177+
public void testDegrees() {
1178+
testMathPPL("source=people | eval `DEGREES(1.57)` = DEGREES(1.57) | fields `DEGREES(1.57)`", List.of(Math.toDegrees(1.57)));
1179+
}
1180+
1181+
@Test
1182+
public void testEuler() {
1183+
//TODO: No corresponding built-in implementation
1184+
testMathPPL("source=people | eval `E()` = E() | fields `E()`", List.of(Math.E));
1185+
}
1186+
1187+
@Test
1188+
public void testExp() {
1189+
testMathPPL("source=people | eval `EXP(2)` = EXP(2) | fields `EXP(2)`", List.of(Math.exp(2)));
1190+
}
1191+
1192+
@Test
1193+
public void testFloor() {
1194+
testMathPPL(
1195+
"source=people | eval `FLOOR(0)` = FLOOR(0), `FLOOR(50.00005)` = FLOOR(50.00005), `FLOOR(-50.00005)` = FLOOR(-50.00005) | fields `FLOOR(0)`, `FLOOR(50.00005)`, `FLOOR(-50.00005)`",
1196+
List.of(Math.floor(0.0), Math.floor(50.00005), Math.floor(-50.00005)));
1197+
testMathPPL(
1198+
"source=people | eval `FLOOR(3147483647.12345)` = FLOOR(3147483647.12345), `FLOOR(113147483647.12345)` = FLOOR(113147483647.12345), `FLOOR(3147483647.00001)` = FLOOR(3147483647.00001) | fields `FLOOR(3147483647.12345)`, `FLOOR(113147483647.12345)`, `FLOOR(3147483647.00001)`",
1199+
List.of(Math.floor(3147483647.12345), Math.floor(113147483647.12345), Math.floor(3147483647.00001)));
1200+
testMathPPL(
1201+
"source=people | eval `FLOOR(282474973688888.022)` = FLOOR(282474973688888.022), `FLOOR(9223372036854775807.022)` = FLOOR(9223372036854775807.022), `FLOOR(9223372036854775807.0000001)` = FLOOR(9223372036854775807.0000001) | fields `FLOOR(282474973688888.022)`, `FLOOR(9223372036854775807.022)`, `FLOOR(9223372036854775807.0000001)`",
1202+
List.of(Math.floor(282474973688888.022), Math.floor(9223372036854775807.022), Math.floor(9223372036854775807.0000001)));
1203+
}
1204+
1205+
@Test
1206+
public void testLn() {
1207+
testMathPPL("source=people | eval `LN(2)` = LN(2) | fields `LN(2)`", List.of(Math.log(2)));
1208+
}
1209+
1210+
@Test
1211+
public void testLog() {
1212+
// TODO: No built-in function for 2-operand log
1213+
testMathPPL("source=people | eval `LOG(2)` = LOG(2), `LOG(2, 8)` = LOG(2, 8) | fields `LOG(2)`, `LOG(2, 8)`", List.of(Math.log(2), Math.log(8) / Math.log(2)));
1214+
}
1215+
1216+
@Test
1217+
public void testLog2() {
1218+
testMathPPL("source=people | eval `LOG2(8)` = LOG2(8) | fields `LOG2(8)`", List.of(Math.log(8) / Math.log(2)));
1219+
}
1220+
1221+
@Test
1222+
public void testLog10() {
1223+
testMathPPL("source=people | eval `LOG10(100)` = LOG10(100) | fields `LOG10(100)`", List.of(Math.log10(100)));
1224+
}
1225+
1226+
@Test
1227+
public void testMod() {
1228+
// TODO: There is a difference between MOD in OpenSearch and SQL standard library
1229+
// For MOD in Calcite, MOD(3.1, 2) = 1
1230+
testMathPPL(
1231+
"source=people | eval `MOD(3, 2)` = MOD(3, 2), `MOD(3.1, 2)` = MOD(3.1, 2) | fields `MOD(3, 2)`, `MOD(3.1, 2)`",
1232+
List.of(1, 1.1));
1233+
}
1234+
1235+
@Test
1236+
public void testPi() {
1237+
testMathPPL("source=people | eval `PI()` = PI() | fields `PI()`", List.of(Math.PI));
1238+
}
1239+
1240+
@Test
1241+
public void testPowAndPower() {
1242+
testMathPPL(
1243+
"source=people | eval `POW(3, 2)` = POW(3, 2), `POW(-3, 2)` = POW(-3, 2), `POW(3, -2)` = POW(3, -2) | fields `POW(3, 2)`, `POW(-3, 2)`, `POW(3, -2)`",
1244+
List.of(Math.pow(3, 2), Math.pow(-3, 2), Math.pow(3, -2)));
1245+
testMathPPL(
1246+
"source=people | eval `POWER(3, 2)` = POWER(3, 2), `POWER(-3, 2)` = POWER(-3, 2), `POWER(3, -2)` = POWER(3, -2) | fields `POWER(3, 2)`, `POWER(-3, 2)`, `POWER(3, -2)`",
1247+
List.of(Math.pow(3, 2), Math.pow(-3, 2), Math.pow(3, -2)));
1248+
}
1249+
1250+
@Test
1251+
public void testRadians() {
1252+
testMathPPL("source=people | eval `RADIANS(90)` = RADIANS(90) | fields `RADIANS(90)`", List.of(Math.toRadians(90)));
1253+
}
1254+
1255+
@Test
1256+
public void testRand() {
1257+
String randPpl = "source=people | eval `RAND(3)` = RAND(3) | fields `RAND(3)`";
1258+
String execResult1 = execute(randPpl);
1259+
String execResult2 = execute(randPpl);
1260+
assertEquals(execResult1, execResult2);
1261+
double val = parseAndGetFirstDataRow(execResult1).get(0).getAsDouble();
1262+
assertTrue(val >= 0 && val <= 1);
1263+
}
1264+
1265+
@Test
1266+
public void testRound() {
1267+
testMathPPL(
1268+
"source=people | eval `ROUND(12.34)` = ROUND(12.34), `ROUND(12.34, 1)` = ROUND(12.34, 1), `ROUND(12.34, -1)` = ROUND(12.34, -1), `ROUND(12, 1)` = ROUND(12, 1) | fields `ROUND(12.34)`, `ROUND(12.34, 1)`, `ROUND(12.34, -1)`, `ROUND(12, 1)`",
1269+
List.of(Math.round(12.34), Math.round(12.34 * 10) / 10.0, Math.round(12.34 / 10) * 10.0, Math.round(12.0 * 10) / 10.0)
1270+
);
1271+
}
1272+
1273+
@Test
1274+
public void testSign() {
1275+
testMathPPL(
1276+
"source=people | eval `SIGN(1)` = SIGN(1), `SIGN(0)` = SIGN(0), `SIGN(-1.1)` = SIGN(-1.1) | fields `SIGN(1)`, `SIGN(0)`, `SIGN(-1.1)`",
1277+
List.of(1, 0, -1)
1278+
);
1279+
}
1280+
1281+
@Test
1282+
public void testSin() {
1283+
testMathPPL("source=people | eval `SIN(0)` = SIN(0) | fields `SIN(0)`", List.of(Math.sin(0.0)));
1284+
}
1285+
1286+
@Test
1287+
public void testSqrt() {
1288+
testMathPPL("source=people | eval `SQRT(4)` = SQRT(4), `SQRT(4.41)` = SQRT(4.41) | fields `SQRT(4)`, `SQRT(4.41)`", List.of(Math.sqrt(4), Math.sqrt(4.41)));
1289+
}
1290+
1291+
@Test
1292+
public void testCbrt() {
1293+
testMathPPL(
1294+
"source=people | eval `CBRT(8)` = CBRT(8), `CBRT(9.261)` = CBRT(9.261), `CBRT(-27)` = CBRT(-27) | fields `CBRT(8)`, `CBRT(9.261)`, `CBRT(-27)`",
1295+
List.of(Math.cbrt(8), Math.cbrt(9.261), Math.cbrt(-27))
1296+
);
1297+
}
10671298
}

0 commit comments

Comments
 (0)