Skip to content

Commit

Permalink
Merge the dev branch.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaomengwang committed Apr 26, 2024
1 parent 76a9aed commit f53f07b
Show file tree
Hide file tree
Showing 13 changed files with 266 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
package com.alibaba.alink.common.insights;

import org.apache.flink.api.java.tuple.Tuple3;

import com.alibaba.alink.common.MTable;
import com.alibaba.alink.operator.local.LocalOperator;

import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;

public class CorrelationInsight extends CorrelationInsightBase {

Expand All @@ -22,29 +32,92 @@ public Insight processData(LocalOperator <?>... sources) {
sourceInput[1] = groupData(sourceInput[1], insight.subject).get(0);
}
insight.score = computeScore(sourceInput);
this.fillLayout();
return this.insight;
}

@Override
public void fillLayout() {
public void fillLayout(double score) {
String correlation = "正相关";
if (score < 0) {
correlation = "负相关";
}
List <Measure> measures = this.insight.subject.measures;
this.insight.layout.xAxis = this.insight.subject.breakdown.colName;
this.insight.layout.yAxis = measures.get(0).aggr + "(" + measures.get(0).colName + ")";
this.insight.layout.lineA = insight.getSubspaceStr(insight.subject.subspaces);
this.insight.layout.lineB = insight.getSubspaceStr(insight.attachSubspaces);
this.insight.layout.title = "子集的统计指标 " +
String.format("%s的%s", measures.get(0).colName, measures.get(0).aggr.getCnName())
+ " 存在相关性";
+ " 存在" + correlation;
StringBuilder builder = new StringBuilder();
builder.append(insight.layout.lineA).append(" 与 ").append(insight.layout.lineB).append(" 条件下,");
builder.append("统计指标 ")
.append(String.format("%s的%s", measures.get(0).colName, measures.get(0).aggr.getCnName()))
.append(" 存在相关性。");
.append(" 存在")
.append(correlation);
//if (this.range.intValue() > MAX_SCALAR_THRESHOLD.intValue()) {
// builder.append("*由于二者数值范围差异较大,对第二条线进行了缩放");
//}
this.insight.layout.description = builder.toString();
}

public double computeScore(LocalOperator <?>... sources) {
//String[] columns = new String[] {insight.subject.breakdown.colName, MEASURE_NAME_PREFIX + "0"};
HashMap <Object, Number> meaValues1 = initData(sources[0]);
HashMap <Object, Number> meaValues2 = initData(sources[1]);
List <Tuple3 <Number, Number, Object>> points = new ArrayList <>();
for (Entry <Object, Number> entry : meaValues1.entrySet()) {
if (!meaValues2.containsKey(entry.getKey())) {
continue;
}
points.add(Tuple3.of(entry.getValue(), meaValues2.get(entry.getKey()), entry.getKey()));
}
if (points.size() < MIN_SAMPLE_NUM) {
return 0;
}
double[] xArray = new double[points.size()];
double[] yArray = new double[points.size()];
double maxX = Double.MIN_VALUE;
double maxY = Double.MIN_VALUE;
double minX = Double.MAX_VALUE;
double minY = Double.MAX_VALUE;
for (int i = 0; i < points.size(); i++) {
xArray[i] = points.get(i).f0.doubleValue();
yArray[i] = points.get(i).f1.doubleValue();
maxX = Math.max(maxX, xArray[i]);
maxY = Math.max(maxY, yArray[i]);
minX = Math.min(minX, xArray[i]);
minY = Math.min(minY, yArray[i]);
}
if (maxX - minX == 0 || maxY - minY == 0) {
return 0;
}

WeightedObservedPoints weightedObservedPoints = new WeightedObservedPoints();
for (int i = 0; i < points.size(); i++) {
weightedObservedPoints.add(xArray[i], yArray[i]);
}
PolynomialCurveFitter polynomialCurveFitter = PolynomialCurveFitter.create(1);
double[] params = polynomialCurveFitter.fit(weightedObservedPoints.toList());
double r2 = 0.0;
for (int i = 0; i < points.size(); i++) {
r2 += Math.pow(params[0] + params[1] * xArray[i] - yArray[i], 2);
}
double scoreA = 1 - Math.sqrt(r2) / ((maxY - minY) * points.size());
if (scoreA < 0) {
return 0;
}

//PearsonsCorrelation pc = new PearsonsCorrelation();
SpearmansCorrelation sc = new SpearmansCorrelation();
double score = Math.abs(sc.correlation(xArray, yArray));
if (score >= MIN_CORRELATION_THRESHOLD) {
MTable mtable = mergeData(points, sources[0].getSchema(), sources[1].getSchema());
insight.layout.data = mtable;
this.fillLayout(params[1]);
} else {
score = 0;
}
return score;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ public HashMap <Object, Number> initData(LocalOperator <?> source) {

public abstract Insight processData(LocalOperator <?>... sources);

public abstract void fillLayout();

@Override
public String toString() {
return insight.toString();
Expand All @@ -99,48 +97,5 @@ public String toString() {
https://www.microsoft.com/en-us/research/uploads/prod/2016/12/Insight-Types-Specification.pdf Significance of
Correlation
*/
public double computeScore(LocalOperator <?>... sources) {
//String[] columns = new String[] {insight.subject.breakdown.colName, MEASURE_NAME_PREFIX + "0"};
HashMap <Object, Number> meaValues1 = initData(sources[0]);
HashMap <Object, Number> meaValues2 = initData(sources[1]);
List <Tuple3 <Number, Number, Object>> points = new ArrayList <>();
for (Entry <Object, Number> entry : meaValues1.entrySet()) {
if (!meaValues2.containsKey(entry.getKey())) {
continue;
}
points.add(Tuple3.of(entry.getValue(), meaValues2.get(entry.getKey()), entry.getKey()));
}
if (points.size() < MIN_SAMPLE_NUM) {
return 0;
}
double[] xArray = new double[points.size()];
double[] yArray = new double[points.size()];
double maxX = 0;
double maxY = 0;
for (int i = 0; i < points.size(); i++) {
xArray[i] = points.get(i).f0.doubleValue();
yArray[i] = points.get(i).f1.doubleValue();
maxX = Math.max(maxX, Math.abs(xArray[i]));
maxY = Math.max(maxY, Math.abs(yArray[i]));
}
if (maxX == 0 || maxY == 0) {
return 0;
}
if (maxX > maxY) {
range = Math.round(maxX / maxY);
} else {
range = Math.round(maxY / maxX);
}

//PearsonsCorrelation pc = new PearsonsCorrelation();
SpearmansCorrelation sc = new SpearmansCorrelation();
double score = Math.abs(sc.correlation(xArray, yArray));
if (score >= MIN_CORRELATION_THRESHOLD) {
MTable mtable = mergeData(points, sources[0].getSchema(), sources[1].getSchema());
insight.layout.data = mtable;
} else {
score = 0;
}
return score;
}
public abstract double computeScore(LocalOperator <?>... sources);
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,28 @@ public Insight processData(LocalOperator <?>... sources) {
}
LocalOperator <?>[] sourceInput = preprocess(sources);
insight.score = computeScore(sourceInput);
this.fillLayout();
return insight;
}

@Override
public void fillLayout() {
public void fillLayout(double score) {
String correlation = "正相关";
if (score < 0) {
correlation = "负相关";
}
List<Measure> measures = this.insight.subject.measures;
this.insight.layout.xAxis = measures.get(0).aggr + "(" + measures.get(0).colName + ")";
this.insight.layout.yAxis = measures.get(1).aggr + "(" + measures.get(1).colName + ")";;
this.insight.layout.yAxis = measures.get(1).aggr + "(" + measures.get(1).colName + ")";
this.insight.layout.title = String.format("%s的%s", measures.get(0).colName, measures.get(0).aggr.getCnName())
+ " 和 " + String.format("%s的%s", measures.get(1).colName, measures.get(1).aggr.getCnName()) + " 存在相关性";
+ " 和 " + String.format("%s的%s", measures.get(1).colName, measures.get(1).aggr.getCnName()) + " 存在" + correlation;
StringBuilder builder = new StringBuilder();
if (null != insight.subject.subspaces && !insight.subject.subspaces.isEmpty()) {
builder.append(insight.getSubspaceStr(insight.subject.subspaces)).append(" 条件下,");
}
builder.append(String.format("%s的%s", measures.get(0).colName, measures.get(0).aggr.getCnName()))
.append(" 与 ")
.append(String.format("%s的%s", measures.get(1).colName, measures.get(1).aggr.getCnName()))
.append(" 存在相关性");
.append(" 存在")
.append(correlation);
this.insight.layout.description = builder.toString();
}

Expand All @@ -92,12 +95,18 @@ public double computeScore(LocalOperator <?>... sources) {
double[] yArray = new double[points.size()];
double maxY = Double.MIN_VALUE;
double minY = Double.MAX_VALUE;

double maxX = Double.MIN_VALUE;
double minX = Double.MAX_VALUE;
for (int i = 0; i < points.size(); i++) {
xArray[i] = points.get(i).f0.doubleValue();
yArray[i] = points.get(i).f1.doubleValue();
maxY = Math.max(maxY, yArray[i]);
minY = Math.min(minY, yArray[i]);
maxX = Math.max(maxX, xArray[i]);
minX = Math.min(minX, xArray[i]);
}
if (maxX - minX == 0 || maxY - minY == 0) {
return 0;
}
WeightedObservedPoints weightedObservedPoints = new WeightedObservedPoints();
for (int i = 0; i < points.size(); i++) {
Expand All @@ -115,12 +124,13 @@ public double computeScore(LocalOperator <?>... sources) {
}
//PearsonsCorrelation pc = new PearsonsCorrelation();
SpearmansCorrelation sc = new SpearmansCorrelation();
double scoreB = Math.abs(sc.correlation(xArray, yArray));
double score = (scoreA + scoreB) / 2;
double scoreB = sc.correlation(xArray, yArray);
double score = (scoreA + Math.abs(scoreB)) / 2;
if (score >= MIN_CORRELATION_THRESHOLD) {
MTable mtable = mergeData(points, sources[0].getSchema(), sources[1].getSchema());
insight.layout.data = mtable;
insight.score = score;
this.fillLayout(params[1]);
} else {
score = 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public ScatterplotClusteringInsight(Insight insight) {
super(insight);
}

@Override
public void fillLayout() {
List<Measure> measures = this.insight.subject.measures;
this.insight.layout.xAxis = String.format("%s(%s)", measures.get(0).aggr, measures.get(0).colName);
Expand All @@ -43,6 +42,9 @@ public void fillLayout() {
if (null != insight.subject.subspaces && !insight.subject.subspaces.isEmpty()) {
builder.append(insight.getSubspaceStr(insight.subject.subspaces)).append(" 条件下,");
}
if (null != insight.subject.breakdown) {
builder.append("按照列").append(insight.subject.breakdown.colName).append("维度统计,");
}
builder.append(String.format("%s的%s", measures.get(0).colName, measures.get(0).aggr.getCnName()))
.append(" 与 ")
.append(String.format("%s的%s", measures.get(1).colName, measures.get(1).aggr.getCnName()))
Expand Down Expand Up @@ -137,6 +139,7 @@ public double computeScore(LocalOperator <?>... sources) {
}
MTable mTable = new MTable(rows, outSchema);
this.insight.layout.data = mTable;
this.fillLayout();
return score;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;

import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.MTable;
import com.alibaba.alink.common.exceptions.AkIllegalStateException;
Expand Down Expand Up @@ -112,7 +113,7 @@ public static void registerFlinkBuiltInFunctions(SchemaPlus schema) {
addScalarFunctionConsumer.accept("TO_BASE64", StringFunctions.TOBASE64);
addScalarFunctionConsumer.accept("LPAD", StringFunctions.LPAD);
addScalarFunctionConsumer.accept("RPAD", StringFunctions.RPAD);
//addScalarFunctionConsumer.accept("REGEXP_REPLACE", StringFunctions.REGEXP_REPLACE);

addScalarFunctionConsumer.accept("REGEXP_EXTRACT", StringFunctions.REGEXP_EXTRACT);
addScalarFunctionConsumer.accept("CONCAT", StringFunctions.CONCAT);
addScalarFunctionConsumer.accept("CONCAT", StringFunctions.CONCAT3);
Expand Down Expand Up @@ -186,10 +187,6 @@ public static void registerFlinkBuiltInFunctions(SchemaPlus schema) {
addScalarFunctionConsumer.accept("DATEDIFF", org.apache.calcite.linq4j.tree.Types.lookupMethod(
DateDiff.class, "eval", Timestamp.class, String.class));

// if have, ut will fail.
//addScalarFunctionConsumer.accept("REGEXP_REPLACE", org.apache.calcite.linq4j.tree.Types.lookupMethod(
// RegExpReplace.class, "eval", String.class, String.class, String.class));

addScalarFunctionConsumer.accept("REGEXP", org.apache.calcite.linq4j.tree.Types.lookupMethod(
RegExp.class, "eval", String.class, String.class));

Expand All @@ -205,6 +202,14 @@ public static void registerFlinkBuiltInFunctions(SchemaPlus schema) {
DateSub.class, "eval", String.class, int.class));
addScalarFunctionConsumer.accept("DATE_SUB", org.apache.calcite.linq4j.tree.Types.lookupMethod(
DateSub.class, "eval", Timestamp.class, int.class));

if (AlinkGlobalConfiguration.getFlinkVersion().equals("flink-1.9") ||
AlinkGlobalConfiguration.getFlinkVersion().equals("flink-1.10") ||
AlinkGlobalConfiguration.getFlinkVersion().equals("flink-1.11")) {
addScalarFunctionConsumer.accept("REGEXP_REPLACE", StringFunctions.REGEXP_REPLACE);
addScalarFunctionConsumer.accept("REGEXP_REPLACE", org.apache.calcite.linq4j.tree.Types.lookupMethod(
RegExpReplace.class, "eval", String.class, String.class, String.class));
}
}

@Override
Expand Down Expand Up @@ -309,7 +314,11 @@ protected void map(SlicedSelectedSample selection, SlicedResult result) throws E
preparedStatement.setObject(i + 1, v, java.sql.Types.FLOAT);
} else if (v instanceof Integer) {
preparedStatement.setObject(i + 1, v, Types.INTEGER);
} else {
}
//else if (v instanceof Timestamp) {
// preparedStatement.setObject(i + 1, v, Types.TIMESTAMP);
//}
else {
preparedStatement.setObject(i + 1, v);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.alibaba.alink.operator.common.tree;

public interface DeepCopyable<T> {
T deepCopy();
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package com.alibaba.alink.operator.common.tree;

import java.io.Serializable;
import java.util.Arrays;

public class LabelCounter implements Serializable {
public class LabelCounter implements Serializable, DeepCopyable <LabelCounter> {
private static final long serialVersionUID = 5749266833722532209L;
private double weightSum;
private int numInst;
Expand Down Expand Up @@ -57,4 +58,15 @@ public LabelCounter normWithWeight() {

return this;
}

@Override
public LabelCounter deepCopy() {
return new LabelCounter(
this.weightSum,
this.numInst,
this.distributions == null
? null
: Arrays.copyOf(this.distributions, this.distributions.length)
);
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package com.alibaba.alink.operator.common.tree;

import java.io.Serializable;
import java.util.Arrays;

/**
* Tree node in the decision tree that will be serialized to json and deserialized from json.
*/
public class Node implements Serializable {
public class Node implements Serializable, DeepCopyable <Node> {
private static final long serialVersionUID = 1788232094688921790L;
/**
* featureIndex == -1 using for leaf
Expand Down Expand Up @@ -140,4 +141,21 @@ public Node setMissingSplit(int[] missingSplit) {
this.missingSplit = missingSplit;
return this;
}

@Override
public Node deepCopy() {
Node newNode = new Node();
newNode.featureIndex = this.featureIndex;
newNode.gain = this.gain;
newNode.counter = this.counter.deepCopy();
newNode.categoricalSplit = this.categoricalSplit == null
? null
: Arrays.copyOf(this.categoricalSplit, this.categoricalSplit.length);
newNode.continuousSplit = this.continuousSplit;
newNode.missingSplit = this.missingSplit == null
? null
: Arrays.copyOf(this.missingSplit, this.missingSplit.length);

return newNode;
}
}
Loading

0 comments on commit f53f07b

Please sign in to comment.