2
2
3
3
import com .antfin .arc .arch .message .graph .Vertex ;
4
4
import com .antfin .util .DistanceFunction ;
5
+ import com .antfin .util .GraphHelper ;
6
+ import java .io .FileOutputStream ;
7
+ import java .io .IOException ;
8
+ import java .io .OutputStream ;
5
9
import java .util .ArrayList ;
6
10
import java .util .Collections ;
7
11
import java .util .HashMap ;
8
12
import java .util .List ;
9
13
import java .util .Map ;
14
+ import java .util .Map .Entry ;
10
15
import java .util .Stack ;
16
+ import javafx .util .Pair ;
17
+ import javax .swing .JFrame ;
18
+ import jsat .SimpleDataSet ;
19
+ import jsat .classifiers .CategoricalData ;
20
+ import jsat .classifiers .DataPoint ;
21
+ import jsat .datatransform .visualization .TSNE ;
22
+ import jsat .linear .DenseMatrix ;
23
+ import jsat .linear .Matrix ;
24
+ import org .jfree .chart .ChartFactory ;
25
+ import org .jfree .chart .ChartPanel ;
26
+ import org .jfree .chart .ChartUtils ;
27
+ import org .jfree .chart .JFreeChart ;
28
+ import org .jfree .chart .plot .PlotOrientation ;
29
+ import org .jfree .data .xy .XYSeries ;
30
+ import org .jfree .data .xy .XYSeriesCollection ;
11
31
12
32
public class GNNHelper {
13
33
@@ -136,7 +156,7 @@ public static <K> void createAliasTable(List<Double> edgeWeight, K v, Map<K, Lis
136
156
}
137
157
138
158
public static <K > List <List <K >> simulateWalks (List <Vertex > vertices , int numWalks , int walkLength , double stayProb , int initialLayer ,
139
- List <Map <K , List <Double >>> layersAlias , List <Map <K , List <Double >>> layersAccept , List <Map <K , List <K >>> layersAdj , List <Map <K , Integer >> gamma ) {
159
+ List <Map <K , List <Double >>> layersAlias , List <Map <K , List <Double >>> layersAccept , List <Map <K , List <K >>> layersAdj , List <Map <K , Integer >> gamma ) {
140
160
List <List <K >> walks = new ArrayList ();
141
161
while ((numWalks --) > 0 ) {
142
162
Collections .shuffle (vertices );
@@ -150,7 +170,7 @@ public static <K> List<List<K>> simulateWalks(List<Vertex> vertices, int numWalk
150
170
// same layer
151
171
if (r < stayProb ) {
152
172
layersAdj .get (layer ).get (v .getId ());
153
- int vid = (int ) (Math .random ()* layersAccept .get (layer ).get (v .getId ()).size ());
173
+ int vid = (int ) (Math .random () * layersAccept .get (layer ).get (v .getId ()).size ());
154
174
if (rx >= layersAccept .get (layer ).get (v .getId ()).get (vid )) {
155
175
vid = layersAlias .get (layer ).get (v .getId ()).get (vid ).intValue ();
156
176
}
@@ -159,11 +179,11 @@ public static <K> List<List<K>> simulateWalks(List<Vertex> vertices, int numWalk
159
179
} else {
160
180
// different layer
161
181
double w = Math .log (gamma .get (layer ).get (v .getId ()) + Math .E );
162
- double probUp = w /( w + 1 );
182
+ double probUp = w / ( w + 1 );
163
183
if (rx > probUp && layer > initialLayer ) {
164
184
layer = layer - 1 ;
165
185
} else {
166
- if (layer + 1 < layersAdj .size () && layersAdj .get (layer + 1 ).containsKey (v .getId ())) {
186
+ if (layer + 1 < layersAdj .size () && layersAdj .get (layer + 1 ).containsKey (v .getId ())) {
167
187
++layer ;
168
188
}
169
189
}
@@ -174,4 +194,70 @@ public static <K> List<List<K>> simulateWalks(List<Vertex> vertices, int numWalk
174
194
}
175
195
return walks ;
176
196
}
197
+
198
+ public static <K > void showEmbeddings (Map <K , List <Double >> embeddings , String labelPath , String outPath ) throws IOException {
199
+ List <Pair <String , String >> labels = GraphHelper .readKVFile (labelPath );
200
+ TSNE instance = new TSNE ();
201
+ instance .setTargetDimension (2 );
202
+
203
+ Matrix orig_dim = new DenseMatrix (embeddings .size (), embeddings .values ().iterator ().next ().size ());
204
+ int i = 0 , j = 0 ;
205
+ for (Pair label :labels ) {
206
+ j = 0 ;
207
+ for (Double val : embeddings .get (label .getKey ())) {
208
+ orig_dim .set (i , j ++, val );
209
+ }
210
+ i ++;
211
+ }
212
+ SimpleDataSet proj = new SimpleDataSet (new CategoricalData [0 ], orig_dim .cols ());
213
+ for (i = 0 ; i < orig_dim .rows (); i ++) {
214
+ proj .add (new DataPoint (orig_dim .getRow (i )));
215
+ }
216
+ SimpleDataSet nodePosition = instance .transform (proj );
217
+
218
+ Map <String , List <Integer >> colorId = new HashMap <>();
219
+ for (i =0 ; i <labels .size (); ++i ) {
220
+ if (!colorId .containsKey (labels .get (i ).getValue ())) {
221
+ List <Integer > index = new ArrayList <>();
222
+ index .add (i );
223
+ colorId .put (labels .get (i ).getValue (), index );
224
+ } else {
225
+ colorId .get (labels .get (i ).getValue ()).add (i );
226
+ }
227
+ }
228
+
229
+ XYSeriesCollection dataset = new XYSeriesCollection ();
230
+
231
+ colorId .forEach ((label , ids ) -> {
232
+ XYSeries XY = new XYSeries (label );
233
+ ids .forEach (id -> {
234
+ XY .add (nodePosition .getDataPoint (id ).getNumericalValues ().get (0 ), nodePosition .getDataPoint (id ).getNumericalValues ().get (1 ));
235
+ });
236
+ dataset .addSeries (XY );
237
+ });
238
+
239
+ JFreeChart freeChart = ChartFactory .createScatterPlot (
240
+ "embeddings" ,
241
+ "X" ,
242
+ "Y" ,
243
+ dataset ,
244
+ PlotOrientation .VERTICAL ,
245
+ true ,
246
+ true ,
247
+ false
248
+ );
249
+
250
+ OutputStream os_png =new FileOutputStream (outPath );
251
+ ChartUtils .writeChartAsPNG (os_png ,freeChart ,560 ,400 );
252
+
253
+ ChartPanel chartPanel = new ChartPanel (freeChart );
254
+ chartPanel .setPreferredSize (new java .awt .Dimension (560 , 400 ));
255
+
256
+ JFrame frame = new JFrame ("embeddings" );
257
+ frame .setLocation (500 , 400 );
258
+ frame .setSize (600 , 500 );
259
+ frame .setContentPane (chartPanel );
260
+ frame .setDefaultCloseOperation (JFrame .EXIT_ON_CLOSE );
261
+ frame .setVisible (true );
262
+ }
177
263
}
0 commit comments