From 76a9aed5dfc6823c79d2c721d6332bdb6d8f8eaa Mon Sep 17 00:00:00 2001 From: shaomengwang Date: Fri, 26 Apr 2024 15:19:05 +0800 Subject: [PATCH] Merge the dev branch. --- .../common/insights/DistributionUtil.java | 144 +++++++++ .../common/insights/DvInsightDescription.java | 21 +- .../alink/common/insights/StatInsight.java | 36 ++- .../sql/builtin/BuiltInAggRegister.java | 21 ++ .../sql/builtin/string/string/DateAdd.java | 35 ++ .../sql/builtin/string/string/DateDiff.java | 58 ++++ .../sql/builtin/string/string/DateSub.java | 34 ++ .../sql/builtin/string/string/KeyValue.java | 29 ++ .../sql/builtin/string/string/RegExp.java | 28 ++ .../builtin/string/string/RegExpExtract.java | 31 ++ .../builtin/string/string/RegExpReplace.java | 28 ++ .../sql/builtin/string/string/SplitPart.java | 36 +++ .../batch/PipelinePredictBatchOp.java | 44 --- .../QuantileDiscretizerTrainBatchOp.java | 17 +- .../operator/batch/nlp/TfidfBatchOp.java | 10 +- .../NegativeItemSamplingBatchOp.java | 2 +- .../operator/batch/sql/SelectBatchOp.java | 38 +-- .../batch/statistics/QuantileBatchOp.java | 273 +++++++++------- .../common/optim/LocalFmOptimizer.java | 298 ------------------ .../operator/common/optim/LocalOptimizer.java | 7 +- .../common/sql/CalciteSelectMapper.java | 162 +++++++++- .../operator/common/sql/SelectMapper.java | 134 +++++++- .../operator/common/sql/SelectUtils.java | 37 ++- .../common/sql/functions/StringFunctions.java | 30 ++ .../basicstatistic/BaseSummary.java | 2 +- .../basicstatistic/TableSummary.java | 3 +- .../common/tree/seriestree/DecisionTree.java | 31 +- .../common/tree/seriestree/DenseData.java | 17 +- .../common/utils/PackBatchOperatorUtil.java | 10 +- .../local/sql/FullOuterJoinLocalOp.java | 47 --- .../alink/operator/local/sql/JoinLocalOp.java | 78 ----- .../local/sql/LeftOuterJoinLocalOp.java | 52 --- .../local/sql/RightOuterJoinLocalOp.java | 47 --- .../operator/local/sql/SelectLocalOp.java | 37 +-- .../stream/PipelinePredictStreamOp.java | 51 --- .../operator/stream/sql/SelectStreamOp.java | 40 +-- .../stream/statistics/QuantileStreamOp.java | 4 +- .../QuantileDiscretizerTrainParams.java | 4 +- .../statistics/QuantileBatchParams.java | 9 - .../params/statistics/QuantileParams.java | 12 +- .../statistics/QuantileStreamParams.java | 11 + .../batch/dataproc/SqlBatchOpsTest.java | 5 +- .../dataproc/TypeConvertBatchOpTest.java | 1 + .../operator/batch/sql/SelectBatchOpTest.java | 18 +- .../operator/common/sql/SelectUtilsTest.java | 16 + .../operator/local/sql/AllTypeOpTest.java | 46 +-- .../local/sql/BaseSqlApiLocalOpTest.java | 24 +- .../operator/local/sql/FilterLocalOpTest.java | 11 +- .../local/sql/FullOuterJoinLocalOpTest.java | 23 -- .../operator/local/sql/JoinLocalOpTest.java | 26 -- .../local/sql/LeftOuterJoinLocalOpTest.java | 27 -- .../local/sql/RightOuterJoinLocalOpTest.java | 24 -- .../operator/local/sql/SelectLocalOpTest.java | 121 +++++-- .../statistics/SummarizerLocalOpTest.java | 178 ++++------- .../dataproc/TypeConvertStreamOpTest.java | 2 +- .../pipeline/PipelinePredictBatchOpTest.java | 71 ----- 56 files changed, 1328 insertions(+), 1273 deletions(-) create mode 100644 core/src/main/java/com/alibaba/alink/common/insights/DistributionUtil.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateAdd.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateDiff.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateSub.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/KeyValue.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExp.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpExtract.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpReplace.java create mode 100644 core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/SplitPart.java delete mode 100644 core/src/main/java/com/alibaba/alink/operator/batch/PipelinePredictBatchOp.java delete mode 100644 core/src/main/java/com/alibaba/alink/operator/common/optim/LocalFmOptimizer.java delete mode 100644 core/src/main/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOp.java delete mode 100644 core/src/main/java/com/alibaba/alink/operator/local/sql/JoinLocalOp.java delete mode 100644 core/src/main/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOp.java delete mode 100644 core/src/main/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOp.java delete mode 100644 core/src/main/java/com/alibaba/alink/operator/stream/PipelinePredictStreamOp.java delete mode 100644 core/src/main/java/com/alibaba/alink/params/statistics/QuantileBatchParams.java create mode 100644 core/src/main/java/com/alibaba/alink/params/statistics/QuantileStreamParams.java delete mode 100644 core/src/test/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOpTest.java delete mode 100644 core/src/test/java/com/alibaba/alink/operator/local/sql/JoinLocalOpTest.java delete mode 100644 core/src/test/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOpTest.java delete mode 100644 core/src/test/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOpTest.java delete mode 100644 core/src/test/java/com/alibaba/alink/pipeline/PipelinePredictBatchOpTest.java diff --git a/core/src/main/java/com/alibaba/alink/common/insights/DistributionUtil.java b/core/src/main/java/com/alibaba/alink/common/insights/DistributionUtil.java new file mode 100644 index 000000000..03e3cf6d1 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/insights/DistributionUtil.java @@ -0,0 +1,144 @@ +package com.alibaba.alink.common.insights; + +import org.apache.flink.api.java.tuple.Tuple3; + +import breeze.stats.distributions.LogNormal; +import org.apache.commons.math3.distribution.ChiSquaredDistribution; +import org.apache.commons.math3.distribution.ExponentialDistribution; +import org.apache.commons.math3.distribution.LogNormalDistribution; +import org.apache.commons.math3.distribution.LogisticDistribution; +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.distribution.TDistribution; +import org.apache.commons.math3.distribution.UniformRealDistribution; +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; + +import java.util.Arrays; +import java.util.List; + +public class DistributionUtil { + + public static final KolmogorovSmirnovTest KS_TEST = new KolmogorovSmirnovTest(); + + public static double[] loadDataToVec(List dataList) { + double[] datas = new double[dataList.size()]; + for (int i = 0; i < dataList.size(); i++) { + datas[i] = Double.valueOf(String.valueOf(dataList.get(i))); + } + return datas; + } + + public static Tuple3 getMeanSd(List dataList) { + if (dataList.size() == 0) { + return Tuple3.of(0D, 0D, null); + } + double[] datas = loadDataToVec(dataList); + double avg = Arrays.stream(datas).average().getAsDouble(); + double variance = 0; + for (int i = 0; i < datas.length; i++) { + variance += Math.pow(datas[i] - avg, 2); + } + variance = variance / datas.length; + double sd = Math.sqrt(variance); + return Tuple3.of(avg, sd, datas); + } + + public static double testNormalDistribution(List dataList) { + if (dataList.size() == 0) { + return 0; + } + Tuple3 avgTuple = getMeanSd(dataList); + NormalDistribution distribution = new NormalDistribution(avgTuple.f0, avgTuple.f1); + return KS_TEST.kolmogorovSmirnovTest(distribution, avgTuple.f2); + } + + public static double testUniformDistribution(List dataList) { + if (dataList.size() == 0) { + return 0; + } + double[] datas = loadDataToVec(dataList); + double lower = Arrays.stream(datas).min().getAsDouble(); + double upper = Arrays.stream(datas).max().getAsDouble(); + + UniformRealDistribution distribution = new UniformRealDistribution(lower, upper); + return KS_TEST.kolmogorovSmirnovTest(distribution, datas); + } + + // if positive, return degreeOfFreedom + public static double testChiSquaredDistribution(List dataList) { + if (dataList.size() == 0) { + return 0; + } + double[] datas = loadDataToVec(dataList); + for (int i = 1; i <= 20; i++) { + ChiSquaredDistribution distribution = new ChiSquaredDistribution(i); + double p = KS_TEST.kolmogorovSmirnovTest(distribution, datas); + if (p > 0.05 ) { + return i; + } + } + return 0; + } + + public static double testExpDistribution(List dataList) { + if (dataList.size() == 0) { + return 0; + } + double[] datas = loadDataToVec(dataList); + double avg = Arrays.stream(datas).average().getAsDouble(); + ExponentialDistribution distribution = new ExponentialDistribution(avg); + return KS_TEST.kolmogorovSmirnovTest(distribution, datas); + } + + // if positive, return degreeOfFreedom + public static double testTDistribution(List dataList) { + if (dataList.size() == 0) { + return 0; + } + double[] datas = loadDataToVec(dataList); + for (int i = 1; i <= 20; i++) { + TDistribution distribution = new TDistribution(i); + double p = KS_TEST.kolmogorovSmirnovTest(distribution, datas); + if (p > 0.05) { + return i; + } + } + return 0; + } + + public static double testLogisticDistribution(List dataList) { + if (dataList.size() == 0) { + return 0; + } + Tuple3 tuple3 = getMeanSd(dataList); + double avg = tuple3.f0; + double sd = tuple3.f1; + double[] datas = tuple3.f2; + LogisticDistribution distribution = new LogisticDistribution(avg, sd); + return KS_TEST.kolmogorovSmirnovTest(distribution, datas); + } + + public static double testLogNormalDistribution(List dataList) { + if (dataList.size() == 0) { + return 0; + } + double[] datas = loadDataToVec(dataList); + double min = Arrays.stream(datas).min().getAsDouble(); + if (min <= 0) { + return 0; + } + double[] logValues = new double[datas.length]; + for (int i = 0; i < datas.length; i++) { + logValues[i] = Math.log(datas[i]); + } + double avg = Arrays.stream(logValues).average().getAsDouble(); + double var = 0; + for (int i = 0; i < logValues.length; i++) { + var += Math.pow((logValues[i] - avg), 2); + } + var = Math.sqrt(var); + LogNormalDistribution distribution = new LogNormalDistribution(avg, var); + return KS_TEST.kolmogorovSmirnovTest(distribution, datas); + } + + +} diff --git a/core/src/main/java/com/alibaba/alink/common/insights/DvInsightDescription.java b/core/src/main/java/com/alibaba/alink/common/insights/DvInsightDescription.java index facb0b1b1..7b07003e0 100644 --- a/core/src/main/java/com/alibaba/alink/common/insights/DvInsightDescription.java +++ b/core/src/main/java/com/alibaba/alink/common/insights/DvInsightDescription.java @@ -107,6 +107,10 @@ public static class View implements Serializable { } public static DvInsightDescription of(Insight insight) { + return DvInsightDescription.of(insight, new HashMap <>()); + } + + public static DvInsightDescription of(Insight insight, Map cnNamesMap) { InsightType type = insight.type; int colNum = insight.layout.data.getNumCol(); int rowNum = insight.layout.data.getNumRow(); @@ -122,7 +126,7 @@ public static DvInsightDescription of(Insight insight) { fields[i] = new Field(); fields[i].id = colNames[i]; fields[i].code = colNames[i]; - fields[i].alias = colName; + fields[i].alias = cnNamesMap.getOrDefault(colName, colName);; fields[i].abstraction = new Abstraction(); if (i == 0) { fields[i].abstraction.aggregation = "COUNTDISTINCT"; @@ -144,7 +148,8 @@ public static DvInsightDescription of(Insight insight) { fields[i].id = colNames[i]; fields[i].code = colNames[i]; if (i != 0) { - fields[i].alias = insight.subject.measures.get(i - 1).colName; + String colName = insight.subject.measures.get(i - 1).colName; + fields[i].alias = cnNamesMap.getOrDefault(colName, colName); fields[i].abstraction = new Abstraction(); fields[i].abstraction.aggregation = insight.subject.measures.get(i - 1).aggr.getEnName(); } @@ -157,7 +162,8 @@ public static DvInsightDescription of(Insight insight) { if (i >= 1) { if (insight.subject != null && insight.subject.measures != null && i - 1 < insight.subject.measures.size()) { - fields[i].alias = insight.subject.measures.get(i - 1).colName; + String colName = insight.subject.measures.get(i - 1).colName; + fields[i].alias = cnNamesMap.getOrDefault(colName, colName); fields[i].abstraction = new Abstraction(); fields[i].abstraction.aggregation = insight.subject.measures.get(i - 1).aggr.getEnName(); } @@ -170,7 +176,8 @@ public static DvInsightDescription of(Insight insight) { fields[i].code = colNames[i]; if (i >= 1) { if (insight.subject != null && insight.subject.measures != null) { - fields[i].alias = insight.subject.measures.get(0).colName; + String colName = insight.subject.measures.get(0).colName; + fields[i].alias = cnNamesMap.getOrDefault(colName, colName); fields[i].abstraction = new Abstraction(); fields[i].abstraction.aggregation = insight.subject.measures.get(0).aggr.getEnName(); } @@ -184,7 +191,8 @@ public static DvInsightDescription of(Insight insight) { if (i >= 1) { if (insight.subject != null && insight.subject.measures != null && i - 1 < insight.subject.measures.size()) { - fields[i].alias = insight.subject.measures.get(i - 1).colName; + String colName = insight.subject.measures.get(i - 1).colName; + fields[i].alias = cnNamesMap.getOrDefault(colName, colName); fields[i].abstraction = new Abstraction(); fields[i].abstraction.aggregation = insight.subject.measures.get(i - 1).aggr.getEnName(); } @@ -197,7 +205,8 @@ public static DvInsightDescription of(Insight insight) { fields[i].code = colNames[i]; if (i == 1) { if (insight.subject != null && insight.subject.measures != null) { - fields[i].alias = insight.subject.measures.get(0).colName; + String colName = insight.subject.measures.get(0).colName; + fields[i].alias = cnNamesMap.getOrDefault(colName, colName); fields[i].abstraction = new Abstraction(); fields[i].abstraction.aggregation = insight.subject.measures.get(0).aggr.getEnName(); } diff --git a/core/src/main/java/com/alibaba/alink/common/insights/StatInsight.java b/core/src/main/java/com/alibaba/alink/common/insights/StatInsight.java index f88f73ba4..6fc3706b9 100644 --- a/core/src/main/java/com/alibaba/alink/common/insights/StatInsight.java +++ b/core/src/main/java/com/alibaba/alink/common/insights/StatInsight.java @@ -20,7 +20,41 @@ public class StatInsight { public static boolean isNumberType(TypeInformation type) { - return type.equals(Types.INT) || type.equals(Types.LONG) || type.equals(Types.DOUBLE) || type.equals(Types.FLOAT) || type.equals(Types.SHORT); + return type.equals(Types.INT) || type.equals(Types.LONG) || type.equals(Types.DOUBLE) || type.equals(Types.FLOAT) + || type.equals(Types.SHORT) || type.equals(Types.BIG_DEC) || type.equals(Types.BIG_INT); + } + + public static Insight basicStatForString(LocalOperator dataAggr, String colName) { + Insight insight = new Insight(); + insight.type = InsightType.BasicStat; + List list = dataAggr.getOutputTable().getRows(); + // colName, term, frequency + int distinct_count = 0; + long count = 0L; + List countList = new ArrayList <>(); + for (Row row : list) { + Object object = row.getField(0); + if (null == object) { + continue; + } + if (object.equals(colName)) { + distinct_count++; + count += Long.valueOf(String.valueOf(row.getField(2))); + countList.add((Integer) row.getField(2)); + } + } + LayoutData layoutData = new LayoutData(); + String schema = "distinct_count_value int, count_value long"; + Row row = new Row(2); + row.setField(0, distinct_count); + row.setField(1, count); + MTable mTable = new MTable(new Row[]{row}, schema); + layoutData.data = mTable; + layoutData.title = "数据列 " + colName + " 统计数据"; + layoutData.xAxis = colName; + insight.layout = layoutData; + insight.score = 0.8; + return insight; } public static Insight basicStat(LocalOperator dataAggr, String colName) { diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuiltInAggRegister.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuiltInAggRegister.java index 09f8920a5..6c81d6832 100644 --- a/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuiltInAggRegister.java +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/BuiltInAggRegister.java @@ -29,6 +29,14 @@ import com.alibaba.alink.common.sql.builtin.agg.TimeSeriesAgg; import com.alibaba.alink.common.sql.builtin.agg.VarPopUdaf; import com.alibaba.alink.common.sql.builtin.agg.VarSampUdaf; +import com.alibaba.alink.common.sql.builtin.string.string.DateAdd; +import com.alibaba.alink.common.sql.builtin.string.string.DateDiff; +import com.alibaba.alink.common.sql.builtin.string.string.DateSub; +import com.alibaba.alink.common.sql.builtin.string.string.KeyValue; +import com.alibaba.alink.common.sql.builtin.string.string.RegExp; +import com.alibaba.alink.common.sql.builtin.string.string.RegExpExtract; +import com.alibaba.alink.common.sql.builtin.string.string.RegExpReplace; +import com.alibaba.alink.common.sql.builtin.string.string.SplitPart; import com.alibaba.alink.common.sql.builtin.time.DataFormat; import com.alibaba.alink.common.sql.builtin.time.FromUnixTime; import com.alibaba.alink.common.sql.builtin.time.Now; @@ -47,6 +55,19 @@ public static void registerUdf(TableEnvironment env) { env.registerFunction("unix_timestamp", new UnixTimeStamp()); env.registerFunction("from_unixtime", new FromUnixTime()); env.registerFunction("date_format_ltz", new DataFormat()); + + env.registerFunction("split_part", new SplitPart()); + env.registerFunction("keyvalue", new KeyValue()); + env.registerFunction("datediff", new DateDiff()); + env.registerFunction("regexp_replace", new RegExpReplace()); + env.registerFunction("REGEXP_REPLACE", new RegExpReplace()); + env.registerFunction("regexp", new RegExp()); + env.registerFunction("REGEXP", new RegExp()); + env.registerFunction("regexp_extract", new RegExpExtract()); + env.registerFunction("REGEXP_EXTRACT", new RegExpExtract()); + env.registerFunction("DATE_ADD", new DateAdd()); + env.registerFunction("DATE_SUB", new DateSub()); + } public static void registerUdf(LocalOpCalciteSqlExecutor executor) { diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateAdd.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateAdd.java new file mode 100644 index 000000000..5c664fb02 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateAdd.java @@ -0,0 +1,35 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import java.sql.Timestamp; + +/** + * @author weibo zhao + */ +public class DateAdd extends ScalarFunction { + + private static final long serialVersionUID = -4716626352353627712L; + + public String eval(String end, int days) { + if (end == null) { + return null; + } + try { + Timestamp tsEnd = Timestamp.valueOf(end); + long ld = (tsEnd.getTime() + days * 86400000L); + return new Timestamp(ld).toString(); + } catch (Exception e) { + return null; + } + } + + public String eval(Timestamp end, int days) { + if (end == null) { + return null; + } + long ld = (end.getTime() + days * 86400000L); + return new Timestamp(ld).toString(); + + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateDiff.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateDiff.java new file mode 100644 index 000000000..c89a9e504 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateDiff.java @@ -0,0 +1,58 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import java.sql.Timestamp; + +/** + * @author weibo zhao + */ + +public class DateDiff extends ScalarFunction { + + private static final long serialVersionUID = 6298088633116239045L; + + public Long eval(String end, String start) { + if (start == null || end == null) { + return null; + } + try { + Timestamp tsEnd = Timestamp.valueOf(end); + Timestamp tsStart = Timestamp.valueOf(start); + return (tsEnd.getTime() - tsStart.getTime()) / 86400000L; + } catch (Exception e) { + return null; + } + } + + public Long eval(String end, Timestamp start) { + if (start == null || end == null) { + return null; + } + try { + Timestamp tsEnd = Timestamp.valueOf(end); + return (tsEnd.getTime() - start.getTime()) / 86400000L; + } catch (Exception e) { + return null; + } + } + + public Long eval(Timestamp end, Timestamp start) { + if (start == null || end == null) { + return null; + } + return (end.getTime() - start.getTime()) / 86400000L; + } + + public Long eval(Timestamp end, String start) { + if (start == null || end == null) { + return null; + } + try { + Timestamp tsStart = Timestamp.valueOf(start); + return (end.getTime() - tsStart.getTime()) / 86400000L; + } catch (Exception e) { + return null; + } + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateSub.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateSub.java new file mode 100644 index 000000000..df30644a7 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/DateSub.java @@ -0,0 +1,34 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import java.sql.Timestamp; + +/** + * @author weibo zhao + */ +public class DateSub extends ScalarFunction { + private static final long serialVersionUID = -8304706070267129499L; + + public String eval(String end, int days) { + if (end == null) { + return null; + } + try { + Timestamp tsEnd = Timestamp.valueOf(end); + long ld = (tsEnd.getTime() - days * 86400000L); + return new Timestamp(ld).toString(); + } catch (Exception e) { + return null; + } + } + + public String eval(Timestamp end, int days) { + if (end == null) { + return null; + } + long ld = (end.getTime() - days * 86400000L); + return new Timestamp(ld).toString(); + + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/KeyValue.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/KeyValue.java new file mode 100644 index 000000000..67021ac01 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/KeyValue.java @@ -0,0 +1,29 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import org.apache.commons.lang3.StringUtils; + +/** + * 将srcStr(源字符串)按delimiter1分成“key-value”对,按delimiter2将key-value对分开,返回“key”所对应的value。 + * + * @author dota.zk + * @date 23/05/2018 + */ +public class KeyValue extends ScalarFunction { + + private static final long serialVersionUID = 772339384889917291L; + + public String eval(String src, String delimiter1, String delimiter2, String key) { + if (src == null || delimiter1 == null || delimiter2 == null || key == null) { + return null; + } + for (final String s : StringUtils.splitByWholeSeparator(src, delimiter1)) { + final String[] L = StringUtils.splitByWholeSeparatorPreserveAllTokens(s, delimiter2, 2); + if (L[0].equals(key)) { + return L.length == 2 ? L[1] : ""; + } + } + return null; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExp.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExp.java new file mode 100644 index 000000000..c7b9f1abf --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExp.java @@ -0,0 +1,28 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * 判断 srcStr(源字符串)中 是否存在 parttern 匹配的字符串。 + * + * @author weibo zhao + */ +public class RegExp extends ScalarFunction { + + private static final long serialVersionUID = -194627833036515975L; + + public boolean eval(String srcStr, String pattern) { + if (srcStr == null || pattern == null) { + return srcStr == null && pattern == null; + } + // 创建 Pattern 对象 + Pattern r = Pattern.compile(pattern); + // 现在创建 matcher 对象 + Matcher m = r.matcher(srcStr); + + return m.find(); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpExtract.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpExtract.java new file mode 100644 index 000000000..7b9a66aea --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpExtract.java @@ -0,0 +1,31 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * 将srcStr(源字符串)中parttern 匹配的字符串替换为 replaceStr。 + * + * @author weibo zhao + */ +public class RegExpExtract extends ScalarFunction { + + private static final long serialVersionUID = -7829402026643969109L; + + public String eval(String srcStr, String pattern, int idx) { + String ret = null; + try { + // 创建 Pattern 对象 + Pattern r = Pattern.compile(pattern); + // 现在创建 matcher 对象 + Matcher m = r.matcher(srcStr); + m.find(); + ret = m.group(idx); + } catch (Exception e) { + return null; + } + return ret; + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpReplace.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpReplace.java new file mode 100644 index 000000000..771d15c22 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/RegExpReplace.java @@ -0,0 +1,28 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * 将srcStr(源字符串)中parttern 匹配的字符串替换为 replaceStr。 + * + * @author weibo zhao + */ +public class RegExpReplace extends ScalarFunction { + + private static final long serialVersionUID = 6928900303963924422L; + + public String eval(String srcStr, String pattern, String replaceStr) { + if (srcStr == null || pattern == null || replaceStr == null) { + return null; + } + // 创建 Pattern 对象 + Pattern r = Pattern.compile(pattern); + // 现在创建 matcher 对象 + Matcher m = r.matcher(srcStr); + + return m.replaceAll(replaceStr); + } +} diff --git a/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/SplitPart.java b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/SplitPart.java new file mode 100644 index 000000000..8c72bec07 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/common/sql/builtin/string/string/SplitPart.java @@ -0,0 +1,36 @@ +package com.alibaba.alink.common.sql.builtin.string.string; + +import org.apache.flink.table.functions.ScalarFunction; + +import org.apache.commons.lang3.StringUtils; + +/** + * @author dota.zk + * @reference http://odps.alibaba-inc.com/doc.htm SQL->udf->split_part + * @date 23/05/2018 + */ +public class SplitPart extends ScalarFunction { + + private static final long serialVersionUID = 3053182766744474677L; + + public String eval(String src, String delimiter, Integer nth) { + if (src == null || delimiter == null || nth == null) { return null; } + if (nth < 1) { return ""; } + final String[] L = StringUtils.splitByWholeSeparatorPreserveAllTokens(src, delimiter); + if (L.length < nth) { + return ""; + } + return L[nth - 1]; + } + + public String eval(String src, String delimiter, Integer start, Integer end) { + if (src == null || delimiter == null || start == null || end == null) { return null; } + if (delimiter.isEmpty()) { return src; } + if (start > end) { return ""; } + if (start < 0) { start = 1; } + final String[] L = StringUtils.splitByWholeSeparatorPreserveAllTokens(src, delimiter); + if (end > L.length) { end = L.length; } + if (start > end) { return ""; } + return StringUtils.join(L, delimiter, start - 1, end); + } +} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/PipelinePredictBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/PipelinePredictBatchOp.java deleted file mode 100644 index a80e18c87..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/batch/PipelinePredictBatchOp.java +++ /dev/null @@ -1,44 +0,0 @@ -package com.alibaba.alink.operator.batch; - -import org.apache.flink.ml.api.misc.param.Params; - -import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.exceptions.AkIllegalDataException; -import com.alibaba.alink.operator.batch.utils.DataSetConversionUtil; -import com.alibaba.alink.params.PipelinePredictParams; -import com.alibaba.alink.pipeline.PipelineModel; - -/** - * Pipeline prediction. - */ - -@NameCn("Pipeline 预测") -@NameEn("Pipeline prediction") -public final class PipelinePredictBatchOp extends BatchOperator - implements PipelinePredictParams { - - public PipelinePredictBatchOp() { - super(new Params()); - } - - public PipelinePredictBatchOp(Params params) { - super(params); - } - - @Override - public PipelinePredictBatchOp linkFrom(BatchOperator ... inputs) { - try { - BatchOperator data = checkAndGetFirst(inputs); - final PipelineModel pipelineModel = PipelineModel.load(getModelFilePath()) - .setMLEnvironmentId(data.getMLEnvironmentId()); - BatchOperator result = pipelineModel.transform(data); - this.setOutput(DataSetConversionUtil.toTable(data.getMLEnvironmentId(), - result.getDataSet(), result.getSchema())); - return this; - } catch (Exception ex) { - ex.printStackTrace(); - throw new AkIllegalDataException(ex.toString()); - } - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java index 00b201abc..3cd89f967 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/feature/QuantileDiscretizerTrainBatchOp.java @@ -32,9 +32,9 @@ import com.alibaba.alink.common.exceptions.AkIllegalOperatorParameterException; import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; -import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; +import com.alibaba.alink.operator.batch.utils.WithModelInfoBatchOp; import com.alibaba.alink.operator.common.dataproc.SortUtils; import com.alibaba.alink.operator.common.dataproc.SortUtilsNext; import com.alibaba.alink.operator.common.feature.ContinuousRanges; @@ -378,16 +378,24 @@ public QuantileDiscretizerModelInfoBatchOp getModelInfoBatchOp() { public static class MultiQuantile extends RichMapPartitionFunction > { private static final long serialVersionUID = -467677491431226184L; + private final HasRoundMode.RoundMode roundType; + private final boolean quantileIncludesBoundary; protected int[] quantileNum; private List > counts; private List > missingCounts; private long totalCnt = 0; - private HasRoundMode.RoundMode roundType; private int taskId; public MultiQuantile(int[] quantileNum, HasRoundMode.RoundMode roundType) { this.quantileNum = quantileNum; this.roundType = roundType; + this.quantileIncludesBoundary = false; + } + + public MultiQuantile(int[] quantileNum, HasRoundMode.RoundMode roundType, boolean quantileIncludesBoundary) { + this.quantileNum = quantileNum; + this.roundType = roundType; + this.quantileIncludesBoundary = quantileIncludesBoundary; } @Override @@ -519,7 +527,10 @@ public void mapPartition(Iterable values, Collector = subStart && index < subEnd) { diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java index 43e6958f2..db06be28f 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/nlp/TfidfBatchOp.java @@ -59,9 +59,10 @@ public TfidfBatchOp linkFrom(BatchOperator ... inputs) { // Count doc and word count in a doc final BatchOperator docStat = in.groupBy(docIdColName, docIdColName + ",sum(" + countColName + ") as total_word_count"); + //Count totoal word count of words BatchOperator wordStat = in.groupBy(wordColName + "," + docIdColName, - wordColName + "," + docIdColName + ",COUNT(1 ) as tmp_count") + wordColName + "," + docIdColName + ",COUNT(1 ) as tmp_count") .groupBy(wordColName, wordColName + ",count(1) as doc_cnt"); final String tmpColNames = docIdColName + "," + wordColName + "," + countColName + "," + "total_word_count"; @@ -83,7 +84,11 @@ public TfidfBatchOp linkFrom(BatchOperator ... inputs) { //Count tf idf resulst of words in docs this.setOutput(join2 .getDataSet() - .join(docStat.select("1 as id,count(1) as total_doc_count").getDataSet() + .join(docStat + // now not support select count(1). + .select("1 as tmpId, *") + .groupBy("tmpId", "1 as id,count(1) as total_doc_count") + .getDataSet() , JoinOperatorBase.JoinHint.BROADCAST_HASH_SECOND) .where("id1").equalTo("id") .map(new MapFunction , Row>() { @@ -114,7 +119,6 @@ public Row map(Tuple2 rowRowTuple2) throws Exception { }), tmpColNames2.split(","), new TypeInformation [] {types[docIdIndex], Types.STRING, Types.LONG, Types.LONG, Types.LONG, Types.LONG, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE}); - ; return this; } } diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java index efc3a4d0c..5d90ee049 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/recommendation/NegativeItemSamplingBatchOp.java @@ -118,7 +118,7 @@ public Tuple3 map(Tuple2 value) { .reduceGroup( new RichGroupReduceFunction , Tuple3 >() { private static final long serialVersionUID = 306722066512456784L; - transient List candidates; + transient List candidates; transient Random random; @Override diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java index 14e2ba212..eca4d100d 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/sql/SelectBatchOp.java @@ -7,6 +7,7 @@ import com.alibaba.alink.common.annotation.NameEn; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.utils.MapBatchOp; +import com.alibaba.alink.operator.common.sql.SelectMapper; import com.alibaba.alink.operator.common.sql.SelectUtils; import com.alibaba.alink.operator.common.sql.SimpleSelectMapper; import com.alibaba.alink.params.sql.SelectParams; @@ -16,7 +17,7 @@ */ @NameCn("SQL操作:Select") @NameEn("SQL Select Operation") -public final class SelectBatchOp extends BaseSqlApiBatchOp +public final class SelectBatchOp extends MapBatchOp implements SelectParams { private static final long serialVersionUID = -1867376056670775636L; @@ -30,40 +31,7 @@ public SelectBatchOp(String clause) { } public SelectBatchOp(Params params) { - super(params); - } - - @Override - public SelectBatchOp linkFrom(BatchOperator ... inputs) { - BatchOperator in = checkAndGetFirst(inputs); - String[] colNames = in.getColNames(); - - String clause = getClause(); - String newClause = SelectUtils.convertRegexClause2ColNames(colNames, clause); - - if (SelectUtils.isSimpleSelect(newClause, colNames)) { - this.setOutputTable( - in.link(new SimpleSelectBatchOp() - .setClause(newClause) - .setMLEnvironmentId(in.getMLEnvironmentId()) - ).getOutputTable()); - } else { - this.setOutputTable(BatchSqlOperators.select(in, newClause).getOutputTable()); - } - return this; - } - - @Internal - private class SimpleSelectBatchOp extends MapBatchOp - implements SelectParams { - - public SimpleSelectBatchOp() { - this(null); - } - - public SimpleSelectBatchOp(Params param) { - super(SimpleSelectMapper::new, param); - } + super(SelectMapper::new, params); } } \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/QuantileBatchOp.java b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/QuantileBatchOp.java index 58845559f..afc8cb2a1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/batch/statistics/QuantileBatchOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/batch/statistics/QuantileBatchOp.java @@ -1,14 +1,16 @@ package com.alibaba.alink.operator.batch.statistics; import org.apache.flink.api.common.functions.BroadcastVariableInitializer; -import org.apache.flink.api.common.functions.RichGroupReduceFunction; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.utils.DataSetUtils; import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; -import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; @@ -20,30 +22,37 @@ import com.alibaba.alink.common.annotation.PortSpec; import com.alibaba.alink.common.annotation.PortType; import com.alibaba.alink.common.annotation.TypeCollections; -import com.alibaba.alink.common.utils.RowUtil; +import com.alibaba.alink.common.exceptions.AkIllegalArgumentException; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp; -import com.alibaba.alink.operator.common.dataproc.SortUtils; -import com.alibaba.alink.params.statistics.QuantileBatchParams; +import com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp.MultiQuantile; +import com.alibaba.alink.operator.common.feature.quantile.PairComparable; +import com.alibaba.alink.operator.common.tree.Preprocessing; +import com.alibaba.alink.params.shared.colname.HasSelectedColDefaultAsNull; +import com.alibaba.alink.params.shared.colname.HasSelectedColsDefaultAsNull; +import com.alibaba.alink.params.statistics.HasRoundMode; +import com.alibaba.alink.params.statistics.QuantileParams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.ArrayList; -import java.util.Collections; +import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.Objects; + +import static com.alibaba.alink.operator.batch.feature.QuantileDiscretizerTrainBatchOp.quantilePreparing; /** - * In statistics and probability quantiles are cut points dividing - * the range of a probability distribution into contiguous intervals - * with equal probabilities, or dividing the observations in a sample - * in the same way. + * In statistics and probability quantiles are cut points dividing the range of a probability distribution into + * contiguous intervals with equal probabilities, or dividing the observations in a sample in the same way. * (https://en.wikipedia.org/wiki/Quantile) *

* reference: Yang, X. (2014). Chong gou da shu ju tong ji (1st ed., pp. 25-29). *

- * Note: This algorithm is improved on the base of the parallel - * sorting by regular sampling(PSRS). The following step is added - * to the PSRS + * Note: This algorithm is improved on the base of the parallel sorting by regular sampling(PSRS). The following step is + * added to the PSRS *

    *
  • replace (val) with (val, task id) to distinguishing the * same value on different machines
  • @@ -55,14 +64,16 @@ */ @InputPorts(values = @PortSpec(PortType.DATA)) @OutputPorts(values = @PortSpec(PortType.DATA)) -@ParamSelectColumnSpec(name = "selectedCol", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) +@ParamSelectColumnSpec(name = "selectedCols", allowedTypeCollections = TypeCollections.NUMERIC_TYPES) @NameCn("分位数") @NameEn("Quantile") public final class QuantileBatchOp extends BatchOperator - implements QuantileBatchParams { + implements QuantileParams { private static final long serialVersionUID = -86119177892147044L; + private static final Logger LOG = LoggerFactory.getLogger(QuantileBatchOp.class); + public QuantileBatchOp() { super(null); } @@ -75,141 +86,175 @@ public QuantileBatchOp(Params params) { public QuantileBatchOp linkFrom(BatchOperator ... inputs) { BatchOperator in = checkAndGetFirst(inputs); - TableSchema tableSchema = in.getSchema(); + String[] quantileColNames = getParams().get(HasSelectedColsDefaultAsNull.SELECTED_COLS); - String quantileColName = getSelectedCol(); + if (quantileColNames == null) { - int index = TableUtil.findColIndexWithAssertAndHint(tableSchema.getFieldNames(), quantileColName); + if (getParams().get(HasSelectedColDefaultAsNull.SELECTED_COL) == null) { + throw new AkIllegalArgumentException("There must select one or more colum in quantile batch op."); + } + + quantileColNames = new String[] {getParams().get(HasSelectedColDefaultAsNull.SELECTED_COL)}; + } - /* filter the selected column from input */ - DataSet input = in.select(quantileColName).getDataSet(); + TypeInformation [] quantileColTypes + = TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames); - /* sort data */ - Tuple2 >, DataSet >> sortedData - = SortUtils.pSort(input, 0); + final int[] quantileNum = new int[quantileColNames.length]; + Arrays.fill(quantileNum, getQuantileNum()); + + final HasRoundMode.RoundMode roundMode = getRoundMode(); + + DataSet input = Preprocessing.select(in, quantileColNames).getDataSet(); + + Tuple4 , DataSet >, + DataSet , DataSet >> quantileData = + quantilePreparing(input, getParams().get(Preprocessing.ZERO_AS_MISSING)); /* calculate quantile */ - DataSet quantile = sortedData.f0. - groupBy(0) - .reduceGroup(new Quantile( - 0, getQuantileNum(), - getRoundMode())) - .withBroadcastSet(sortedData.f1, "counts"); + DataSet > quantileTuple = quantileData.f0 + .mapPartition(new MultiQuantile(quantileNum, roundMode, true)) + .withBroadcastSet(quantileData.f1, "counts") + .withBroadcastSet(quantileData.f2, "totalCnt") + .withBroadcastSet(quantileData.f3, "missingCounts"); + + DataSet > indexedQuantile = quantileTuple + .partitionByHash(0) + .mapPartition(new GiveQuantileIndices()); + + DataSet quantile = indexedQuantile + .partitionByHash(2) + .mapPartition(new SerializeOutput(quantileNum)) + .withBroadcastSet(DataSetUtils.countElementsPerPartition(indexedQuantile).sum(1), "count"); + + String[] outputColNames = new String[quantileColNames.length + 1]; + TypeInformation [] outputColTypes = new TypeInformation [quantileColNames.length + 1]; + + outputColNames[0] = "quantile"; + outputColTypes[0] = AlinkTypes.LONG; + + System.arraycopy(quantileColNames, 0, outputColNames, 1, quantileColNames.length); + System.arraycopy(quantileColTypes, 0, outputColTypes, 1, quantileColNames.length); /* set output */ - setOutput(quantile, - new String[] {tableSchema.getFieldNames()[index], "quantile"}, - new TypeInformation [] {tableSchema.getFieldTypes()[index], BasicTypeInfo.LONG_TYPE_INFO}); + setOutput(quantile, outputColNames, outputColTypes); return this; } - /** - * - */ - public static class Quantile extends RichGroupReduceFunction , Row> { - private static final long serialVersionUID = -6101513604891658021L; - private int index; - private List > counts; - private long countSum = 0; - private int quantileNum; - private RoundMode roundType; - - public Quantile(int index, int quantileNum, RoundMode roundType) { - this.index = index; - this.quantileNum = quantileNum; - this.roundType = roundType; - } + private static class SerializeOutput + extends RichMapPartitionFunction , Row> { + private final int[] quantileNum; + private long cnt; + + public SerializeOutput(int[] quantileNum) {this.quantileNum = quantileNum;} @Override public void open(Configuration parameters) throws Exception { - this.counts = getRuntimeContext().getBroadcastVariableWithInitializer( - "counts", - new BroadcastVariableInitializer , List >>() { - @Override - public List > initializeBroadcastVariable( - Iterable > data) { - // sort the list by task id to calculate the correct offset - List > sortedData = new ArrayList <>(); - for (Tuple2 datum : data) { - sortedData.add(datum); + super.open(parameters); + + cnt = getRuntimeContext() + .getBroadcastVariableWithInitializer( + "count", + new BroadcastVariableInitializer , Long>() { + @Override + public Long initializeBroadcastVariable(Iterable > data) { + return data.iterator().next().f1; } - Collections.sort(sortedData, new Comparator >() { - @Override - public int compare(Tuple2 o1, Tuple2 o2) { - return o1.f0.compareTo(o2.f0); - } - }); - - return sortedData; } - }); - - for (int i = 0; i < this.counts.size(); ++i) { - countSum += this.counts.get(i).f1; - } + ); } @Override - public void reduce(Iterable > values, Collector out) throws Exception { - ArrayList allRows = new ArrayList <>(); - int id = -1; - long start = 0; - long end = 0; - - for (Tuple2 value : values) { - id = value.f0; - allRows.add(Row.copy(value.f1)); + public void mapPartition(Iterable > values, Collector out) { + if (cnt <= 0L) { + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + Row row = new Row(quantileNum.length + 1); + for (int i = 0; i <= quantileNum[0]; ++i) { + row.setField(0, (long) i); + for (int j = 0; j < quantileNum.length; ++j) { + row.setField(j + 1, null); + } + out.collect(row); + } + } + + return; } - if (id < 0) { - throw new Exception("Error key. key: " + id); + List > buffer = new ArrayList <>(); + + for (Tuple3 value : values) { + buffer.add(value); } - int curListIndex = -1; - int size = counts.size(); + if (buffer.isEmpty()) { + return; + } - for (int i = 0; i < size; ++i) { - int curId = counts.get(i).f0; + buffer.sort(Comparator.comparing(o -> o.f2)); + Row row = new Row(quantileNum.length + 1); + int size = buffer.size(); - if (curId == id) { - curListIndex = i; - break; + for (int i = 0; i < size + 1; ++i) { + Tuple3 prev; + Tuple3 next; + + if (i == 0) { + prev = buffer.get(i); + } else { + prev = buffer.get(i - 1); } - if (curId > id) { - throw new Exception("Error curId: " + curId - + ". id: " + id); + if (i == size) { + next = buffer.get(i - 1); + } else { + next = buffer.get(i); } - start += counts.get(i).f1; - } + if (i == size || !Objects.equals(next.f2, prev.f2)) { + out.collect(row); + } - end = start + counts.get(curListIndex).f1; + if (i == 0 || !Objects.equals(next.f2, prev.f2)) { + row.setField(0, Long.valueOf(next.f2)); + for (int j = 0; j < quantileNum.length; ++j) { + row.setField(j + 1, null); + } + } - if (allRows.size() != end - start) { - throw new Exception("Error start end." - + " start: " + start - + ". end: " + end - + ". size: " + allRows.size()); + row.setField(next.f0 + 1, next.f1); } + } + } - SortUtils.RowComparator rowComparator = new SortUtils.RowComparator(this.index); - Collections.sort(allRows, rowComparator); - - QuantileDiscretizerTrainBatchOp.QIndex qIndex = new QuantileDiscretizerTrainBatchOp.QIndex( - countSum, quantileNum, roundType); + private static class GiveQuantileIndices + implements MapPartitionFunction , Tuple3 > { + @Override + public void mapPartition(Iterable > values, + Collector > out) { + List buffer = new ArrayList <>(); + + for (Tuple2 value : values) { + PairComparable pairComparable = new PairComparable(); + pairComparable.first = value.f0; + pairComparable.second = value.f1; + buffer.add(pairComparable); + } - for (int i = 0; i <= quantileNum; ++i) { - long index = qIndex.genIndex(i); + buffer.sort(PairComparable::compareTo); - if (index >= start && index < end) { - out.collect( - RowUtil.merge(allRows.get((int) (index - start)), Long.valueOf(i))); + int latest = -1; + int quantile = 0; + for (PairComparable pairComparable : buffer) { + if (latest != pairComparable.first) { + latest = pairComparable.first; + quantile = 0; } + out.collect(Tuple3.of(pairComparable.first, pairComparable.second, quantile)); + quantile++; } - } + } } - } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalFmOptimizer.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalFmOptimizer.java deleted file mode 100644 index e49a1b4bd..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalFmOptimizer.java +++ /dev/null @@ -1,298 +0,0 @@ -package com.alibaba.alink.operator.common.optim; - -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.ml.api.misc.param.Params; - -import com.alibaba.alink.common.linalg.DenseVector; -import com.alibaba.alink.common.linalg.SparseVector; -import com.alibaba.alink.common.linalg.Vector; -import com.alibaba.alink.common.model.ModelParamName; -import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.FmDataFormat; -import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LogitLoss; -import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.LossFunction; -import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.SquareLoss; -import com.alibaba.alink.operator.common.fm.BaseFmTrainBatchOp.Task; -import com.alibaba.alink.params.recommendation.FmTrainParams; - -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; - -/** - * Local fm optimizer. - */ -public class LocalFmOptimizer { - private final List > trainData; - private final int[] dim; - protected FmDataFormat fmModel = null; - private final double[] lambda; - - private FmDataFormat sigmaGii; - private final double learnRate; - private final LossFunction lossFunc; - private final int numEpochs; - private final Task task; - - private final double[] y; - private final double[] loss; - private final double[] vx; - private final double[] v2x2; - private long oldTime; - private final double[] lossCurve; - private double oldLoss = 1.0; - - /** - * construct function. - * - * @param trainData train data. - * @param params parameters for optimizer. - */ - public LocalFmOptimizer(List > trainData, Params params) { - this.numEpochs = params.get(FmTrainParams.NUM_EPOCHS); - this.trainData = trainData; - this.y = new double[trainData.size()]; - this.loss = new double[4]; - this.dim = new int[3]; - dim[0] = params.get(FmTrainParams.WITH_INTERCEPT) ? 1 : 0; - dim[1] = params.get(FmTrainParams.WITH_LINEAR_ITEM) ? 1 : 0; - dim[2] = params.get(FmTrainParams.NUM_FACTOR); - vx = new double[dim[2]]; - v2x2 = new double[dim[2]]; - this.lambda = new double[3]; - lambda[0] = params.get(FmTrainParams.LAMBDA_0); - lambda[1] = params.get(FmTrainParams.LAMBDA_1); - lambda[2] = params.get(FmTrainParams.LAMBDA_2); - task = params.get(ModelParamName.TASK); - this.learnRate = params.get(FmTrainParams.LEARN_RATE); - oldTime = System.currentTimeMillis(); - if (task.equals(Task.REGRESSION)) { - double minTarget = -1.0e20; - double maxTarget = 1.0e20; - double d = maxTarget - minTarget; - d = Math.max(d, 1.0); - maxTarget = maxTarget + d * 0.2; - minTarget = minTarget - d * 0.2; - lossFunc = new SquareLoss(maxTarget, minTarget); - } else { - lossFunc = new LogitLoss(); - } - - lossCurve = new double[numEpochs * 3]; - } - - /** - * initialize fmModel. - */ - public void setWithInitFactors(FmDataFormat model) { - this.fmModel = model; - int vectorSize = fmModel.factors.length; - sigmaGii = new FmDataFormat(vectorSize, dim, 0.0); - } - - /** - * optimize Fm problem. - * - * @return fm model. - */ - public Tuple2 optimize() { - for (int i = 0; i < numEpochs; ++i) { - updateFactors(); - calcLossAndEvaluation(); - if (termination(i)) { - break; - } - } - return Tuple2.of(fmModel, lossCurve); - } - - /** - * Termination function of fm iteration. - */ - public boolean termination(int step) { - lossCurve[3 * step] = loss[0] / loss[1]; - lossCurve[3 * step + 2] = loss[3] / loss[1]; - if (task.equals(Task.BINARY_CLASSIFICATION)) { - lossCurve[3 * step + 1] = loss[2]; - - System.out.println("step : " + step + " loss : " - + loss[0] / loss[1] + " auc : " + loss[2] + " accuracy : " - + loss[3] / loss[1] + " time : " + (System.currentTimeMillis() - - oldTime)); - } else { - lossCurve[3 * step + 1] = loss[2] / loss[1]; - System.out.println("step : " + step + " loss : " - + loss[0] / loss[1] + " mae : " + loss[2] / loss[1] + " mse : " - + loss[3] / loss[1] + " time : " + (System.currentTimeMillis() - - oldTime)); - } - oldTime = System.currentTimeMillis(); - if (Math.abs(oldLoss - loss[0] / loss[1]) / oldLoss < 1.0e-6) { - oldLoss = loss[0] / loss[1]; - return true; - } else { - oldLoss = loss[0] / loss[1]; - return false; - } - } - - /** - * Calculate loss and evaluations. - */ - public void calcLossAndEvaluation() { - double lossSum = 0.; - for (int i = 0; i < y.length; i++) { - double yTruth = trainData.get(i).f1; - double l = lossFunc.l(yTruth, y[i]); - lossSum += l; - } - - if (this.task.equals(Task.REGRESSION)) { - double mae = 0.0; - double mse = 0.0; - for (int i = 0; i < y.length; i++) { - double yDiff = y[i] - trainData.get(i).f1; - mae += Math.abs(yDiff); - mse += yDiff * yDiff; - } - loss[2] = mae; - loss[3] = mse; - } else { - Integer[] order = new Integer[y.length]; - double correctNum = 0.0; - for (int i = 0; i < y.length; i++) { - order[i] = i; - if (y[i] > 0 && trainData.get(i).f1 > 0.5) { - correctNum += 1.0; - } - if (y[i] < 0 && trainData.get(i).f1 < 0.5) { - correctNum += 1.0; - } - } - Arrays.sort(order, Comparator.comparingDouble(o -> y[o])); - int mSum = 0; - int nSum = 0; - double posRankSum = 0.; - for (int i = 0; i < order.length; i++) { - int sampleId = order[i]; - int rank = i + 1; - boolean isPositiveSample = trainData.get(sampleId).f1 > 0.5; - if (isPositiveSample) { - mSum++; - posRankSum += rank; - } else { - nSum++; - } - } - if (mSum != 0 && nSum != 0) { - double auc = (posRankSum - 0.5 * mSum * (mSum + 1.0)) / ((double) mSum * (double) nSum); - loss[2] = auc; - } else { - loss[2] = 0.0; - } - loss[3] = correctNum; - } - loss[0] = lossSum; - loss[1] = y.length; - } - - private void updateFactors() { - for (int i1 = 0; i1 < trainData.size(); ++i1) { - Tuple3 sample = trainData.get(i1); - Vector vec = sample.f2; - Tuple2 yVx = calcY(vec, fmModel, dim); - y[i1] = yVx.f0; - - double yTruth = sample.f1; - double dldy = lossFunc.dldy(yTruth, yVx.f0); - - int[] indices; - double[] vals; - if (sample.f2 instanceof SparseVector) { - indices = ((SparseVector) sample.f2).getIndices(); - vals = ((SparseVector) sample.f2).getValues(); - } else { - indices = new int[sample.f2.size()]; - for (int i = 0; i < sample.f2.size(); ++i) { - indices[i] = i; - } - vals = ((DenseVector) sample.f2).getData(); - } - double localLearnRate = sample.f0 * learnRate; - - double eps = 1.0e-8; - if (dim[0] > 0) { - double grad = dldy + lambda[0] * fmModel.bias; - sigmaGii.bias += grad * grad; - fmModel.bias -= localLearnRate * grad / (Math.sqrt(sigmaGii.bias + eps)); - } - - for (int i = 0; i < indices.length; ++i) { - int idx = indices[i]; - // update fmModel - for (int j = 0; j < dim[2]; j++) { - double vixi = vals[i] * fmModel.factors[idx][j]; - double d = vals[i] * (yVx.f1[j] - vixi); - double grad = dldy * d + lambda[2] * fmModel.factors[idx][j]; - sigmaGii.factors[idx][j] += grad * grad; - fmModel.factors[idx][j] -= localLearnRate * grad / (Math.sqrt(sigmaGii.factors[idx][j] + eps)); - } - if (dim[1] > 0) { - double grad = dldy * vals[i] + lambda[1] * fmModel.factors[idx][dim[2]]; - sigmaGii.factors[idx][dim[2]] += grad * grad; - fmModel.factors[idx][dim[2]] - -= grad * localLearnRate / (Math.sqrt(sigmaGii.factors[idx][dim[2]]+ eps)); - } - } - } - } - - /** - * calculate the value of y with given fm model. - */ - private Tuple2 calcY(Vector vec, FmDataFormat fmModel, int[] dim) { - int[] featureIds; - double[] featureValues; - if (vec instanceof SparseVector) { - featureIds = ((SparseVector) vec).getIndices(); - featureValues = ((SparseVector) vec).getValues(); - } else { - featureIds = new int[vec.size()]; - for (int i = 0; i < vec.size(); ++i) { - featureIds[i] = i; - } - featureValues = ((DenseVector) vec).getData(); - } - - Arrays.fill(vx, 0.0); - Arrays.fill(v2x2, 0.0); - - // (1) compute y - double y = 0.; - - if (dim[0] > 0) { - y += fmModel.bias; - } - - for (int i = 0; i < featureIds.length; i++) { - int featurePos = featureIds[i]; - double x = featureValues[i]; - - // the linear term - if (dim[1] > 0) { - y += x * fmModel.factors[featurePos][dim[2]]; - } - // the quadratic term - for (int j = 0; j < dim[2]; j++) { - double vixi = x * fmModel.factors[featurePos][j]; - vx[j] += vixi; - v2x2[j] += vixi * vixi; - } - } - - for (int i = 0; i < dim[2]; i++) { - y += 0.5 * (vx[i] * vx[i] - v2x2[i]); - } - return Tuple2.of(y, vx); - } -} \ No newline at end of file diff --git a/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java b/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java index e65ccf7ff..8f4efdcd0 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/optim/LocalOptimizer.java @@ -15,7 +15,6 @@ import com.alibaba.alink.operator.local.LocalOperator; import com.alibaba.alink.params.regression.HasEpsilon; import com.alibaba.alink.params.shared.HasNumCorrections_30; -import com.alibaba.alink.params.shared.HasNumThreads; import com.alibaba.alink.params.shared.iter.HasMaxIterDefaultAs100; import com.alibaba.alink.params.shared.linear.HasEpsilonDefaultAs0000001; import com.alibaba.alink.params.shared.linear.HasL1; @@ -34,7 +33,7 @@ */ public class LocalOptimizer { - private static final int NEWTON_MAX_FEATURE_NUM = 1024; + public static final int NEWTON_MAX_FEATURE_NUM = 1024; private static final double EPS = 1.0e-18; /** @@ -50,9 +49,7 @@ public static Tuple2 optimize(OptimObjFunc objFunc, DenseVector initCoef, Params params) { LinearTrainParams.OptimMethod method = params.get(LinearTrainParams.OPTIM_METHOD); if (null == method) { - if (initCoef.size() <= NEWTON_MAX_FEATURE_NUM) { - method = OptimMethod.Newton; - } else if (params.get(HasL1.L_1) > 0) { + if (params.get(HasL1.L_1) > 0) { method = OptimMethod.OWLQN; } else { method = OptimMethod.LBFGS; diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/CalciteSelectMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/sql/CalciteSelectMapper.java index 9befb3445..c394e5d08 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/CalciteSelectMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/sql/CalciteSelectMapper.java @@ -6,20 +6,42 @@ import org.apache.flink.table.api.TableSchema; import com.alibaba.alink.common.MLEnvironmentFactory; +import com.alibaba.alink.common.MTable; import com.alibaba.alink.common.exceptions.AkIllegalStateException; import com.alibaba.alink.common.exceptions.AkParseErrorException; import com.alibaba.alink.common.exceptions.AkUnclassifiedErrorException; +import com.alibaba.alink.common.linalg.Tensor; +import com.alibaba.alink.common.linalg.Vector; import com.alibaba.alink.common.mapper.Mapper; +import com.alibaba.alink.common.sql.builtin.string.string.DateAdd; +import com.alibaba.alink.common.sql.builtin.string.string.DateDiff; +import com.alibaba.alink.common.sql.builtin.string.string.DateSub; +import com.alibaba.alink.common.sql.builtin.string.string.KeyValue; +import com.alibaba.alink.common.sql.builtin.string.string.RegExp; +import com.alibaba.alink.common.sql.builtin.string.string.RegExpExtract; +import com.alibaba.alink.common.sql.builtin.string.string.RegExpReplace; +import com.alibaba.alink.common.sql.builtin.string.string.SplitPart; +import com.alibaba.alink.common.sql.builtin.time.DataFormat; +import com.alibaba.alink.common.sql.builtin.time.FromUnixTime; +import com.alibaba.alink.common.sql.builtin.time.Now; +import com.alibaba.alink.common.sql.builtin.time.ToTimeStamp; +import com.alibaba.alink.common.sql.builtin.time.UnixTimeStamp; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; -import com.alibaba.alink.operator.batch.sql.SelectBatchOp; +import com.alibaba.alink.operator.batch.sql.BatchSqlOperators; import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter; import com.alibaba.alink.operator.common.sql.functions.MathFunctions; import com.alibaba.alink.operator.common.sql.functions.StringFunctions; import com.alibaba.alink.params.sql.SelectParams; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.avatica.util.Quoting; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.config.NullCollation; import org.apache.calcite.jdbc.CalciteConnection; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.schema.impl.ScalarFunctionImpl; import org.apache.calcite.util.BuiltInMethod; +import scala.Int; import java.lang.reflect.Method; import java.math.BigDecimal; @@ -30,7 +52,10 @@ import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; import java.util.Collections; +import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiConsumer; @@ -47,6 +72,8 @@ public class CalciteSelectMapper extends Mapper { private final static String TEMPLATE = "SELECT %s FROM (SELECT %s FROM (VALUES (1))) foo"; + private final static String COL_NAME_PREFIX = "alink_prefix_"; + /** * Constructor. * @@ -57,6 +84,7 @@ public CalciteSelectMapper(TableSchema dataSchema, Params params) { super(dataSchema, params); } + // Case sensitive. public static void registerFlinkBuiltInFunctions(SchemaPlus schema) { BiConsumer addScalarFunctionConsumer = (k, v) -> schema.add(k, ScalarFunctionImpl.create(v)); @@ -86,6 +114,10 @@ public static void registerFlinkBuiltInFunctions(SchemaPlus schema) { addScalarFunctionConsumer.accept("RPAD", StringFunctions.RPAD); //addScalarFunctionConsumer.accept("REGEXP_REPLACE", StringFunctions.REGEXP_REPLACE); addScalarFunctionConsumer.accept("REGEXP_EXTRACT", StringFunctions.REGEXP_EXTRACT); + addScalarFunctionConsumer.accept("CONCAT", StringFunctions.CONCAT); + addScalarFunctionConsumer.accept("CONCAT", StringFunctions.CONCAT3); + addScalarFunctionConsumer.accept("CONCAT", StringFunctions.CONCAT4); + addScalarFunctionConsumer.accept("CONCAT", StringFunctions.CONCAT5); addScalarFunctionConsumer.accept("LTRIM", BuiltInMethod.LTRIM.method); addScalarFunctionConsumer.accept("RTRIM", BuiltInMethod.RTRIM.method); @@ -97,6 +129,82 @@ public static void registerFlinkBuiltInFunctions(SchemaPlus schema) { addScalarFunctionConsumer.accept("SHA384", StringFunctions.SHA384); addScalarFunctionConsumer.accept("SHA512", StringFunctions.SHA512); addScalarFunctionConsumer.accept("SHA2", StringFunctions.SHA2); + + // time function. + addScalarFunctionConsumer.accept("NOW", org.apache.calcite.linq4j.tree.Types.lookupMethod( + Now.class, "eval", int.class)); + addScalarFunctionConsumer.accept("NOW", org.apache.calcite.linq4j.tree.Types.lookupMethod( + Now.class, "eval")); + + addScalarFunctionConsumer.accept("DATE_FORMAT_LTZ", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DataFormat.class, "eval", Timestamp.class)); + addScalarFunctionConsumer.accept("DATE_FORMAT_LTZ", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DataFormat.class, "eval", Timestamp.class, String.class)); + + addScalarFunctionConsumer.accept("TO_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + ToTimeStamp.class, "eval", Long.class)); + addScalarFunctionConsumer.accept("TO_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + ToTimeStamp.class, "eval", Integer.class)); + addScalarFunctionConsumer.accept("TO_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + ToTimeStamp.class, "eval", String.class)); + addScalarFunctionConsumer.accept("TO_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + ToTimeStamp.class, "eval", String.class, String.class)); + + addScalarFunctionConsumer.accept("UNIX_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + UnixTimeStamp.class, "eval")); + addScalarFunctionConsumer.accept("UNIX_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + UnixTimeStamp.class, "eval", String.class)); + addScalarFunctionConsumer.accept("UNIX_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + UnixTimeStamp.class, "eval", Timestamp.class)); + addScalarFunctionConsumer.accept("UNIX_TIMESTAMP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + UnixTimeStamp.class, "eval", String.class, String.class)); + + addScalarFunctionConsumer.accept("FROM_UNIXTIME", org.apache.calcite.linq4j.tree.Types.lookupMethod( + FromUnixTime.class, "eval", Long.class)); + addScalarFunctionConsumer.accept("FROM_UNIXTIME", org.apache.calcite.linq4j.tree.Types.lookupMethod( + FromUnixTime.class, "eval", Integer.class)); + addScalarFunctionConsumer.accept("FROM_UNIXTIME", org.apache.calcite.linq4j.tree.Types.lookupMethod( + FromUnixTime.class, "eval", Long.class, String.class)); + addScalarFunctionConsumer.accept("FROM_UNIXTIME", org.apache.calcite.linq4j.tree.Types.lookupMethod( + FromUnixTime.class, "eval", Integer.class, String.class)); + + // for other + addScalarFunctionConsumer.accept("SPLIT_PART", org.apache.calcite.linq4j.tree.Types.lookupMethod( + SplitPart.class, "eval", String.class, String.class, Integer.class)); + addScalarFunctionConsumer.accept("SPLIT_PART", org.apache.calcite.linq4j.tree.Types.lookupMethod( + SplitPart.class, "eval", String.class, String.class, Integer.class, Integer.class)); + + addScalarFunctionConsumer.accept("KEYVALUE", org.apache.calcite.linq4j.tree.Types.lookupMethod( + KeyValue.class, "eval", String.class, String.class, String.class, String.class)); + + addScalarFunctionConsumer.accept("DATEDIFF", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateDiff.class, "eval", String.class, Timestamp.class)); + addScalarFunctionConsumer.accept("DATEDIFF", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateDiff.class, "eval", String.class, String.class)); + addScalarFunctionConsumer.accept("DATEDIFF", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateDiff.class, "eval", Timestamp.class, Timestamp.class)); + addScalarFunctionConsumer.accept("DATEDIFF", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateDiff.class, "eval", Timestamp.class, String.class)); + + // if have, ut will fail. + //addScalarFunctionConsumer.accept("REGEXP_REPLACE", org.apache.calcite.linq4j.tree.Types.lookupMethod( + // RegExpReplace.class, "eval", String.class, String.class, String.class)); + + addScalarFunctionConsumer.accept("REGEXP", org.apache.calcite.linq4j.tree.Types.lookupMethod( + RegExp.class, "eval", String.class, String.class)); + + addScalarFunctionConsumer.accept("REGEXP_EXTRACT", org.apache.calcite.linq4j.tree.Types.lookupMethod( + RegExpExtract.class, "eval", String.class, String.class, int.class)); + + addScalarFunctionConsumer.accept("DATE_ADD", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateAdd.class, "eval", String.class, int.class)); + addScalarFunctionConsumer.accept("DATE_ADD", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateAdd.class, "eval", Timestamp.class, int.class)); + + addScalarFunctionConsumer.accept("DATE_SUB", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateSub.class, "eval", String.class, int.class)); + addScalarFunctionConsumer.accept("DATE_SUB", org.apache.calcite.linq4j.tree.Types.lookupMethod( + DateSub.class, "eval", Timestamp.class, int.class)); } @Override @@ -120,10 +228,17 @@ private Connection getConnection() { CompilerFactoryFactory#getDefaultCompilerFactory} failed. So we manually set it to the classloader of this class. */ + Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader()); try { + Properties info = new Properties(); + info.setProperty(CalciteConnectionProperty.DEFAULT_NULL_COLLATION.camelName(), NullCollation.LAST.name()); + info.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName(), "false"); + info.setProperty(CalciteConnectionProperty.UNQUOTED_CASING.camelName(), Casing.UNCHANGED.name()); + info.setProperty(CalciteConnectionProperty.QUOTING.camelName(), Quoting.BACK_TICK.name()); Class.forName("org.apache.calcite.jdbc.Driver"); - return DriverManager.getConnection("jdbc:calcite:fun=mysql"); + + return DriverManager.getConnection("jdbc:calcite:fun=mysql", info); } catch (ClassNotFoundException | SQLException e) { throw new AkUnclassifiedErrorException("Failed to initialize JDBC connection.", e); } @@ -154,12 +269,13 @@ private PreparedStatement getPreparedStatement() { sb.append(", "); } sb.append("CAST(? as "); - sb.append(FlinkTypeConverter.getTypeString(fieldTypes[i])); + sb.append(getTypeString(fieldTypes[i])); sb.append(") as "); - sb.append(fieldNames[i]); + sb.append("`" + fieldNames[i] + "`"); } String query = String.format(TEMPLATE, clause, sb); + try { return calciteConnection.prepareStatement(query); } catch (SQLException e) { @@ -167,6 +283,15 @@ private PreparedStatement getPreparedStatement() { } } + private String getTypeString(TypeInformation dataType) { + if (AlinkTypes.isVectorType(dataType) + || AlinkTypes.isMTableType(dataType) + || AlinkTypes.isTensorType(dataType)) { + return "DOUBLE"; + } + return FlinkTypeConverter.getTypeString(dataType); + } + @Override protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception { PreparedStatement preparedStatement = threadPreparedStatementMap.computeIfAbsent( @@ -178,6 +303,12 @@ protected void map(SlicedSelectedSample selection, SlicedResult result) throws E preparedStatement.setObject(i + 1, v, java.sql.Types.DECIMAL); } else if (v instanceof BigInteger) { preparedStatement.setObject(i + 1, v, java.sql.Types.BIGINT); + } else if (v instanceof Vector || v instanceof MTable || v instanceof Tensor) { + preparedStatement.setObject(i + 1, null, java.sql.Types.DOUBLE); + } else if (v instanceof Float) { + preparedStatement.setObject(i + 1, v, java.sql.Types.FLOAT); + } else if (v instanceof Integer) { + preparedStatement.setObject(i + 1, v, Types.INTEGER); } else { preparedStatement.setObject(i + 1, v); } @@ -203,10 +334,29 @@ static Tuple4 [], String[]> prepareIoSch Params params) { String clause = params.get(SelectParams.CLAUSE); Long newMLEnvId = MLEnvironmentFactory.getNewMLEnvironmentId(); - MemSourceBatchOp source = new MemSourceBatchOp(Collections.emptyList(), dataSchema) + + TypeInformation [] colTypes = dataSchema.getFieldTypes(); + TypeInformation [] newColTypes = new TypeInformation[colTypes.length]; + for (int i = 0; i < colTypes.length; i++) { + if (AlinkTypes.isVectorType(colTypes[i]) + || AlinkTypes.isMTableType(colTypes[i]) + || AlinkTypes.isTensorType(colTypes[i])) { + newColTypes[i] = AlinkTypes.DOUBLE; + } else { + newColTypes[i] = colTypes[i]; + } + } + + MemSourceBatchOp source = new MemSourceBatchOp(Collections.emptyList(), + new TableSchema(dataSchema.getFieldNames(), newColTypes)) .setMLEnvironmentId(newMLEnvId); - TableSchema outputSchema = source.linkTo(new SelectBatchOp().setClause(clause)).getSchema(); + + String newClause = SelectUtils.convertRegexClause2ColNames(dataSchema.getFieldNames(), clause); + + TableSchema outputSchema = BatchSqlOperators.select(source, newClause).getOutputTable().getSchema(); + MLEnvironmentFactory.remove(newMLEnvId); + return Tuple4.of( dataSchema.getFieldNames(), outputSchema.getFieldNames(), diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectMapper.java b/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectMapper.java index 37be805e0..fbed8e0d1 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectMapper.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectMapper.java @@ -1,6 +1,7 @@ package com.alibaba.alink.operator.common.sql; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.TableSchema; @@ -9,9 +10,21 @@ import com.alibaba.alink.common.mapper.Mapper; import com.alibaba.alink.params.sql.SelectParams; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * todo: + * 1. not support count(1), count(*) ...; it will to multi line. + */ public class SelectMapper extends Mapper { - private Mapper mapper; + private Mapper[] mappers; + + private String[] outColNames; + + private TypeInformation [] outColTypes; public SelectMapper(TableSchema dataSchema, Params params) { super(dataSchema, params); @@ -19,18 +32,48 @@ public SelectMapper(TableSchema dataSchema, Params params) { @Override public void open() { - if (SelectUtils.isSimpleSelect(params.get(SelectParams.CLAUSE), this.getDataSchema().getFieldNames())) { - mapper = new SimpleSelectMapper(this.getDataSchema(), this.params); + String[] colNames = this.getDataSchema().getFieldNames(); + String clause = params.get(SelectParams.CLAUSE); + clause = SelectUtils.convertRegexClause2ColNames(colNames, clause); + + if (SelectUtils.isSimpleSelect(clause, colNames)) { + mappers = new Mapper[1]; + mappers[0] = new SimpleSelectMapper(this.getDataSchema(), new Params().set(SelectParams.CLAUSE, clause)); + mappers[0].open(); + TableSchema outSchema = mappers[0].getOutputSchema(); + this.outColNames = outSchema.getFieldNames(); + this.outColTypes = outSchema.getFieldTypes(); } else { - mapper = new CalciteSelectMapper(this.getDataSchema(), this.params); + Tuple2 [] clauseSplits = SelectUtils.splitClauseBySimpleClause(clause, colNames); + mappers = new Mapper[clauseSplits.length]; + List outColNames = new ArrayList <>(); + List > outColTypes = new ArrayList <>(); + for (int i = 0; i < clauseSplits.length; i++) { + String curClause = clauseSplits[i].f0; + Params curParams = new Params().set(SelectParams.CLAUSE, curClause); + if (SelectUtils.isSimpleSelect(curClause, colNames)) { + mappers[i] = new SimpleSelectMapper(this.getDataSchema(), curParams); + } else { + mappers[i] = new CalciteSelectMapper(this.getDataSchema(), curParams); + } + mappers[i].open(); + TableSchema outSchema = mappers[i].getOutputSchema(); + outColNames.addAll(Arrays.asList(outSchema.getFieldNames())); + outColTypes.addAll(Arrays.asList(outSchema.getFieldTypes())); + } + this.outColNames = outColNames.toArray(new String[0]); + this.outColTypes = outColTypes.toArray(new TypeInformation [0]); } - mapper.open(); } @Override public void close() { - if (mapper != null) { - mapper.close(); + if (mappers != null) { + for (Mapper mapper : mappers) { + if (mapper != null) { + mapper.close(); + } + } } } @@ -40,23 +83,90 @@ protected void map(SlicedSelectedSample selection, SlicedResult result) throws E @Override public Row map(Row row) throws Exception { - return mapper.map(row); + if (mappers.length == 0) { + return null; + } else if (mappers.length == 1) { + return mappers[0].map(row); + } else { + Row outRow = new Row(outColNames.length); + int idx = 0; + for (Mapper mapper : mappers) { + Row r = mapper.map(row); + for (int j = 0; j < r.getArity(); j++) { + outRow.setField(idx, r.getField(j)); + idx++; + } + } + return outRow; + } } @Override public void bufferMap(Row bufferRow, int[] bufferSelectedColIndices, int[] bufferResultColIndices) throws Exception { - mapper.bufferMap(bufferRow, bufferSelectedColIndices, bufferResultColIndices); + if (mappers.length == 0) { + return; + } else if (mappers.length == 1) { + mappers[0].bufferMap(bufferRow, bufferSelectedColIndices, bufferResultColIndices); + } else { + Row in = Row.project(bufferRow, bufferSelectedColIndices); + Row out = map(in); + for (int i = 0; i < bufferResultColIndices.length; i++) { + bufferRow.setField(bufferResultColIndices[i], out.getField(i)); + } + } } @Override protected Tuple4 [], String[]> prepareIoSchema(TableSchema dataSchema, Params params) { - if (SelectUtils.isSimpleSelect(params.get(SelectParams.CLAUSE), dataSchema.getFieldNames())) { - return SimpleSelectMapper.prepareIoSchemaImpl(dataSchema, params); + String clause = params.get(SelectParams.CLAUSE); + String[] colNames = dataSchema.getFieldNames(); + String newClause = SelectUtils.convertRegexClause2ColNames(colNames, clause); + if (SelectUtils.isSimpleSelect(newClause, colNames)) { + return SimpleSelectMapper.prepareIoSchemaImpl(dataSchema, new Params().set(SelectParams.CLAUSE, + newClause)); } else { - return CalciteSelectMapper.prepareIoSchemaImpl(dataSchema, params); + Tuple2 [] clauseSplits = SelectUtils.splitClauseBySimpleClause(newClause, colNames); + mappers = new Mapper[clauseSplits.length]; + List outColNames = new ArrayList <>(); + List > outColTypes = new ArrayList <>(); + for (int i = 0; i < clauseSplits.length; i++) { + String curClause = clauseSplits[i].f0; + Params curParams = new Params().set(SelectParams.CLAUSE, curClause); + if (SelectUtils.isSimpleSelect(curClause, colNames)) { + mappers[i] = new SimpleSelectMapper(this.getDataSchema(), curParams); + } else { + mappers[i] = new CalciteSelectMapper(this.getDataSchema(), curParams); + } + mappers[i].open(); + TableSchema outSchema = mappers[i].getOutputSchema(); + outColNames.addAll(Arrays.asList(outSchema.getFieldNames())); + outColTypes.addAll(Arrays.asList(outSchema.getFieldTypes())); + mappers[i].close(); + } + this.outColNames = outColNames.toArray(new String[0]); + this.outColTypes = outColTypes.toArray(new TypeInformation [0]); + return Tuple4.of( + dataSchema.getFieldNames(), + this.outColNames, + this.outColTypes, + new String[0] + ); + } + } + + //over write output schema, output cols order by clause, otherwise it will order by input schema, + @Override + public TableSchema getOutputSchema() { + if (this.outColNames == null || this.outColTypes == null) { + Tuple4 [], String[]> t4 = prepareIoSchema(this.getDataSchema(), + this.params); + this.outColNames = t4.f1; + this.outColTypes = t4.f2; } + + return new TableSchema(this.outColNames, this.outColTypes); } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectUtils.java b/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectUtils.java index 89d581188..ccb475d89 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectUtils.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/sql/SelectUtils.java @@ -150,7 +150,7 @@ static public Tuple2 splitAndTrim(String clause, String[] c } } if (count > 0) { - outputCols[i] = outputCols[i] + (count-1); + outputCols[i] = outputCols[i] + (count - 1); } } return Tuple2.of(inputCols, outputCols); @@ -179,4 +179,39 @@ static public String replaceStar(String clause, String[] colNames) { } return newClause; } + + public static Tuple2 [] splitClauseBySimpleClause(String clause, String[] colNames) { + String[] splits = StringUtils.split(clause, ","); + Boolean[] isSimpleClauses = new Boolean[splits.length]; + List > out = new ArrayList <>(); + String tmp = splits[0]; + isSimpleClauses[0] = isSimpleSelect(splits[0], colNames); + for (int i = 1; i < splits.length; i++) { + isSimpleClauses[i] = isSimpleSelect(splits[i], colNames); + if (!isCompleteClause(tmp) || isSimpleClauses[i] == isSimpleClauses[i - 1]) { + tmp = tmp + ", " + splits[i]; + } else { + out.add(Tuple2.of(tmp.trim(), isSimpleClauses[i - 1])); + tmp = splits[i]; + } + } + if (!tmp.isEmpty()) { + out.add(Tuple2.of(tmp.trim(), isSimpleClauses[splits.length - 1])); + } + return out.toArray(new Tuple2[0]); + } + + // todo: Only considered parentheses, not quotation marks。 + static boolean isCompleteClause(String clause) { + int leftParenthesisCount = 0; + int rightParenthesisCount = 0; + for (char c : clause.toCharArray()) { + if (c == '(') { + leftParenthesisCount++; + } else if(c == ')'){ + rightParenthesisCount++; + } + } + return leftParenthesisCount == rightParenthesisCount; + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/sql/functions/StringFunctions.java b/core/src/main/java/com/alibaba/alink/operator/common/sql/functions/StringFunctions.java index d53f2e502..05914a93e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/sql/functions/StringFunctions.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/sql/functions/StringFunctions.java @@ -33,6 +33,15 @@ public class StringFunctions implements Serializable { public static Method CONCAT = Types.lookupMethod(StringFunctions.class, "concat", String.class, String.class); + public static Method CONCAT3 = Types.lookupMethod(StringFunctions.class, "concat3", String.class, + String.class, String.class); + + public static Method CONCAT4 = Types.lookupMethod(StringFunctions.class, "concat4", String.class, + String.class, String.class, String.class); + + public static Method CONCAT5 = Types.lookupMethod(StringFunctions.class, "concat5", String.class, + String.class, String.class, String.class, String.class); + public static Method MD5 = Types.lookupMethod(StringFunctions.class, "md5", String.class); public static Method SHA1 = Types.lookupMethod(StringFunctions.class, "sha1", String.class); public static Method SHA224 = Types.lookupMethod(StringFunctions.class, "sha224", String.class); @@ -203,4 +212,25 @@ public static String concat(String str1, String str2) { } return str1 + str2; } + + public static String concat3(String str1, String str2, String str3) { + if (str1 == null || str2 == null || str3 == null) { + return null; + } + return str1 + str2 + str3; + } + + public static String concat4(String str1, String str2, String str3, String str4) { + if (str1 == null || str2 == null || str3 == null || str4 == null) { + return null; + } + return str1 + str2 + str3 + str4; + } + + public static String concat5(String str1, String str2, String str3, String str4, String str5) { + if (str1 == null || str2 == null || str3 == null || str4 == null || str5 == null) { + return null; + } + return str1 + str2 + str3 + str4 + str5; + } } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/BaseSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/BaseSummary.java index 8deb94d40..bd19f60bf 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/BaseSummary.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/BaseSummary.java @@ -42,7 +42,7 @@ public abstract class BaseSummary implements Serializable { /** * count. */ - public long count; + protected long count; /** * count. diff --git a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java index baf4824fd..3b1546c30 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/statistics/basicstatistic/TableSummary.java @@ -287,7 +287,8 @@ public double centralMoment4(String colName) { int idx = findIdx(colName); double mean = mean(colName); if (idx >= 0) { - return (sum4.get(idx) - 4 * sum3.get(idx) * mean + 6 * sum2.get(idx) * mean * mean - 3 * sum.get(idx) * mean * mean * mean) / count; + return (sum4.get(idx) - 4 * sum3.get(idx) * mean + 6 * sum2.get(idx) * mean * mean - 3 * sum.get(idx) * mean + * mean * mean) / count; } else { return Double.NaN; } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DecisionTree.java b/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DecisionTree.java index 5978392b3..014ac8226 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DecisionTree.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DecisionTree.java @@ -22,6 +22,7 @@ public class DecisionTree { private final DenseData data; + private final int[] indices; private final Params params; private final Deque > queue = new ArrayDeque <>(); private final Random random; @@ -32,7 +33,12 @@ public class DecisionTree { private final ArrayList > futures; public DecisionTree(DenseData data, Params params, ExecutorService executorService) { + this(data, null, params, executorService); + } + + public DecisionTree(DenseData data, int[] indices, Params params, ExecutorService executorService) { this.data = data; + this.indices = indices; this.params = params; this.random = new Random(params.get(HasSeed.SEED)); @@ -116,7 +122,7 @@ private SequentialFeatureSplitter fitNodeMultiThread(SequentialFeatureSplitter[] try { gain = futures.get(j).get(); } catch (InterruptedException | ExecutionException e) { - throw new AkUnclassifiedErrorException("Error. ",e); + throw new AkUnclassifiedErrorException("Error. ", e); } if (gain > bestGain || bestSplitter == null) { @@ -174,12 +180,25 @@ private static void shuffle(T[] array, Random random) { } private SequentialPartition initSequentialPartition() { - ArrayList > dataIndices = new ArrayList <>(data.m); - for (int i = 0; i < data.m; ++i) { - dataIndices.add(Tuple2.of(i, data.weights[i])); - } + if (null == indices) { + + ArrayList > dataIndices = new ArrayList <>(data.m); + for (int i = 0; i < data.m; ++i) { + dataIndices.add(Tuple2.of(i, data.weights[i])); + } + + return new SequentialPartition(dataIndices); - return new SequentialPartition(dataIndices); + } else { + + ArrayList > dataIndices = new ArrayList <>(indices.length); + for (int index : indices) { + dataIndices.add(Tuple2.of(index, data.weights[index])); + } + + return new SequentialPartition(dataIndices); + + } } private SequentialFeatureSplitter[] initSplitters(SequentialPartition partition) { diff --git a/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DenseData.java b/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DenseData.java index d2eed9e2b..fd81ca7ba 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DenseData.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/tree/seriestree/DenseData.java @@ -9,8 +9,8 @@ * DataSet. */ public class DenseData { - private final static double CONTINUOUS_NULL = Double.NaN; - private final static int CATEGORICAL_NULL = Integer.MAX_VALUE; + final static double CONTINUOUS_NULL = Double.NaN; + final static int CATEGORICAL_NULL = Integer.MAX_VALUE; /** * flag for the length of data buffer array. @@ -31,10 +31,10 @@ public class DenseData { * Features. */ FeatureMeta[] featureMetas; - private final Object[] featureValues; + final Object[] featureValues; FeatureMeta labelMeta; - private Object labelValues; + Object labelValues; double[] weights; @@ -68,6 +68,10 @@ public DenseData( } + public int getNumRows() { + return this.m; + } + void resetM(int m) { AkPreconditions.checkState(m <= this.rawBufferLen); this.m = m; @@ -122,9 +126,10 @@ public void readFromInstances(Iterable instances) { } // initial weights. - if (labelMeta != null && instance.getArity() == n + 2 - || labelMeta == null && instance.getArity() == n + 1) { + if (labelMeta != null && instance.getArity() == n + 2) { weights[i] = (double) instance.getField(n + 1); + } else if (labelMeta == null && instance.getArity() == n + 1) { + weights[i] = (double) instance.getField(n); } else { weights[i] = 1.0; } diff --git a/core/src/main/java/com/alibaba/alink/operator/common/utils/PackBatchOperatorUtil.java b/core/src/main/java/com/alibaba/alink/operator/common/utils/PackBatchOperatorUtil.java index 11bc774ad..7d697879a 100644 --- a/core/src/main/java/com/alibaba/alink/operator/common/utils/PackBatchOperatorUtil.java +++ b/core/src/main/java/com/alibaba/alink/operator/common/utils/PackBatchOperatorUtil.java @@ -24,9 +24,9 @@ */ public class PackBatchOperatorUtil { //model col prefix - private static final String MODEL_COL_PREFIX = "p"; + public static final String MODEL_COL_PREFIX = "p"; //id col name - private static final String ID_COL_NAME = "id"; + public static final String ID_COL_NAME = "id"; /** * pack batch ops @@ -144,7 +144,7 @@ private static Tuple2 > mergeTypes(BatchOperator [] return Tuple2.of(new TableSchema(colNames, colTypes), colIndices); } - private static int[] addOne(int[] vec) { + public static int[] addOne(int[] vec) { for (int i = 0; i < vec.length; i++) { vec[i]++; } @@ -189,13 +189,13 @@ private static BatchOperator packBatchOp(BatchOperator op, /** * first entry is opIdx, second is meta, other fill with colIndices, others is null. */ - private static class FlattenMap implements MapFunction { + public static class FlattenMap implements MapFunction { private static final long serialVersionUID = -4502881391819047945L; private int colNum; private int opIdx; private int[] colIndices; - FlattenMap(int colNum, int opIdx, int[] colIndices) { + public FlattenMap(int colNum, int opIdx, int[] colIndices) { this.colNum = colNum; this.opIdx = opIdx; this.colIndices = colIndices; diff --git a/core/src/main/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOp.java b/core/src/main/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOp.java deleted file mode 100644 index f77ee38ae..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOp.java +++ /dev/null @@ -1,47 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import org.apache.flink.ml.api.misc.param.Params; - -import com.alibaba.alink.common.LocalMLEnvironment; -import com.alibaba.alink.common.annotation.InputPorts; -import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.annotation.OutputPorts; -import com.alibaba.alink.common.annotation.PortSpec; -import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.params.sql.JoinParams; - -/** - * Full outer join two batch operators. - */ - -@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) -@OutputPorts(values = @PortSpec(PortType.DATA)) -@NameCn("SQL操作:FullOuterJoin") -public final class FullOuterJoinLocalOp extends BaseSqlApiLocalOp - implements JoinParams { - - private static final long serialVersionUID = 6002321920184611785L; - - public FullOuterJoinLocalOp() { - this(new Params()); - } - - public FullOuterJoinLocalOp(String whereClause, String selectClause) { - this(new Params() - .set(JoinParams.JOIN_PREDICATE, whereClause) - .set(JoinParams.SELECT_CLAUSE, selectClause)); - } - - public FullOuterJoinLocalOp(Params params) { - super(params); - } - - @Override - protected void linkFromImpl(LocalOperator ... inputs) { - String joinPredicate = getJoinPredicate(); - String selectClause = getSelectClause(); - this.setOutputTable(LocalMLEnvironment.getInstance().getSqlExecutor() - .fullOuterJoin(inputs[0], inputs[1], joinPredicate, selectClause).getOutputTable()); - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/local/sql/JoinLocalOp.java b/core/src/main/java/com/alibaba/alink/operator/local/sql/JoinLocalOp.java deleted file mode 100644 index 4f0ade9b6..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/local/sql/JoinLocalOp.java +++ /dev/null @@ -1,78 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import org.apache.flink.ml.api.misc.param.Params; - -import com.alibaba.alink.common.LocalMLEnvironment; -import com.alibaba.alink.common.annotation.InputPorts; -import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.annotation.OutputPorts; -import com.alibaba.alink.common.annotation.PortSpec; -import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.params.sql.JoinParams; - -/** - * Join two batch operators. - */ -@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) -@OutputPorts(values = @PortSpec(PortType.DATA)) -@NameCn("SQL操作:Join") -public final class JoinLocalOp extends BaseSqlApiLocalOp - implements JoinParams { - - private static final long serialVersionUID = -5284150849586086589L; - - public JoinLocalOp() { - this(new Params()); - } - - public JoinLocalOp(String joinPredicate) { - this(joinPredicate, "*"); - } - - public JoinLocalOp(String joinPredicate, String selectClause) { - this(new Params() - .set(JOIN_PREDICATE, joinPredicate) - .set(SELECT_CLAUSE, selectClause)); - } - - public JoinLocalOp(Params params) { - super(params); - } - - @Override - protected void linkFromImpl(LocalOperator ... inputs) { - String selectClause = "*"; - if (this.getParams().contains(JoinParams.SELECT_CLAUSE)) { - selectClause = this.getParams().get(JoinParams.SELECT_CLAUSE); - } - String joidPredicate = this.getParams().get(JoinParams.JOIN_PREDICATE); - - LocalOperator outputOp; - switch (getType()) { - case JOIN: - outputOp = LocalMLEnvironment.getInstance().getSqlExecutor().join(inputs[0], inputs[1], joidPredicate, - selectClause); - break; - case LEFTOUTERJOIN: - outputOp = LocalMLEnvironment.getInstance().getSqlExecutor().leftOuterJoin(inputs[0], inputs[1], - joidPredicate, - selectClause); - break; - case RIGHTOUTERJOIN: - outputOp = LocalMLEnvironment.getInstance().getSqlExecutor().rightOuterJoin(inputs[0], inputs[1], - joidPredicate, - selectClause); - break; - case FULLOUTERJOIN: - outputOp = LocalMLEnvironment.getInstance().getSqlExecutor().fullOuterJoin(inputs[0], inputs[1], - joidPredicate, - selectClause); - break; - default: - throw new AkUnsupportedOperationException("Not supported binary op: " + getType()); - } - this.setOutputTable(outputOp.getOutputTable()); - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOp.java b/core/src/main/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOp.java deleted file mode 100644 index ec10dc4bb..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOp.java +++ /dev/null @@ -1,52 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import org.apache.flink.ml.api.misc.param.Params; - -import com.alibaba.alink.common.LocalMLEnvironment; -import com.alibaba.alink.common.annotation.InputPorts; -import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.annotation.OutputPorts; -import com.alibaba.alink.common.annotation.PortSpec; -import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.params.sql.JoinParams; - -/** - * Left outer join two batch operators. - */ - -@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) -@OutputPorts(values = @PortSpec(PortType.DATA)) -@NameCn("SQL操作:LeftOuterJoin") -public final class LeftOuterJoinLocalOp extends BaseSqlApiLocalOp - implements JoinParams { - - private static final long serialVersionUID = -4614107895339207282L; - - public LeftOuterJoinLocalOp() { - this(new Params()); - } - - public LeftOuterJoinLocalOp(String whereClause) { - this(whereClause, "*"); - } - - public LeftOuterJoinLocalOp(String whereClause, String selectClause) { - this(new Params() - .set(JOIN_PREDICATE, whereClause) - .set(SELECT_CLAUSE, selectClause)); - } - - public LeftOuterJoinLocalOp(Params params) { - super(params); - } - - @Override - protected void linkFromImpl(LocalOperator ... inputs) { - String joinPredicate = getJoinPredicate(); - String selectClause = getSelectClause(); - this.setOutputTable(LocalMLEnvironment.getInstance().getSqlExecutor() - .leftOuterJoin(inputs[0], inputs[1], joinPredicate, selectClause) - .getOutputTable()); - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOp.java b/core/src/main/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOp.java deleted file mode 100644 index 93d07d32e..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOp.java +++ /dev/null @@ -1,47 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import org.apache.flink.ml.api.misc.param.Params; - -import com.alibaba.alink.common.LocalMLEnvironment; -import com.alibaba.alink.common.annotation.InputPorts; -import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.annotation.OutputPorts; -import com.alibaba.alink.common.annotation.PortSpec; -import com.alibaba.alink.common.annotation.PortType; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.params.sql.JoinParams; - -/** - * Right outer join two batch operators. - */ -@InputPorts(values = {@PortSpec(PortType.DATA), @PortSpec(PortType.DATA)}) -@OutputPorts(values = @PortSpec(PortType.DATA)) -@NameCn("SQL操作:RightOuterJoin") -public final class RightOuterJoinLocalOp extends BaseSqlApiLocalOp - implements JoinParams { - - private static final long serialVersionUID = -9188072782747998516L; - - public RightOuterJoinLocalOp() { - this(new Params()); - } - - public RightOuterJoinLocalOp(String whereClause, String selectClause) { - this(new Params() - .set(JOIN_PREDICATE, whereClause) - .set(SELECT_CLAUSE, selectClause)); - } - - public RightOuterJoinLocalOp(Params params) { - super(params); - } - - @Override - protected void linkFromImpl(LocalOperator ... inputs) { - String joinPredicate = getJoinPredicate(); - String selectClause = getSelectClause(); - this.setOutputTable(LocalMLEnvironment.getInstance().getSqlExecutor() - .rightOuterJoin(inputs[0], inputs[1], joinPredicate, selectClause) - .getOutputTable()); - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/local/sql/SelectLocalOp.java b/core/src/main/java/com/alibaba/alink/operator/local/sql/SelectLocalOp.java index dcb9fc020..afeaffcf4 100644 --- a/core/src/main/java/com/alibaba/alink/operator/local/sql/SelectLocalOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/local/sql/SelectLocalOp.java @@ -1,22 +1,17 @@ package com.alibaba.alink.operator.local.sql; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.api.misc.param.Params; -import com.alibaba.alink.common.LocalMLEnvironment; -import com.alibaba.alink.common.MTable; -import com.alibaba.alink.common.MTableUtil; import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.utils.TableUtil; -import com.alibaba.alink.operator.common.sql.SelectUtils; -import com.alibaba.alink.operator.local.LocalOperator; +import com.alibaba.alink.operator.common.sql.SelectMapper; +import com.alibaba.alink.operator.local.utils.MapLocalOp; import com.alibaba.alink.params.sql.SelectParams; /** * Select the fields of a batch operator. */ @NameCn("SQL操作:Select") -public final class SelectLocalOp extends BaseSqlApiLocalOp +public final class SelectLocalOp extends MapLocalOp implements SelectParams { private static final long serialVersionUID = -1867376056670775636L; @@ -25,32 +20,8 @@ public SelectLocalOp() { this(new Params()); } - public SelectLocalOp(String clause) { - this(new Params().set(CLAUSE, clause)); - } - public SelectLocalOp(Params params) { - super(params); + super(SelectMapper::new, params); } - @Override - protected void linkFromImpl(LocalOperator ... inputs) { - LocalOperator in = checkAndGetFirst(inputs); - String[] colNames = in.getColNames(); - - String clause = getClause(); - String newClause = SelectUtils.convertRegexClause2ColNames(colNames, clause); - - if (SelectUtils.isSimpleSelect(newClause, colNames)) { - Tuple2 sTuple = SelectUtils.splitAndTrim(clause, colNames); - - int[] colIndexes = TableUtil.findColIndicesWithAssertAndHint(in.getSchema(), sTuple.f0); - MTable mt = MTableUtil.selectAs(in.getOutputTable(), colIndexes, sTuple.f1); - - this.setOutputTable(mt); - } else { - this.setOutputTable( - LocalMLEnvironment.getInstance().getSqlExecutor().select(in, newClause).getOutputTable()); - } - } } diff --git a/core/src/main/java/com/alibaba/alink/operator/stream/PipelinePredictStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/stream/PipelinePredictStreamOp.java deleted file mode 100644 index 468bc2cd3..000000000 --- a/core/src/main/java/com/alibaba/alink/operator/stream/PipelinePredictStreamOp.java +++ /dev/null @@ -1,51 +0,0 @@ -package com.alibaba.alink.operator.stream; - -import org.apache.flink.ml.api.misc.param.Params; - -import com.alibaba.alink.common.annotation.NameCn; -import com.alibaba.alink.common.annotation.NameEn; -import com.alibaba.alink.common.exceptions.AkIllegalDataException; -import com.alibaba.alink.operator.stream.utils.DataStreamConversionUtil; -import com.alibaba.alink.params.PipelinePredictParams; -import com.alibaba.alink.pipeline.PipelineModel; - -/** - * - */ -@NameCn("Pipeline在线预测") -@NameEn("Pipeline prediction") -public final class PipelinePredictStreamOp extends StreamOperator - implements PipelinePredictParams { - private PipelineModel pipelineModel; - - public PipelinePredictStreamOp(PipelineModel model) { - this(model, new Params()); - } - - public PipelinePredictStreamOp(PipelineModel pipelineModel, Params params) { - super(params); - this.pipelineModel = pipelineModel; - } - - public PipelinePredictStreamOp(Params params) { - super(params); - } - - @Override - public PipelinePredictStreamOp linkFrom(StreamOperator ... inputs) { - try { - if (getParams().contains(PipelinePredictParams.MODEL_FILE_PATH)) { - pipelineModel = PipelineModel.load(getModelFilePath()) - .setMLEnvironmentId(inputs[0].getMLEnvironmentId()); - } - - StreamOperator result = pipelineModel.transform(inputs[0]); - this.setOutput(DataStreamConversionUtil.toTable(getMLEnvironmentId(), - result.getDataStream(), result.getSchema())); - } catch (Exception ex) { - ex.printStackTrace(); - throw new AkIllegalDataException(ex.toString()); - } - return this; - } -} diff --git a/core/src/main/java/com/alibaba/alink/operator/stream/sql/SelectStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/stream/sql/SelectStreamOp.java index 84584a4e1..0a080954e 100644 --- a/core/src/main/java/com/alibaba/alink/operator/stream/sql/SelectStreamOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/stream/sql/SelectStreamOp.java @@ -5,6 +5,7 @@ import com.alibaba.alink.common.annotation.Internal; import com.alibaba.alink.common.annotation.NameCn; import com.alibaba.alink.common.annotation.NameEn; +import com.alibaba.alink.operator.common.sql.SelectMapper; import com.alibaba.alink.operator.common.sql.SelectUtils; import com.alibaba.alink.operator.common.sql.SimpleSelectMapper; import com.alibaba.alink.operator.stream.StreamOperator; @@ -16,7 +17,7 @@ */ @NameCn("SQL操作:Select") @NameEn("SQL:Select") -public final class SelectStreamOp extends BaseSqlApiStreamOp +public final class SelectStreamOp extends MapStreamOp implements SelectParams { private static final long serialVersionUID = 7401063240614374140L; @@ -30,41 +31,8 @@ public SelectStreamOp(String clause) { } public SelectStreamOp(Params params) { - super(params); - } - - @Override - public SelectStreamOp linkFrom(StreamOperator ... inputs) { - StreamOperator in = checkAndGetFirst(inputs); - String[] colNames = in.getColNames(); - - String clause = getClause(); - String newClause = SelectUtils.convertRegexClause2ColNames(colNames, clause); - - if (SelectUtils.isSimpleSelect(newClause, colNames)) { - this.setOutputTable( - in.link(new SimpleSelectStreamOp() - .setClause(newClause) - .setMLEnvironmentId(in.getMLEnvironmentId()) - ).getOutputTable()); - } else { - this.setOutputTable(StreamSqlOperators.select(in, newClause).getOutputTable()); - } - - return this; - } - - @Internal - private class SimpleSelectStreamOp extends MapStreamOp - implements SelectParams { - - public SimpleSelectStreamOp() { - this(null); - } - - public SimpleSelectStreamOp(Params param) { - super(SimpleSelectMapper::new, param); - } + super(SelectMapper::new, params); } + } diff --git a/core/src/main/java/com/alibaba/alink/operator/stream/statistics/QuantileStreamOp.java b/core/src/main/java/com/alibaba/alink/operator/stream/statistics/QuantileStreamOp.java index 743d735f4..acaf60bab 100644 --- a/core/src/main/java/com/alibaba/alink/operator/stream/statistics/QuantileStreamOp.java +++ b/core/src/main/java/com/alibaba/alink/operator/stream/statistics/QuantileStreamOp.java @@ -19,14 +19,14 @@ import com.alibaba.alink.operator.common.statistics.basicstatistic.QuantileWindowFunction; import com.alibaba.alink.operator.stream.StreamOperator; import com.alibaba.alink.params.shared.HasTimeCol_null; -import com.alibaba.alink.params.statistics.QuantileParams; +import com.alibaba.alink.params.statistics.QuantileStreamParams; @InputPorts(values = {@PortSpec(PortType.DATA)}) @OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)}) @NameCn("分位数") @NameEn("Quantile") public final class QuantileStreamOp extends StreamOperator - implements QuantileParams { + implements QuantileStreamParams { private static final long serialVersionUID = 8927492832239574864L; diff --git a/core/src/main/java/com/alibaba/alink/params/feature/QuantileDiscretizerTrainParams.java b/core/src/main/java/com/alibaba/alink/params/feature/QuantileDiscretizerTrainParams.java index 4f007b1fd..74b3d7b5d 100644 --- a/core/src/main/java/com/alibaba/alink/params/feature/QuantileDiscretizerTrainParams.java +++ b/core/src/main/java/com/alibaba/alink/params/feature/QuantileDiscretizerTrainParams.java @@ -1,6 +1,7 @@ package com.alibaba.alink.params.feature; import com.alibaba.alink.params.shared.colname.HasSelectedCols; +import com.alibaba.alink.params.statistics.HasRoundMode; /** * Params for QuantileDiscretizerTrain. @@ -9,5 +10,6 @@ public interface QuantileDiscretizerTrainParams extends HasSelectedCols , HasNumBuckets , HasNumBucketsArray , - HasLeftOpen { + HasLeftOpen , + HasRoundMode { } diff --git a/core/src/main/java/com/alibaba/alink/params/statistics/QuantileBatchParams.java b/core/src/main/java/com/alibaba/alink/params/statistics/QuantileBatchParams.java deleted file mode 100644 index ff84200c6..000000000 --- a/core/src/main/java/com/alibaba/alink/params/statistics/QuantileBatchParams.java +++ /dev/null @@ -1,9 +0,0 @@ -package com.alibaba.alink.params.statistics; - -import com.alibaba.alink.params.shared.colname.HasSelectedCol; - -public interface QuantileBatchParams extends - HasSelectedCol , - HasQuantileNum , - HasRoundMode { -} diff --git a/core/src/main/java/com/alibaba/alink/params/statistics/QuantileParams.java b/core/src/main/java/com/alibaba/alink/params/statistics/QuantileParams.java index 818da9dd3..7a75a4e8f 100644 --- a/core/src/main/java/com/alibaba/alink/params/statistics/QuantileParams.java +++ b/core/src/main/java/com/alibaba/alink/params/statistics/QuantileParams.java @@ -1,11 +1,9 @@ package com.alibaba.alink.params.statistics; -import com.alibaba.alink.params.shared.HasTimeCol_null; +import com.alibaba.alink.params.shared.colname.HasSelectedCols; public interface QuantileParams extends - StatBaseParams , - HasTimeCol_null , - HasDalayTime , - HasQuantileNum { - -} \ No newline at end of file + HasSelectedCols , + HasQuantileNum , + HasRoundMode { +} diff --git a/core/src/main/java/com/alibaba/alink/params/statistics/QuantileStreamParams.java b/core/src/main/java/com/alibaba/alink/params/statistics/QuantileStreamParams.java new file mode 100644 index 000000000..661dd9016 --- /dev/null +++ b/core/src/main/java/com/alibaba/alink/params/statistics/QuantileStreamParams.java @@ -0,0 +1,11 @@ +package com.alibaba.alink.params.statistics; + +import com.alibaba.alink.params.shared.HasTimeCol_null; + +public interface QuantileStreamParams extends + StatBaseParams , + HasTimeCol_null , + HasDalayTime , + HasQuantileNum { + +} diff --git a/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/SqlBatchOpsTest.java b/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/SqlBatchOpsTest.java index e8ac040b0..623ef1018 100644 --- a/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/SqlBatchOpsTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/SqlBatchOpsTest.java @@ -116,7 +116,6 @@ public void testAs() throws Exception { AsBatchOp op = new AsBatchOp() .setClause("a"); inOp.link(op).collect(); - op.print(); } @Test @@ -157,10 +156,10 @@ public void testRegExp() throws Exception { @Test public void testNow() throws Exception { MemSourceBatchOp inOp = createTable3(); - inOp.lazyPrint(-1); + //inOp.lazyPrint(-1); SelectBatchOp op = new SelectBatchOp() - .setClause("NOW(10) as b1, NOW(colb) as b2"); + .setClause("NOW(10) as b1, NOW(colb) as b2, now() as b3"); inOp.link(op).collect(); } diff --git a/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOpTest.java b/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOpTest.java index ff37583a2..dd030e7be 100644 --- a/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/batch/dataproc/TypeConvertBatchOpTest.java @@ -45,6 +45,7 @@ public void test2() throws Exception { ); BatchOperator mSource = new MemSourceBatchOp(rows, schemaStr); + mSource.print(); BatchOperator typeConvert = new TypeConvertBatchOp() .setTargetType("double") diff --git a/core/src/test/java/com/alibaba/alink/operator/batch/sql/SelectBatchOpTest.java b/core/src/test/java/com/alibaba/alink/operator/batch/sql/SelectBatchOpTest.java index ac1f1454d..3a11fda71 100644 --- a/core/src/test/java/com/alibaba/alink/operator/batch/sql/SelectBatchOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/batch/sql/SelectBatchOpTest.java @@ -76,11 +76,27 @@ public void testCSelect4() throws Exception { @Test public void testCSelect5() throws Exception { - String[] originSqlCols = BatchSqlOperators.select(data(), "f_string, f_double, f_string, f_string").getColNames(); + String[] originSqlCols = BatchSqlOperators.select(data(), "f_string, f_double, f_string, f_string") + .getColNames(); String[] simpleSelectCols = data().select("f_string, f_double, f_string, f_string").getColNames(); Assert.assertArrayEquals(originSqlCols, simpleSelectCols); } + @Test + public void test() throws Exception { + List testArray = Arrays.asList( + Row.of("a", 1L, 1, 2.0, true), + Row.of(null, 2L, 2, -3.0, true), + Row.of("c", null, null, 2.0, false), + Row.of("a", 0L, 0, null, null) + ); + + String[] colNames = new String[] {"f_string", "group", "f_lint", "f_double", "f_boolean"}; + + BatchOperator source = new MemSourceBatchOp(testArray, colNames); + source.select("cast(`group` as VARCHAR) as `group`").print(); + } + private BatchOperator data() { List testArray = Arrays.asList( Row.of("a", 1L, 1, 2.0, true), diff --git a/core/src/test/java/com/alibaba/alink/operator/common/sql/SelectUtilsTest.java b/core/src/test/java/com/alibaba/alink/operator/common/sql/SelectUtilsTest.java index 97ffc79dd..64bdfe236 100644 --- a/core/src/test/java/com/alibaba/alink/operator/common/sql/SelectUtilsTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/common/sql/SelectUtilsTest.java @@ -1,5 +1,8 @@ package com.alibaba.alink.operator.common.sql; +import org.apache.flink.api.java.tuple.Tuple2; + +import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.testutil.AlinkTestBase; import org.junit.Assert; import org.junit.Test; @@ -33,5 +36,18 @@ public void testIsSimpleClause() { Assert.assertTrue(SelectUtils.isSimpleSelect("*, f_double as fr_1", colNames)); } + @Test + public void testSplit() { + String sqlStr = "ts, TIMESTAMPDIFF(DAY, ts, TIMESTAMP '2022-05-11 00:00:00') AS past_days," + + " TIMESTAMPDIFF(WEEK, ts, TIMESTAMP '2022-05-11 00:00:00') AS past_weeks," + + " TIMESTAMPDIFF(MONTH, ts, TIMESTAMP '2022-05-11 00:00:00') AS past_months," + + " TIMESTAMPDIFF(YEAR, ts, TIMESTAMP '2022-05-11 00:00:00') AS past_years"; + + String[] colNames = new String[] {"user_id", "ts", "type"}; + + Tuple2 [] t2 = SelectUtils.splitClauseBySimpleClause(sqlStr, colNames); + System.out.println(JsonConverter.toJson(t2)); + } + } diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/AllTypeOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/AllTypeOpTest.java index ea27ffbdd..aaa105ad7 100644 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/AllTypeOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/local/sql/AllTypeOpTest.java @@ -43,29 +43,29 @@ public void test1() { for (int i = 0; i < colNames.length; i++) { newColNames[i] = colNames[i] + "_2"; } - LocalOperator data1 = data.as(newColNames); - - String join_select_clause = getJoinSelectClause(colNames, newColNames); - - for (int i = 0; i < colNames.length; i++) { - System.out.println("\n>>> " + "join on " + colNames[i]); - new JoinLocalOp() - .setJoinPredicate(colNames[i] + "=" + newColNames[i]) - .linkFrom(data, data1) - .print(); - System.out.println("\n>>> " + "left join on " + colNames[i]); - new LeftOuterJoinLocalOp() - .setSelectClause(join_select_clause) - .setJoinPredicate(colNames[i] + "=" + newColNames[i]) - .linkFrom(data, data1) - .print(); - System.out.println("\n>>> " + "right join on " + colNames[i]); - new RightOuterJoinLocalOp() - .setSelectClause(join_select_clause) - .setJoinPredicate(colNames[i] + "=" + newColNames[i]) - .linkFrom(data, data1) - .print(); - } + //LocalOperator data1 = data.as(newColNames); + // + //String join_select_clause = getJoinSelectClause(colNames, newColNames); + // + //for (int i = 0; i < colNames.length; i++) { + // System.out.println("\n>>> " + "join on " + colNames[i]); + // new JoinLocalOp() + // .setJoinPredicate(colNames[i] + "=" + newColNames[i]) + // .linkFrom(data, data1) + // .print(); + // System.out.println("\n>>> " + "left join on " + colNames[i]); + // new LeftOuterJoinLocalOp() + // .setSelectClause(join_select_clause) + // .setJoinPredicate(colNames[i] + "=" + newColNames[i]) + // .linkFrom(data, data1) + // .print(); + // System.out.println("\n>>> " + "right join on " + colNames[i]); + // new RightOuterJoinLocalOp() + // .setSelectClause(join_select_clause) + // .setJoinPredicate(colNames[i] + "=" + newColNames[i]) + // .linkFrom(data, data1) + // .print(); + //} //LocalOperator data2 = new UnionAllLocalOp().linkFrom(data, data); //LocalOperator data3 = new UnionAllLocalOp().linkFrom(data2, data); diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/BaseSqlApiLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/BaseSqlApiLocalOpTest.java index a6055386c..75c980728 100644 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/BaseSqlApiLocalOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/local/sql/BaseSqlApiLocalOpTest.java @@ -8,8 +8,6 @@ import org.junit.Assert; import org.junit.Test; -import java.sql.Timestamp; - public class BaseSqlApiLocalOpTest { Row[] rows = new Row[] { @@ -24,30 +22,16 @@ public class BaseSqlApiLocalOpTest { Row.of("4L", "3L", 10.0), }; - @Test - public void test1() throws Exception { - LocalOperator data = new MemSourceLocalOp(rows, new String[] {"f1", "f2", "f3"}); - LocalOperator data1 = new MemSourceLocalOp(rows1, new String[] {"f1", "f2", "f3"}); - - new FullOuterJoinLocalOp().setJoinPredicate("a.f1=b.f1") - .setSelectClause("case when a.f1 is null then b.f1 when b.f1 is null then a.f1 else b.f1 end as uid, " - + "case when a.f1 is null then b.f3 when b.f1 is null then a.f3 else b.f3 end as factors") - .linkFrom(data, data1).print(); - } @Test public void test() { LocalOperator data = new MemSourceLocalOp(rows, new String[] {"f1", "f2", "f3"}); Assert.assertEquals(data.select("f1").getColNames().length, 1); Assert.assertEquals(data.select(new String[] {"f1", "f2"}).getColNames().length, 2); - Assert.assertEquals(new JoinLocalOp().setJoinPredicate("a.f1=b.f1").setSelectClause("a.f1 as f1") - .linkFrom(data, data).getColNames().length, 1); - Assert.assertEquals(new LeftOuterJoinLocalOp().setJoinPredicate("a.f1=b.f1").setSelectClause("a.f1 as f1") - .linkFrom(data, data).getColNames().length, 1); - Assert.assertEquals(new RightOuterJoinLocalOp().setJoinPredicate("a.f1=b.f1").setSelectClause("a.f1 as f1") - .linkFrom(data, data).getColNames().length, 1); - Assert.assertEquals(new FullOuterJoinLocalOp().setJoinPredicate("a.f1=b.f1").setSelectClause("a.f1 as f1") - .linkFrom(data, data).getColNames().length, 1); + Assert.assertEquals(LocalMLEnvironment.getInstance().getSqlExecutor().join(data, data,"a.f1=b.f1","a.f1 as f1").getColNames().length, 1); + Assert.assertEquals(LocalMLEnvironment.getInstance().getSqlExecutor().leftOuterJoin(data, data,"a.f1=b.f1","a.f1 as f1").getColNames().length, 1); + Assert.assertEquals(LocalMLEnvironment.getInstance().getSqlExecutor().rightOuterJoin(data, data,"a.f1=b.f1","a.f1 as f1").getColNames().length, 1); + Assert.assertEquals(LocalMLEnvironment.getInstance().getSqlExecutor().fullOuterJoin(data, data,"a.f1=b.f1","a.f1 as f1").getColNames().length, 1); Assert.assertEquals(LocalMLEnvironment.getInstance().getSqlExecutor().minus(data,data).getColNames().length,3); //Assert.assertEquals(LocalMLEnvironment.getInstance().getSqlExecutor().minusAll(data,data).getColNames().length,3); Assert.assertEquals(LocalMLEnvironment.getInstance().getSqlExecutor().union(data,data).getColNames().length,3); diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/FilterLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/FilterLocalOpTest.java index cafc05c2e..b9f7e505f 100644 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/FilterLocalOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/local/sql/FilterLocalOpTest.java @@ -21,9 +21,12 @@ public void testFilterLocalOp() { Row.of("Nevada", 2002, 2.9), Row.of("Nevada", 2003, 3.2) ); - LocalOperator batch_data = new TableSourceLocalOp(new MTable(df, "f1 string, f2 int, f3 double")); - LocalOperator op = new FilterLocalOp().setClause("f1='Ohio'"); - batch_data = batch_data.link(op); - batch_data.print(); + LocalOperator source = new TableSourceLocalOp(new MTable(df, "f1 string, f2 int, f3 double")); + //LocalOperator op = new FilterLocalOp().setClause("f1='Ohio'"); + //source = source.link(op); + //source.print(); + + source.link(new FilterLocalOp().setClause("f1='Ohio'")).print(); + source.link(new FilterLocalOp().setClause("f2<=2001")).print(); } } \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOpTest.java deleted file mode 100644 index 96b364ac6..000000000 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/FullOuterJoinLocalOpTest.java +++ /dev/null @@ -1,23 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.operator.local.source.TableSourceLocalOp; -import org.junit.Test; - -public class FullOuterJoinLocalOpTest { - @Test - public void testFullOuterJoinLocalOp() { - String URL = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv"; - String SCHEMA_STR - = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; - LocalOperator data1 = new TableSourceLocalOp( - new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - LocalOperator data2 = new TableSourceLocalOp( - new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - LocalOperator joinOp = - new FullOuterJoinLocalOp().setJoinPredicate("a.category=b.category").setSelectClause( - "a.petal_length"); - joinOp.linkFrom(data1, data2).print(); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/JoinLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/JoinLocalOpTest.java deleted file mode 100644 index 0a4d5a6fc..000000000 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/JoinLocalOpTest.java +++ /dev/null @@ -1,26 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.operator.local.source.TableSourceLocalOp; -import org.junit.Test; - -public class JoinLocalOpTest { - @Test - public void testJoinLocalOp() { - //String URL = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv"; - //String SCHEMA_STR - // = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; - //LocalOperator data1 = new TableSourceLocalOp( - // new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - //LocalOperator data2 = new TableSourceLocalOp( - // new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - - LocalOperator data1 = IrisData.getLocalSourceOp(); - LocalOperator data2 = IrisData.getLocalSourceOp(); - - LocalOperator joinOp = new JoinLocalOp().setJoinPredicate("a.category=b.category").setSelectClause( - "a.petal_length"); - joinOp.linkFrom(data1, data2).print(); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOpTest.java deleted file mode 100644 index e5cfa025e..000000000 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/LeftOuterJoinLocalOpTest.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; -import com.alibaba.alink.operator.local.LocalOperator; -import com.alibaba.alink.operator.local.source.TableSourceLocalOp; -import org.junit.Test; - -public class LeftOuterJoinLocalOpTest { - @Test - public void testLeftOuterJoinLocalOp() { - //String URL = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv"; - //String SCHEMA_STR - // = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; - //LocalOperator data1 = new TableSourceLocalOp( - // new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - //LocalOperator data2 = new TableSourceLocalOp( - // new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - - LocalOperator data1 = IrisData.getLocalSourceOp(); - LocalOperator data2 = IrisData.getLocalSourceOp(); - - LocalOperator joinOp = - new LeftOuterJoinLocalOp().setJoinPredicate("a.category=b.category").setSelectClause( - "a.petal_length"); - joinOp.linkFrom(data1, data2).print(); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOpTest.java deleted file mode 100644 index 8a3d85580..000000000 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/RightOuterJoinLocalOpTest.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.alibaba.alink.operator.local.sql; - -import com.alibaba.alink.operator.local.LocalOperator; -import org.junit.Test; - -public class RightOuterJoinLocalOpTest { - @Test - public void testRightOuterJoinLocalOp() { - //String URL = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv"; - //String SCHEMA_STR - // = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; - //LocalOperator data1 = new TableSourceLocalOp( - // new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - //LocalOperator data2 = new TableSourceLocalOp( - // new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR).collectMTable()); - - LocalOperator data1 = IrisData.getLocalSourceOp(); - LocalOperator data2 = IrisData.getLocalSourceOp(); - - LocalOperator joinOp = new RightOuterJoinLocalOp().setJoinPredicate("a.category=b.category") - .setSelectClause("a.petal_length"); - joinOp.linkFrom(data1, data2).printStatistics().print(); - } -} \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/local/sql/SelectLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/sql/SelectLocalOpTest.java index 40f4e7036..48b8f6cea 100644 --- a/core/src/test/java/com/alibaba/alink/operator/local/sql/SelectLocalOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/local/sql/SelectLocalOpTest.java @@ -1,15 +1,15 @@ package com.alibaba.alink.operator.local.sql; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; -import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp; -import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; -import com.alibaba.alink.operator.batch.sql.BatchSqlOperators; -import com.alibaba.alink.operator.batch.sql.SelectBatchOp; +import com.alibaba.alink.common.linalg.VectorUtil; +import com.alibaba.alink.common.type.AlinkTypes; import com.alibaba.alink.operator.local.LocalOperator; import com.alibaba.alink.operator.local.source.MemSourceLocalOp; -import com.alibaba.alink.operator.local.source.TableSourceLocalOp; +import com.alibaba.alink.pipeline.Pipeline; +import com.alibaba.alink.pipeline.sql.Select; import org.junit.Assert; import org.junit.Test; @@ -17,6 +17,90 @@ import java.util.List; public class SelectLocalOpTest { + + @Test + public void testVector() { + Row[] array = new Row[] { + Row.of(VectorUtil.getVector("$31$0:1.0 1:1.0 2:1.0 30:1.0"), "1.0 1.0 1.0 1.0", 1.0, 1.0, 1.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:1.0 1:1.0 2:0.0 30:1.0"), "1.0 1.0 0.0 1.0", 1.0, 1.0, 0.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:1.0 1:0.0 2:1.0 30:1.0"), "1.0 0.0 1.0 1.0", 1.0, 0.0, 1.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:1.0 1:0.0 2:1.0 30:1.0"), "1.0 0.0 1.0 1.0", 1.0, 0.0, 1.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0) + }; + + LocalOperator source = new MemSourceLocalOp( + Arrays.asList(array), + new TableSchema( + new String[] {"svec", "vec", "f0", "f1", "f2", "f3", "labels"}, + new TypeInformation [] { + AlinkTypes.SPARSE_VECTOR, + AlinkTypes.DENSE_VECTOR, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + })); + + source.select("svec, f0, 'neg' as f11") + .print("********** simple ********"); + source.select("svec, f0, (f1 + cast(0.5 as double)) as f11") + .print("********** test ********"); + source.select("CAST(f0 AS VARCHAR) AS f0_str, CAST(f1 as VARCHAR) AS f1_str, svec, vec, f0, f1") + .select("f0,f1,concat(f0_str, ',') as f0_f1") + .print("********** test for , ********"); + + } + + @Test + public void testPipeline() { + Row[] array = new Row[] { + Row.of(VectorUtil.getVector("$31$0:1.0 1:1.0 2:1.0 30:1.0"), "1.0 1.0 1.0 1.0", 1.0, 1.0, 1.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:1.0 1:1.0 2:0.0 30:1.0"), "1.0 1.0 0.0 1.0", 1.0, 1.0, 0.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:1.0 1:0.0 2:1.0 30:1.0"), "1.0 0.0 1.0 1.0", 1.0, 0.0, 1.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:1.0 1:0.0 2:1.0 30:1.0"), "1.0 0.0 1.0 1.0", 1.0, 0.0, 1.0, 1.0, 1), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0), + Row.of(VectorUtil.getVector("$31$0:0.0 1:1.0 2:1.0 30:0.0"), "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0) + }; + + LocalOperator source = new MemSourceLocalOp( + Arrays.asList(array), + new TableSchema( + new String[] {"svec", "vec", "f0", "f1", "f2", "f3", "labels"}, + new TypeInformation [] { + AlinkTypes.SPARSE_VECTOR, + AlinkTypes.DENSE_VECTOR, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + AlinkTypes.DOUBLE, + })); + + Pipeline pipeline = new Pipeline() + .add(new Select().setClause("svec, f0, 'neg' as f11")); + + pipeline.fit(source).transform(source).print(); + + Pipeline pipeline2 = new Pipeline() + .add(new Select().setClause("svec, f0, (f1 + cast(0.5 as double)) as f11")); + + pipeline2.fit(source).transform(source).print(); + + Pipeline pipeline3 = new Pipeline() + .add(new Select().setClause( + "CAST(f0 AS VARCHAR) AS f0_str, CAST(f1 as VARCHAR) AS f1_str, svec, vec, f0, f1")) + .add(new Select().setClause("f0,f1,concat(f0_str, ',') as f0_f1")) + ; + + pipeline3.fit(source).transform(source).print(); + } + @Test public void testSelectLocalOp() { //String URL = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/iris.csv"; @@ -28,7 +112,6 @@ public void testSelectLocalOp() { data.link(new SelectLocalOp().setClause("category as a")).print(); } - @Test public void testSimpleSelect() throws Exception { data().link( @@ -42,32 +125,11 @@ public void testSimpleSelect2() throws Exception { data().select("f_double, f_long").print(); } - //@Test - //public void testSelect() throws Exception { - // data().link( - // new SelectLocalOp() - // .setClause("f_double, `f_l.*`") - // ).print(); - //} - // - //@Test - //public void testSelect2() throws Exception { - // data().select("f_double, `f_l.*`").print(); - //} - // - //@Test - //public void testSelect3() throws Exception { - // data().link( - // new SelectLocalOp() - // .setClause("`f_d.*`, `f_l.*`") - // ).print(); - //} - @Test public void testCSelect() throws Exception { data().link( new SelectLocalOp() - .setClause("f_double, `f_l.*`, f_double+1 as f_double_1") + .setClause("f_double,`f_l.*`, f_double+1 as f_double_1") ).print(); } @@ -105,6 +167,7 @@ private LocalOperator data() { Row.of("a", 0L, 0, null, null) ); + // for test String[] colNames = new String[] {"f_string", "f_long", "f_lint", "f_double", "f_boolean"}; return new MemSourceLocalOp(testArray, colNames); diff --git a/core/src/test/java/com/alibaba/alink/operator/local/statistics/SummarizerLocalOpTest.java b/core/src/test/java/com/alibaba/alink/operator/local/statistics/SummarizerLocalOpTest.java index e57c0c696..29f15162e 100644 --- a/core/src/test/java/com/alibaba/alink/operator/local/statistics/SummarizerLocalOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/local/statistics/SummarizerLocalOpTest.java @@ -1,110 +1,68 @@ -//package com.alibaba.alink.operator.local.statistics; -// -//import org.apache.flink.types.Row; -// -//import com.alibaba.alink.operator.batch.dataproc.format.ColumnsToTripleBatchOp; -//import com.alibaba.alink.operator.batch.feature.OverWindowBatchOp; -//import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; -//import junit.framework.TestCase; -//import org.junit.Test; -// -//import java.sql.Timestamp; -//import java.util.Arrays; -// -//public class SummarizerLocalOpTest extends TestCase { -// -// @Test -// public void test2() throws Exception { -// Row[] testArray = -// new Row[] { -// Row.of("a", 1, 1.1, 1.2), -// Row.of("b", null, 0.9, 1.0), -// Row.of("c", 100, -0.01, 1.0), -// Row.of("d", -99, 100.9, 0.1), -// Row.of("a", 1, 1.1, 1.2), -// Row.of("b", null, 0.9, 1.0), -// Row.of("c", null, -0.01, 0.2), -// Row.of("d", -99, 100.9, 0.3) -// }; -// -// String[] colNames = new String[] {"col1", "col2", "col3", "col4"}; -// -// MemSourceBatchOp source = new MemSourceBatchOp(Arrays.asList(testArray), colNames); -// -// source -// .link( -// new ColumnsToTripleBatchOp() -// .setSelectedCols(colNames) -// .setTripleColumnValueSchemaStr("colName string, val string") -// ) -// .groupBy("colName", "colName, COUNT(DISTINCT val) AS valCount") -// .lazyPrint(100); -// -// source -// .link( -// new ColumnsToTripleBatchOp() -// .setSelectedCols(colNames) -// .setTripleColumnValueSchemaStr("colName string, val string") -// ) -// .groupBy("colName, val", "colName, val, COUNT(val) AS cnt") -// .link( -// new OverWindowBatchOp() -// .setOrderBy("cnt desc") -// .setGroupCols("colName") -// .setClause("ROW_NUMBER(cnt) AS rnk") -// .setReservedCols("colName","val", "cnt") -// ) -// .filter("rnk<=2") -// .print(100); -// } -// -// @Test -// public void test3() throws Exception { -// String[] colNames = new String[] {"id", "user", "sell_time", "price"}; -// -// MemSourceBatchOp source = new MemSourceBatchOp( -// new Row[] { -// Row.of(1, "user2", Timestamp.valueOf("2021-01-01 00:01:00"), 20), -// Row.of(2, "user1", Timestamp.valueOf("2021-01-01 00:02:00"), 50), -// Row.of(3, "user2", Timestamp.valueOf("2021-01-01 00:03:00"), 30), -// Row.of(4, "user1", Timestamp.valueOf("2021-01-01 00:06:00"), 60), -// Row.of(5, "user2", Timestamp.valueOf("2021-01-01 00:06:00"), 40), -// Row.of(6, "user2", Timestamp.valueOf("2021-01-01 00:06:00"), 20), -// Row.of(7, "user2", Timestamp.valueOf("2021-01-01 00:07:00"), 70), -// Row.of(8, "user1", Timestamp.valueOf("2021-01-01 00:08:00"), 80), -// Row.of(9, "user1", Timestamp.valueOf("2021-01-01 00:09:00"), 40), -// Row.of(10, "user1", Timestamp.valueOf("2021-01-01 00:10:00"), 20), -// Row.of(11, "user1", Timestamp.valueOf("2021-01-01 00:11:00"), 30), -// Row.of(12, "user1", Timestamp.valueOf("2021-01-01 00:11:00"), 50) -// }, -// colNames -// ); -// -// source -// .link( -// new ColumnsToTripleBatchOp() -// .setSelectedCols(colNames) -// .setTripleColumnValueSchemaStr("colName string, val string") -// ) -// .groupBy("colName", "colName, COUNT(DISTINCT val) AS valCount") -// .lazyPrint(100); -// -// source -// .link( -// new ColumnsToTripleBatchOp() -// .setSelectedCols(colNames) -// .setTripleColumnValueSchemaStr("colName string, val string") -// ) -// .groupBy("colName, val", "colName, val, COUNT(val) AS cnt") -// .link( -// new OverWindowBatchOp() -// .setOrderBy("cnt desc") -// .setGroupCols("colName") -// .setClause("ROW_NUMBER(cnt) AS rnk") -// .setReservedCols("colName","val", "cnt") -// ) -// .filter("rnk<=2") -// .print(100); -// -// } -//} \ No newline at end of file +package com.alibaba.alink.operator.local.statistics; + +import org.apache.flink.types.Row; + +import com.alibaba.alink.operator.local.LocalOperator; +import com.alibaba.alink.operator.local.source.MemSourceLocalOp; +import junit.framework.TestCase; +import org.junit.Ignore; +import org.junit.Test; + +import java.sql.Timestamp; +import java.util.Arrays; + +public class SummarizerLocalOpTest extends TestCase { + + @Test + @Ignore + public void test2() throws Exception { + Row[] testArray = + new Row[] { + Row.of("a", 1, 1.1, 1.2), + Row.of("b", null, 0.9, 1.0), + Row.of("c", 100, -0.01, 1.0), + Row.of("d", -99, 100.9, 0.1), + Row.of("a", 1, 1.1, 1.2), + Row.of("b", null, 0.9, 1.0), + Row.of("c", null, -0.01, 0.2), + Row.of("d", -99, 100.9, 0.3) + }; + + String[] colNames = new String[] {"col1", "col2", "col3", "col4"}; + + new MemSourceLocalOp(Arrays.asList(testArray), colNames) + .lazyPrintStatistics() + .lazyVizStatistics(); + LocalOperator.execute(); + } + + @Test + @Ignore + public void test3() throws Exception { + String[] colNames = new String[] {"id", "user", "sell_time", "price"}; + + MemSourceLocalOp source = new MemSourceLocalOp( + new Row[] { + Row.of(1, "user2", Timestamp.valueOf("2021-01-01 00:01:00"), 20), + Row.of(2, "user1", Timestamp.valueOf("2021-01-01 00:02:00"), 50), + Row.of(3, "user2", Timestamp.valueOf("2021-01-01 00:03:00"), 30), + Row.of(4, "user1", Timestamp.valueOf("2021-01-01 00:06:00"), 60), + Row.of(5, "user2", Timestamp.valueOf("2021-01-01 00:06:00"), 40), + Row.of(6, "user2", Timestamp.valueOf("2021-01-01 00:06:00"), 20), + Row.of(7, "user2", Timestamp.valueOf("2021-01-01 00:07:00"), 70), + Row.of(8, "user1", Timestamp.valueOf("2021-01-01 00:08:00"), 80), + Row.of(9, "user1", Timestamp.valueOf("2021-01-01 00:09:00"), 40), + Row.of(10, "user1", Timestamp.valueOf("2021-01-01 00:10:00"), 20), + Row.of(11, "user1", Timestamp.valueOf("2021-01-01 00:11:00"), 30), + Row.of(12, "user1", Timestamp.valueOf("2021-01-01 00:11:00"), 50) + }, + colNames + ); + + source + .lazyPrintStatistics() + .lazyVizStatistics(); + LocalOperator.execute(); + + } +} \ No newline at end of file diff --git a/core/src/test/java/com/alibaba/alink/operator/stream/dataproc/TypeConvertStreamOpTest.java b/core/src/test/java/com/alibaba/alink/operator/stream/dataproc/TypeConvertStreamOpTest.java index 33c4a5ea7..23632bb20 100644 --- a/core/src/test/java/com/alibaba/alink/operator/stream/dataproc/TypeConvertStreamOpTest.java +++ b/core/src/test/java/com/alibaba/alink/operator/stream/dataproc/TypeConvertStreamOpTest.java @@ -54,6 +54,6 @@ public void test() throws Exception { ret1.orderBy(0); Assert.assertEquals(Row.of(1L, 1L, 1L, 1L, 1L), ret.getRows().get(0)); - Assert.assertEquals(Row.of("1", "1", "1.1", "1.0", "true"), ret1.getRows().get(0)); + Assert.assertEquals(Row.of("1", "1", "1.1", "1", "TRUE"), ret1.getRows().get(0)); } } diff --git a/core/src/test/java/com/alibaba/alink/pipeline/PipelinePredictBatchOpTest.java b/core/src/test/java/com/alibaba/alink/pipeline/PipelinePredictBatchOpTest.java deleted file mode 100644 index 95c231d06..000000000 --- a/core/src/test/java/com/alibaba/alink/pipeline/PipelinePredictBatchOpTest.java +++ /dev/null @@ -1,71 +0,0 @@ -package com.alibaba.alink.pipeline; - -import org.apache.flink.types.Row; - -import com.alibaba.alink.operator.batch.BatchOperator; -import com.alibaba.alink.operator.batch.PipelinePredictBatchOp; -import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp; -import com.alibaba.alink.operator.batch.source.MemSourceBatchOp; -import com.alibaba.alink.pipeline.classification.LogisticRegression; -import com.alibaba.alink.pipeline.dataproc.JsonValue; -import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler; -import com.alibaba.alink.pipeline.feature.MultiHotEncoder; -import com.alibaba.alink.pipeline.feature.OneHotEncoder; -import com.alibaba.alink.testutil.AlinkTestBase; -import org.junit.Test; - -import java.util.Arrays; - -public class PipelinePredictBatchOpTest extends AlinkTestBase { - - @Test - public void test() throws Exception { - - Row[] trainArray = new Row[] { - Row.of("u0", "1.0 1.0", 1.0, 1.0, 1, 18), - Row.of("u1", "1.0 1.0", 1.0, 1.0, 0, 19), - Row.of("u2", "1.0 0.0", 1.0, 0.0, 1, 88), - Row.of("u3", "1.0 0.0", 1.0, 0.0, 1, 18), - Row.of("u4", "0.0 1.0", 0.0, 1.0, 1, 88), - Row.of("u5", "0.0 1.0", 0.0, 1.0, 1, 19), - Row.of("u6", "0.0 1.0", 0.0, 1.0, 1, 88) - }; - BatchOperator trainData = new MemSourceBatchOp(Arrays.asList(trainArray), - new String[] {"uid", "uf", "f0", "f1", "labels", "iid"}); - - String[] oneHotCols = new String[] {"uid", "f0", "f1", "iid"}; - String[] multiHotCols = new String[] {"uf"}; - - Pipeline pipe = new Pipeline() - .add( - new OneHotEncoder() - .setSelectedCols(oneHotCols) - .setOutputCols("ovec")) - .add( - new MultiHotEncoder().setDelimiter(" ") - .setSelectedCols(multiHotCols) - .setOutputCols("mvec")) - .add( - new VectorAssembler() - .setSelectedCols("ovec", "mvec") - .setOutputCol("vec")) - .add( - new LogisticRegression() - .setVectorCol("vec") - .setLabelCol("labels") - .setReservedCols("uid", "iid") - .setPredictionDetailCol("detail") - .setPredictionCol("pred")) - .add( - new JsonValue() - .setSelectedCol("detail") - .setJsonPath("$.1") - .setOutputCols("score")); - BatchOperator model = pipe.fit(trainData).save(); - String path = "/tmp/pipeline_predict_batch_op_test.ak"; - model.link(new AkSinkBatchOp().setFilePath(path).setOverwriteSink(true)); - BatchOperator.execute(); - - new PipelinePredictBatchOp().setModelFilePath(path).linkFrom(trainData).print(); - } -} \ No newline at end of file