Skip to content

Commit d44231c

Browse files
committed
add GNNHelper.showEmbeddings
1 parent e8a8db0 commit d44231c

23 files changed

+311
-22
lines changed

pom.xml

+21
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,27 @@
1919
</properties>
2020

2121
<dependencies>
22+
<!-- https://mvnrepository.com/artifact/com.edwardraff/JSAT -->
23+
<dependency>
24+
<groupId>com.edwardraff</groupId>
25+
<artifactId>JSAT</artifactId>
26+
<version>0.0.9</version>
27+
</dependency>
28+
29+
<!-- https://mvnrepository.com/artifact/org.jfree/jcommon -->
30+
<dependency>
31+
<groupId>org.jfree</groupId>
32+
<artifactId>jcommon</artifactId>
33+
<version>1.0.24</version>
34+
</dependency>
35+
36+
<!-- https://mvnrepository.com/artifact/org.jfree/jfreechart -->
37+
<dependency>
38+
<groupId>org.jfree</groupId>
39+
<artifactId>jfreechart</artifactId>
40+
<version>1.5.0</version>
41+
</dependency>
42+
2243
<!-- https://mvnrepository.com/artifact/com.medallia.word2vec/Word2VecJava -->
2344
<dependency>
2445
<groupId>com.medallia.word2vec</groupId>

src/main/java/com/antfin/graph/refObj/Graph_Map_CSR.java

+15-9
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public void addVertex(Vertex<K, VV> vertex) {
5555
this.vertices.add(vertex);
5656
this.dictV.put(vertex.getId(), this.dictV.size());
5757
} else {
58-
if (vertex.isEmpty()) {
58+
if (vertex.isEmpty() && this.dictV.get(vertex.getId()) < this.edges.size()) {
5959
this.vertices.set(this.dictV.get(vertex.getId()), nullVertex);
6060
} else {
6161
this.vertices.set(this.dictV.get(vertex.getId()), vertex);
@@ -71,25 +71,31 @@ public void addEdge(Edge<K, EV> edge) {
7171
// the source vertex of edge has other edges.
7272
this.edges.get(this.dictV.get(srcId)).add(edge);
7373
} else {
74-
int _id = 0;
74+
int dictVal = 0;
7575
if (!exist) {
7676
this.vertices.add(nullVertex);
77-
if (this.dictV.size() > 0) {
78-
_id = this.dictV.size() - 1;
79-
}
77+
dictVal = this.dictV.size();
8078
} else {
81-
_id = this.dictV.get(srcId);
79+
dictVal = this.dictV.get(srcId);
80+
if (this.vertices.get(dictVal) != nullVertex && this.vertices.get(dictVal).isEmpty()) {
81+
this.vertices.set(dictVal, nullVertex);
82+
}
8283
}
83-
if (this.edges.size() != this.dictV.size()) {
84-
this.dictV.put(this.vertices.get(this.edges.size()).getId(), _id);
84+
if (this.edges.size() != this.dictV.size() && dictVal != this.edges.size()) {
85+
this.dictV.put(this.vertices.get(this.edges.size()).getId(), dictVal);
86+
GraphHelper.swap(this.edges.size(), dictVal, this.vertices);
8587
}
8688
this.dictV.put(srcId, this.edges.size());
87-
GraphHelper.swap(this.edges.size(), _id, this.vertices);
8889

8990
List<Edge<K, EV>> edges = new ArrayList<>();
9091
edges.add(edge);
9192
this.edges.add(edges);
9293
}
94+
// add target
95+
if (!dictV.containsKey(edge.getTargetId())) {
96+
this.vertices.add(new Vertex<>(edge.getTargetId()));
97+
this.dictV.put(edge.getTargetId(), this.dictV.size());
98+
}
9399
}
94100

95101
@Override

src/main/java/com/antfin/util/GraphHelper.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.HashMap;
1414
import java.util.List;
1515
import java.util.Map;
16+
import javafx.util.Pair;
1617

1718
public class GraphHelper {
1819

@@ -122,11 +123,19 @@ public static void writeObject(Object object, File file) throws IOException {
122123
}
123124

124125
public static List<Edge<String, String>> loadEdges(String path){
126+
List<Edge<String, String>> edges = new ArrayList<>();
127+
readKVFile(path).forEach(pair->{
128+
edges.add(new Edge<>(pair.getKey(), pair.getValue(), RandomWord.getWords(100)));
129+
});
130+
return edges;
131+
}
132+
133+
public static List<Pair<String, String>> readKVFile(String path){
125134
File file = new File(path);
126135
if (!file.exists()) {
127136
System.err.println(path + " is not exist!");
128137
}
129-
List<Edge<String, String>> edges = new ArrayList<>();
138+
List<Pair<String, String>> pairs = new ArrayList<>();
130139
try (BufferedReader reader = Files.newBufferedReader(file.toPath(), Charset.forName("utf-8"))) {
131140
String line = null;
132141
while ((line = reader.readLine()) != null) {
@@ -138,11 +147,11 @@ public static List<Edge<String, String>> loadEdges(String path){
138147
{
139148
System.err.println(line + " must include source and sink!");
140149
}
141-
edges.add(new Edge<>(vid[0], vid[1], RandomWord.getWords(100)));
150+
pairs.add(new Pair<>(vid[0], vid[1]));
142151
}
143152
} catch (IOException x) {
144153
System.err.format("IOException: %s%n", x);
145154
}
146-
return edges;
155+
return pairs;
147156
}
148157
}

src/main/java/com/gnn/embedding/Struc2vec.java

+33-6
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import java.util.stream.Collectors;
2828
import javafx.util.Pair;
2929
import org.apache.commons.collections.map.HashedMap;
30-
import org.apache.commons.lang3.StringUtils;
3130

3231
public class Struc2vec<K, VV, EV> {
3332

@@ -62,7 +61,7 @@ public Struc2vec() {
6261

6362
this.opt1_reduce_len = true;
6463
this.opt2_reduce_sim_calc = false;
65-
this.opt3_reduce_layers = false;
64+
this.opt3_reduce_layers = true;
6665
this.opt3_num_layers = 10;
6766

6867
this.tempPath = "./temp/struc2vec/";
@@ -99,8 +98,8 @@ public Struc2vec(String path, int walkLength, int numWalks, int workers, double
9998
public Struc2vec(String path) throws IOException {
10099
this();
101100
this.graph = new Graph_Map_CSR(GraphHelper.loadEdges(path), false);
102-
this.createContextGraph(10, 4);
103-
this.walks = this.struc2vecWalk(100, 10, 0.3, 4);
101+
this.createContextGraph(this.opt3_num_layers, this.workers);
102+
this.walks = this.struc2vecWalk(this.numWalks, this.walkLength, this.stayProb, this.workers);
104103
}
105104

106105
public void createContextGraph(int numLayers, int workers) throws IOException {
@@ -133,7 +132,30 @@ public Map<Pair<K, K>, List<Double>> computeStructuralDistance(int numLayers, in
133132

134133
Map<K, List<K>> vertices = new HashMap<>();
135134
if (this.opt2_reduce_sim_calc) {
135+
// store v list of degree
136+
Map<Integer, Map<String, Object>> degrees = new HashMap();
137+
// store degree
138+
List<Integer> degreeSet = new ArrayList<>();
139+
this.graph.getVertexList().forEach(v -> {
140+
int degree = ((List) this.graph.getEdge(((Vertex) v).getId())).size();
141+
if (!degreeSet.contains(degree)) {
142+
degreeSet.add(degree);
143+
}
144+
if (!degrees.containsKey(degree)) {
145+
Map<String, Object> temp = new HashMap<>();
146+
temp.put("vertices", new ArrayList<K>());
147+
degrees.put(degree, temp);
148+
}
149+
((List<K>) degrees.get(degree).get("vertices")).add((K) ((Vertex) v).getId());
150+
});
151+
Collections.sort(degreeSet);
152+
for (int i = 1; i < degreeSet.size(); ++i) {
153+
degrees.get(degreeSet.get(i)).put("before", degreeSet.get(i - 1));
154+
degrees.get(degreeSet.get(i - 1)).put("after", degreeSet.get(i));
155+
}
156+
this.graph.getVertexList().forEach(v -> {
136157

158+
});
137159
} else {
138160
degreeList.keySet().forEach(k -> {
139161
vertices.keySet().forEach(item -> {
@@ -361,10 +383,10 @@ public void train(int embed_size, int window_size, int workers, int iterator) th
361383
.setLayerSize(embed_size)
362384
.setDownSamplingRate(1e-3)
363385
.setNumIterations(iterator)
364-
.train(Iterables.partition((List<String>)this.walks.stream().flatMap(List::stream).collect(Collectors.toList()), this.walkLength));
386+
.train(Iterables.partition((List<String>) this.walks.stream().flatMap(List::stream).collect(Collectors.toList()), this.walkLength));
365387
}
366388

367-
public Map<K, List<Double>> getEmbeddings() {
389+
public Map<K, List<Double>> getEmbeddings() throws IOException {
368390
if (this.model == null) {
369391
System.err.println("this mode is not trained!");
370392
}
@@ -379,6 +401,7 @@ public Map<K, List<Double>> getEmbeddings() {
379401
e.printStackTrace();
380402
}
381403
});
404+
GraphHelper.writeObject(this.embeddings, new File(String.format("%sembeddings.kryo", this.tempPath)));
382405
return this.embeddings;
383406
}
384407

@@ -389,4 +412,8 @@ public Graph getGraph() {
389412
public Word2VecModel getModel() {
390413
return this.model;
391414
}
415+
416+
public String getTempPath() {
417+
return this.tempPath;
418+
}
392419
}

src/main/java/com/gnn/util/GNNHelper.java

+90-4
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,32 @@
22

33
import com.antfin.arc.arch.message.graph.Vertex;
44
import com.antfin.util.DistanceFunction;
5+
import com.antfin.util.GraphHelper;
6+
import java.io.FileOutputStream;
7+
import java.io.IOException;
8+
import java.io.OutputStream;
59
import java.util.ArrayList;
610
import java.util.Collections;
711
import java.util.HashMap;
812
import java.util.List;
913
import java.util.Map;
14+
import java.util.Map.Entry;
1015
import java.util.Stack;
16+
import javafx.util.Pair;
17+
import javax.swing.JFrame;
18+
import jsat.SimpleDataSet;
19+
import jsat.classifiers.CategoricalData;
20+
import jsat.classifiers.DataPoint;
21+
import jsat.datatransform.visualization.TSNE;
22+
import jsat.linear.DenseMatrix;
23+
import jsat.linear.Matrix;
24+
import org.jfree.chart.ChartFactory;
25+
import org.jfree.chart.ChartPanel;
26+
import org.jfree.chart.ChartUtils;
27+
import org.jfree.chart.JFreeChart;
28+
import org.jfree.chart.plot.PlotOrientation;
29+
import org.jfree.data.xy.XYSeries;
30+
import org.jfree.data.xy.XYSeriesCollection;
1131

1232
public class GNNHelper {
1333

@@ -136,7 +156,7 @@ public static <K> void createAliasTable(List<Double> edgeWeight, K v, Map<K, Lis
136156
}
137157

138158
public static <K> List<List<K>> simulateWalks(List<Vertex> vertices, int numWalks, int walkLength, double stayProb, int initialLayer,
139-
List<Map<K, List<Double>>> layersAlias, List<Map<K, List<Double>>> layersAccept, List<Map<K, List<K>>> layersAdj, List<Map<K, Integer>> gamma) {
159+
List<Map<K, List<Double>>> layersAlias, List<Map<K, List<Double>>> layersAccept, List<Map<K, List<K>>> layersAdj, List<Map<K, Integer>> gamma) {
140160
List<List<K>> walks = new ArrayList();
141161
while ((numWalks--) > 0) {
142162
Collections.shuffle(vertices);
@@ -150,7 +170,7 @@ public static <K> List<List<K>> simulateWalks(List<Vertex> vertices, int numWalk
150170
// same layer
151171
if (r < stayProb) {
152172
layersAdj.get(layer).get(v.getId());
153-
int vid = (int) (Math.random()*layersAccept.get(layer).get(v.getId()).size());
173+
int vid = (int) (Math.random() * layersAccept.get(layer).get(v.getId()).size());
154174
if (rx >= layersAccept.get(layer).get(v.getId()).get(vid)) {
155175
vid = layersAlias.get(layer).get(v.getId()).get(vid).intValue();
156176
}
@@ -159,11 +179,11 @@ public static <K> List<List<K>> simulateWalks(List<Vertex> vertices, int numWalk
159179
} else {
160180
// different layer
161181
double w = Math.log(gamma.get(layer).get(v.getId()) + Math.E);
162-
double probUp = w/(w+1);
182+
double probUp = w / (w + 1);
163183
if (rx > probUp && layer > initialLayer) {
164184
layer = layer - 1;
165185
} else {
166-
if (layer + 1 < layersAdj.size() && layersAdj.get(layer+1).containsKey(v.getId())) {
186+
if (layer + 1 < layersAdj.size() && layersAdj.get(layer + 1).containsKey(v.getId())) {
167187
++layer;
168188
}
169189
}
@@ -174,4 +194,70 @@ public static <K> List<List<K>> simulateWalks(List<Vertex> vertices, int numWalk
174194
}
175195
return walks;
176196
}
197+
198+
public static <K> void showEmbeddings(Map<K, List<Double>> embeddings, String labelPath, String outPath) throws IOException {
199+
List<Pair<String, String>> labels = GraphHelper.readKVFile(labelPath);
200+
TSNE instance = new TSNE();
201+
instance.setTargetDimension(2);
202+
203+
Matrix orig_dim = new DenseMatrix(embeddings.size(), embeddings.values().iterator().next().size());
204+
int i = 0, j = 0;
205+
for (Pair label:labels) {
206+
j = 0;
207+
for (Double val : embeddings.get(label.getKey())) {
208+
orig_dim.set(i, j++, val);
209+
}
210+
i++;
211+
}
212+
SimpleDataSet proj = new SimpleDataSet(new CategoricalData[0], orig_dim.cols());
213+
for (i = 0; i < orig_dim.rows(); i++) {
214+
proj.add(new DataPoint(orig_dim.getRow(i)));
215+
}
216+
SimpleDataSet nodePosition = instance.transform(proj);
217+
218+
Map<String, List<Integer>> colorId = new HashMap<>();
219+
for (i=0; i<labels.size(); ++i) {
220+
if(!colorId.containsKey(labels.get(i).getValue())) {
221+
List<Integer> index = new ArrayList<>();
222+
index.add(i);
223+
colorId.put(labels.get(i).getValue(), index);
224+
} else {
225+
colorId.get(labels.get(i).getValue()).add(i);
226+
}
227+
}
228+
229+
XYSeriesCollection dataset = new XYSeriesCollection();
230+
231+
colorId.forEach((label, ids) -> {
232+
XYSeries XY = new XYSeries(label);
233+
ids.forEach(id -> {
234+
XY.add(nodePosition.getDataPoint(id).getNumericalValues().get(0), nodePosition.getDataPoint(id).getNumericalValues().get(1));
235+
});
236+
dataset.addSeries(XY);
237+
});
238+
239+
JFreeChart freeChart = ChartFactory.createScatterPlot(
240+
"embeddings",
241+
"X",
242+
"Y",
243+
dataset,
244+
PlotOrientation.VERTICAL,
245+
true,
246+
true,
247+
false
248+
);
249+
250+
OutputStream os_png=new FileOutputStream(outPath);
251+
ChartUtils.writeChartAsPNG(os_png,freeChart,560,400);
252+
253+
ChartPanel chartPanel = new ChartPanel(freeChart);
254+
chartPanel.setPreferredSize(new java.awt.Dimension(560, 400));
255+
256+
JFrame frame = new JFrame("embeddings");
257+
frame.setLocation(500, 400);
258+
frame.setSize(600, 500);
259+
frame.setContentPane(chartPanel);
260+
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
261+
frame.setVisible(true);
262+
}
177263
}

0 commit comments

Comments
 (0)