diff --git a/README.md b/README.md index 7d1c418..eb32c41 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,32 @@ # GAT -Graph Attention Networks +Graph Attention Networks (Veličković *et al.*, ICLR 2018) + +## Overview +Here we provide the implementation of a Graph Attention Network (GAT) layer in TensorFlow, along with a minimal execution example (on the Cora dataset). The repository is organised as follows: +- `data/` contains the necessary dataset files for Cora; +- `models/` contains the implementation of the GAT network (`gat.py`); +- `pre_trained/` contains a pre-trained Cora model (achieving 84.4% accuracy on the test set); +- `utils/` contains: + * an implementation of an attention head, along with an experimental sparse version (`layers.py`); + * preprocessing subroutines (`process.py`); + * preprocessing utilities for the PPI benchmark (`process_ppi.py`). + +Finally, `execute_cora.py` puts all of the above together and may be used to execute a full training run on Cora. + +## Reference +If you make advantage of the GAT model in your research, please cite the following in your manuscript: + +``` +@article{ + velickovic2018graph, + title={Graph Attention Networks}, + author={Petar Veli{\v{c}}kovi{\'{c}}, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Li{\`{o}}, Yoshua Bengio}, + journal={International Conference on Learning Representations}, + year={2018}, + url={https://openreview.net/forum?id=rJXMpikCZ}, + note={accepted as poster}, +} +``` + +## License +MIT diff --git a/data/ind.cora.allx b/data/ind.cora.allx new file mode 100644 index 0000000..44d53b1 Binary files /dev/null and b/data/ind.cora.allx differ diff --git a/data/ind.cora.ally b/data/ind.cora.ally new file mode 100644 index 0000000..04fbd0b Binary files /dev/null and b/data/ind.cora.ally differ diff --git a/data/ind.cora.graph b/data/ind.cora.graph new file mode 100644 index 0000000..4d3bf85 Binary files /dev/null and b/data/ind.cora.graph differ diff --git a/data/ind.cora.test.index b/data/ind.cora.test.index new file mode 100644 index 0000000..ded8092 --- /dev/null +++ b/data/ind.cora.test.index @@ -0,0 +1,1000 @@ +2692 +2532 +2050 +1715 +2362 +2609 +2622 +1975 +2081 +1767 +2263 +1725 +2588 +2259 +2357 +1998 +2574 +2179 +2291 +2382 +1812 +1751 +2422 +1937 +2631 +2510 +2378 +2589 +2345 +1943 +1850 +2298 +1825 +2035 +2507 +2313 +1906 +1797 +2023 +2159 +2495 +1886 +2122 +2369 +2461 +1925 +2565 +1858 +2234 +2000 +1846 +2318 +1723 +2559 +2258 +1763 +1991 +1922 +2003 +2662 +2250 +2064 +2529 +1888 +2499 +2454 +2320 +2287 +2203 +2018 +2002 +2632 +2554 +2314 +2537 +1760 +2088 +2086 +2218 +2605 +1953 +2403 +1920 +2015 +2335 +2535 +1837 +2009 +1905 +2636 +1942 +2193 +2576 +2373 +1873 +2463 +2509 +1954 +2656 +2455 +2494 +2295 +2114 +2561 +2176 +2275 +2635 +2442 +2704 +2127 +2085 +2214 +2487 +1739 +2543 +1783 +2485 +2262 +2472 +2326 +1738 +2170 +2100 +2384 +2152 +2647 +2693 +2376 +1775 +1726 +2476 +2195 +1773 +1793 +2194 +2581 +1854 +2524 +1945 +1781 +1987 +2599 +1744 +2225 +2300 +1928 +2042 +2202 +1958 +1816 +1916 +2679 +2190 +1733 +2034 +2643 +2177 +1883 +1917 +1996 +2491 +2268 +2231 +2471 +1919 +1909 +2012 +2522 +1865 +2466 +2469 +2087 +2584 +2563 +1924 +2143 +1736 +1966 +2533 +2490 +2630 +1973 +2568 +1978 +2664 +2633 +2312 +2178 +1754 +2307 +2480 +1960 +1742 +1962 +2160 +2070 +2553 +2433 +1768 +2659 +2379 +2271 +1776 +2153 +1877 +2027 +2028 +2155 +2196 +2483 +2026 +2158 +2407 +1821 +2131 +2676 +2277 +2489 +2424 +1963 +1808 +1859 +2597 +2548 +2368 +1817 +2405 +2413 +2603 +2350 +2118 +2329 +1969 +2577 +2475 +2467 +2425 +1769 +2092 +2044 +2586 +2608 +1983 +2109 +2649 +1964 +2144 +1902 +2411 +2508 +2360 +1721 +2005 +2014 +2308 +2646 +1949 +1830 +2212 +2596 +1832 +1735 +1866 +2695 +1941 +2546 +2498 +2686 +2665 +1784 +2613 +1970 +2021 +2211 +2516 +2185 +2479 +2699 +2150 +1990 +2063 +2075 +1979 +2094 +1787 +2571 +2690 +1926 +2341 +2566 +1957 +1709 +1955 +2570 +2387 +1811 +2025 +2447 +2696 +2052 +2366 +1857 +2273 +2245 +2672 +2133 +2421 +1929 +2125 +2319 +2641 +2167 +2418 +1765 +1761 +1828 +2188 +1972 +1997 +2419 +2289 +2296 +2587 +2051 +2440 +2053 +2191 +1923 +2164 +1861 +2339 +2333 +2523 +2670 +2121 +1921 +1724 +2253 +2374 +1940 +2545 +2301 +2244 +2156 +1849 +2551 +2011 +2279 +2572 +1757 +2400 +2569 +2072 +2526 +2173 +2069 +2036 +1819 +1734 +1880 +2137 +2408 +2226 +2604 +1771 +2698 +2187 +2060 +1756 +2201 +2066 +2439 +1844 +1772 +2383 +2398 +1708 +1992 +1959 +1794 +2426 +2702 +2444 +1944 +1829 +2660 +2497 +2607 +2343 +1730 +2624 +1790 +1935 +1967 +2401 +2255 +2355 +2348 +1931 +2183 +2161 +2701 +1948 +2501 +2192 +2404 +2209 +2331 +1810 +2363 +2334 +1887 +2393 +2557 +1719 +1732 +1986 +2037 +2056 +1867 +2126 +1932 +2117 +1807 +1801 +1743 +2041 +1843 +2388 +2221 +1833 +2677 +1778 +2661 +2306 +2394 +2106 +2430 +2371 +2606 +2353 +2269 +2317 +2645 +2372 +2550 +2043 +1968 +2165 +2310 +1985 +2446 +1982 +2377 +2207 +1818 +1913 +1766 +1722 +1894 +2020 +1881 +2621 +2409 +2261 +2458 +2096 +1712 +2594 +2293 +2048 +2359 +1839 +2392 +2254 +1911 +2101 +2367 +1889 +1753 +2555 +2246 +2264 +2010 +2336 +2651 +2017 +2140 +1842 +2019 +1890 +2525 +2134 +2492 +2652 +2040 +2145 +2575 +2166 +1999 +2434 +1711 +2276 +2450 +2389 +2669 +2595 +1814 +2039 +2502 +1896 +2168 +2344 +2637 +2031 +1977 +2380 +1936 +2047 +2460 +2102 +1745 +2650 +2046 +2514 +1980 +2352 +2113 +1713 +2058 +2558 +1718 +1864 +1876 +2338 +1879 +1891 +2186 +2451 +2181 +2638 +2644 +2103 +2591 +2266 +2468 +1869 +2582 +2674 +2361 +2462 +1748 +2215 +2615 +2236 +2248 +2493 +2342 +2449 +2274 +1824 +1852 +1870 +2441 +2356 +1835 +2694 +2602 +2685 +1893 +2544 +2536 +1994 +1853 +1838 +1786 +1930 +2539 +1892 +2265 +2618 +2486 +2583 +2061 +1796 +1806 +2084 +1933 +2095 +2136 +2078 +1884 +2438 +2286 +2138 +1750 +2184 +1799 +2278 +2410 +2642 +2435 +1956 +2399 +1774 +2129 +1898 +1823 +1938 +2299 +1862 +2420 +2673 +1984 +2204 +1717 +2074 +2213 +2436 +2297 +2592 +2667 +2703 +2511 +1779 +1782 +2625 +2365 +2315 +2381 +1788 +1714 +2302 +1927 +2325 +2506 +2169 +2328 +2629 +2128 +2655 +2282 +2073 +2395 +2247 +2521 +2260 +1868 +1988 +2324 +2705 +2541 +1731 +2681 +2707 +2465 +1785 +2149 +2045 +2505 +2611 +2217 +2180 +1904 +2453 +2484 +1871 +2309 +2349 +2482 +2004 +1965 +2406 +2162 +1805 +2654 +2007 +1947 +1981 +2112 +2141 +1720 +1758 +2080 +2330 +2030 +2432 +2089 +2547 +1820 +1815 +2675 +1840 +2658 +2370 +2251 +1908 +2029 +2068 +2513 +2549 +2267 +2580 +2327 +2351 +2111 +2022 +2321 +2614 +2252 +2104 +1822 +2552 +2243 +1798 +2396 +2663 +2564 +2148 +2562 +2684 +2001 +2151 +2706 +2240 +2474 +2303 +2634 +2680 +2055 +2090 +2503 +2347 +2402 +2238 +1950 +2054 +2016 +1872 +2233 +1710 +2032 +2540 +2628 +1795 +2616 +1903 +2531 +2567 +1946 +1897 +2222 +2227 +2627 +1856 +2464 +2241 +2481 +2130 +2311 +2083 +2223 +2284 +2235 +2097 +1752 +2515 +2527 +2385 +2189 +2283 +2182 +2079 +2375 +2174 +2437 +1993 +2517 +2443 +2224 +2648 +2171 +2290 +2542 +2038 +1855 +1831 +1759 +1848 +2445 +1827 +2429 +2205 +2598 +2657 +1728 +2065 +1918 +2427 +2573 +2620 +2292 +1777 +2008 +1875 +2288 +2256 +2033 +2470 +2585 +2610 +2082 +2230 +1915 +1847 +2337 +2512 +2386 +2006 +2653 +2346 +1951 +2110 +2639 +2520 +1939 +2683 +2139 +2220 +1910 +2237 +1900 +1836 +2197 +1716 +1860 +2077 +2519 +2538 +2323 +1914 +1971 +1845 +2132 +1802 +1907 +2640 +2496 +2281 +2198 +2416 +2285 +1755 +2431 +2071 +2249 +2123 +1727 +2459 +2304 +2199 +1791 +1809 +1780 +2210 +2417 +1874 +1878 +2116 +1961 +1863 +2579 +2477 +2228 +2332 +2578 +2457 +2024 +1934 +2316 +1841 +1764 +1737 +2322 +2239 +2294 +1729 +2488 +1974 +2473 +2098 +2612 +1834 +2340 +2423 +2175 +2280 +2617 +2208 +2560 +1741 +2600 +2059 +1747 +2242 +2700 +2232 +2057 +2147 +2682 +1792 +1826 +2120 +1895 +2364 +2163 +1851 +2391 +2414 +2452 +1803 +1989 +2623 +2200 +2528 +2415 +1804 +2146 +2619 +2687 +1762 +2172 +2270 +2678 +2593 +2448 +1882 +2257 +2500 +1899 +2478 +2412 +2107 +1746 +2428 +2115 +1800 +1901 +2397 +2530 +1912 +2108 +2206 +2091 +1740 +2219 +1976 +2099 +2142 +2671 +2668 +2216 +2272 +2229 +2666 +2456 +2534 +2697 +2688 +2062 +2691 +2689 +2154 +2590 +2626 +2390 +1813 +2067 +1952 +2518 +2358 +1789 +2076 +2049 +2119 +2013 +2124 +2556 +2105 +2093 +1885 +2305 +2354 +2135 +2601 +1770 +1995 +2504 +1749 +2157 diff --git a/data/ind.cora.tx b/data/ind.cora.tx new file mode 100644 index 0000000..6e856d7 Binary files /dev/null and b/data/ind.cora.tx differ diff --git a/data/ind.cora.ty b/data/ind.cora.ty new file mode 100644 index 0000000..da1734a Binary files /dev/null and b/data/ind.cora.ty differ diff --git a/data/ind.cora.x b/data/ind.cora.x new file mode 100644 index 0000000..c4a91d0 Binary files /dev/null and b/data/ind.cora.x differ diff --git a/data/ind.cora.y b/data/ind.cora.y new file mode 100644 index 0000000..58e30ef Binary files /dev/null and b/data/ind.cora.y differ diff --git a/execute_cora.py b/execute_cora.py new file mode 100644 index 0000000..2a0b498 --- /dev/null +++ b/execute_cora.py @@ -0,0 +1,174 @@ +import time +import numpy as np +import tensorflow as tf + +from models import GAT +from utils import process + +checkpt_file = 'pre_trained/cora/mod_cora.ckpt' + +dataset = 'cora' + +# training params +batch_size = 1 +nb_epochs = 100000 +patience = 100 +lr = 0.005 # learning rate +l2_coef = 0.0005 # weight decay +hid_units = [8] # numbers of hidden units per each attention head in each layer +n_heads = [8, 1] # additional entry for the output layer +residual = False +nonlinearity = tf.nn.elu +model = GAT + +print('Dataset: ' + dataset) +print('----- Opt. hyperparams -----') +print('lr: ' + str(lr)) +print('l2_coef: ' + str(l2_coef)) +print('----- Archi. hyperparams -----') +print('nb. layers: ' + str(len(hid_units))) +print('nb. units per layer: ' + str(hid_units)) +print('nb. attention heads: ' + str(n_heads)) +print('residual: ' + str(residual)) +print('nonlinearity: ' + str(nonlinearity)) +print('model: ' + str(model)) + +adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = process.load_data(dataset) +features, spars = process.preprocess_features(features) + +nb_nodes = features.shape[0] +ft_size = features.shape[1] +nb_classes = y_train.shape[1] + +adj = adj.todense() + +features = features[np.newaxis] +adj = adj[np.newaxis] +y_train = y_train[np.newaxis] +y_val = y_val[np.newaxis] +y_test = y_test[np.newaxis] +train_mask = train_mask[np.newaxis] +val_mask = val_mask[np.newaxis] +test_mask = test_mask[np.newaxis] + +biases = process.adj_to_bias(adj, [nb_nodes], nhood=1) + +with tf.Graph().as_default(): + with tf.name_scope('input'): + ftr_in = tf.placeholder(dtype=tf.float32, shape=(batch_size, nb_nodes, ft_size)) + bias_in = tf.placeholder(dtype=tf.float32, shape=(batch_size, nb_nodes, nb_nodes)) + lbl_in = tf.placeholder(dtype=tf.int32, shape=(batch_size, nb_nodes, nb_classes)) + msk_in = tf.placeholder(dtype=tf.int32, shape=(batch_size, nb_nodes)) + attn_drop = tf.placeholder(dtype=tf.float32, shape=()) + ffd_drop = tf.placeholder(dtype=tf.float32, shape=()) + is_train = tf.placeholder(dtype=tf.bool, shape=()) + + logits = model.inference(ftr_in, nb_classes, nb_nodes, is_train, + attn_drop, ffd_drop, + bias_mat=bias_in, + hid_units=hid_units, n_heads=n_heads, + residual=residual, activation=nonlinearity) + log_resh = tf.reshape(logits, [-1, nb_classes]) + lab_resh = tf.reshape(lbl_in, [-1, nb_classes]) + msk_resh = tf.reshape(msk_in, [-1]) + loss = model.masked_softmax_cross_entropy(log_resh, lab_resh, msk_resh) + accuracy = model.masked_accuracy(log_resh, lab_resh, msk_resh) + + train_op = model.training(loss, lr, l2_coef) + + saver = tf.train.Saver() + + init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) + + vlss_mn = np.inf + vacc_mx = 0.0 + curr_step = 0 + + with tf.Session() as sess: + sess.run(init_op) + + train_loss_avg = 0 + train_acc_avg = 0 + val_loss_avg = 0 + val_acc_avg = 0 + + for epoch in range(nb_epochs): + tr_step = 0 + tr_size = features.shape[0] + + while tr_step * batch_size < tr_size: + _, loss_value_tr, acc_tr = sess.run([train_op, loss, accuracy], + feed_dict={ + ftr_in: features[tr_step*batch_size:(tr_step+1)*batch_size], + bias_in: biases[tr_step*batch_size:(tr_step+1)*batch_size], + lbl_in: y_train[tr_step*batch_size:(tr_step+1)*batch_size], + msk_in: train_mask[tr_step*batch_size:(tr_step+1)*batch_size], + is_train: True, + attn_drop: 0.6, ffd_drop: 0.6}) + train_loss_avg += loss_value_tr + train_acc_avg += acc_tr + tr_step += 1 + + vl_step = 0 + vl_size = features.shape[0] + + while vl_step * batch_size < vl_size: + loss_value_vl, acc_vl = sess.run([loss, accuracy], + feed_dict={ + ftr_in: features[vl_step*batch_size:(vl_step+1)*batch_size], + bias_in: biases[vl_step*batch_size:(vl_step+1)*batch_size], + lbl_in: y_val[vl_step*batch_size:(vl_step+1)*batch_size], + msk_in: val_mask[vl_step*batch_size:(vl_step+1)*batch_size], + is_train: False, + attn_drop: 0.0, ffd_drop: 0.0}) + val_loss_avg += loss_value_vl + val_acc_avg += acc_vl + vl_step += 1 + + print('Training: loss = %.5f, acc = %.5f | Val: loss = %.5f, acc = %.5f' % + (train_loss_avg/tr_step, train_acc_avg/tr_step, + val_loss_avg/vl_step, val_acc_avg/vl_step)) + + if val_acc_avg/vl_step >= vacc_mx or val_loss_avg/vl_step <= vlss_mn: + if val_acc_avg/vl_step >= vacc_mx and val_loss_avg/vl_step <= vlss_mn: + vacc_early_model = val_acc_avg/vl_step + vlss_early_model = val_loss_avg/vl_step + saver.save(sess, checkpt_file) + vacc_mx = np.max((val_acc_avg/vl_step, vacc_mx)) + vlss_mn = np.min((val_loss_avg/vl_step, vlss_mn)) + curr_step = 0 + else: + curr_step += 1 + if curr_step == patience: + print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx) + print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model) + break + + train_loss_avg = 0 + train_acc_avg = 0 + val_loss_avg = 0 + val_acc_avg = 0 + + saver.restore(sess, checkpt_file) + + ts_size = features.shape[0] + ts_step = 0 + ts_loss = 0.0 + ts_acc = 0.0 + + while ts_step * batch_size < ts_size: + loss_value_ts, acc_ts = sess.run([loss, accuracy], + feed_dict={ + ftr_in: features[ts_step*batch_size:(ts_step+1)*batch_size], + bias_in: biases[ts_step*batch_size:(ts_step+1)*batch_size], + lbl_in: y_test[ts_step*batch_size:(ts_step+1)*batch_size], + msk_in: test_mask[ts_step*batch_size:(ts_step+1)*batch_size], + is_train: False, + attn_drop: 0.0, ffd_drop: 0.0}) + ts_loss += loss_value_ts + ts_acc += acc_ts + ts_step += 1 + + print('Test loss:', ts_loss/ts_step, '; Test accuracy:', ts_acc/ts_step) + + sess.close() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..a4a9cdf --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .gat import GAT diff --git a/models/base_gattn.py b/models/base_gattn.py new file mode 100644 index 0000000..71edcf9 --- /dev/null +++ b/models/base_gattn.py @@ -0,0 +1,89 @@ +import tensorflow as tf + +class BaseGAttN: + def loss(logits, labels, nb_classes, class_weights): + sample_wts = tf.reduce_sum(tf.multiply(tf.one_hot(labels, nb_classes), class_weights), axis=-1) + xentropy = tf.multiply(tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits), sample_wts) + return tf.reduce_mean(xentropy, name='xentropy_mean') + + def training(loss, lr, l2_coef): + # weight decay + vars = tf.trainable_variables() + lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if v.name not + in ['bias', 'gamma', 'b', 'g', 'beta']]) * l2_coef + + # optimizer + opt = tf.train.AdamOptimizer(learning_rate=lr) + + # training op + train_op = opt.minimize(loss+lossL2) + + return train_op + + def preshape(logits, labels, nb_classes): + new_sh_lab = [-1] + new_sh_log = [-1, nb_classes] + log_resh = tf.reshape(logits, new_sh_log) + lab_resh = tf.reshape(labels, new_sh_lab) + return log_resh, lab_resh + + def confmat(logits, labels): + preds = tf.argmax(logits, axis=1) + return tf.confusion_matrix(labels, preds) + +########################## +# Adapted from tkipf/gcn # +########################## + + def masked_softmax_cross_entropy(logits, labels, mask): + """Softmax cross-entropy loss with masking.""" + loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels) + mask = tf.cast(mask, dtype=tf.float32) + mask /= tf.reduce_mean(mask) + loss *= mask + return tf.reduce_mean(loss) + + def masked_sigmoid_cross_entropy(logits, labels, mask): + """Softmax cross-entropy loss with masking.""" + labels = tf.cast(labels, dtype=tf.float32) + loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) + loss=tf.reduce_mean(loss,axis=1) + mask = tf.cast(mask, dtype=tf.float32) + mask /= tf.reduce_mean(mask) + loss *= mask + return tf.reduce_mean(loss) + + def masked_accuracy(logits, labels, mask): + """Accuracy with masking.""" + correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) + accuracy_all = tf.cast(correct_prediction, tf.float32) + mask = tf.cast(mask, dtype=tf.float32) + mask /= tf.reduce_mean(mask) + accuracy_all *= mask + return tf.reduce_mean(accuracy_all) + + def micro_f1(logits, labels, mask): + """Accuracy with masking.""" + predicted = tf.round(tf.nn.sigmoid(logits)) + + # Use integers to avoid any nasty FP behaviour + predicted = tf.cast(predicted, dtype=tf.int32) + labels = tf.cast(labels, dtype=tf.int32) + mask = tf.cast(mask, dtype=tf.int32) + + # expand the mask so that broadcasting works ([nb_nodes, 1]) + mask = tf.expand_dims(mask, -1) + + # Count true positives, true negatives, false positives and false negatives. + tp = tf.count_nonzero(predicted * labels * mask) + tn = tf.count_nonzero((predicted - 1) * (labels - 1) * mask) + fp = tf.count_nonzero(predicted * (labels - 1) * mask) + fn = tf.count_nonzero((predicted - 1) * labels * mask) + + # Calculate accuracy, precision, recall and F1 score. + precision = tp / (tp + fp) + recall = tp / (tp + fn) + fmeasure = (2 * precision * recall) / (precision + recall) + fmeasure = tf.cast(fmeasure, tf.float32) + return fmeasure diff --git a/models/gat.py b/models/gat.py new file mode 100644 index 0000000..005e65e --- /dev/null +++ b/models/gat.py @@ -0,0 +1,31 @@ +import numpy as np +import tensorflow as tf + +from utils import layers +from models.base_gattn import BaseGAttN + +class GAT(BaseGAttN): + def inference(inputs, nb_classes, nb_nodes, training, attn_drop, ffd_drop, + bias_mat, hid_units=[16], n_heads=1, n_layers=2, activation=tf.nn.elu, residual=False): + attns = [] + for _ in range(n_heads[0]): + attns.append(layers.attn_head(inputs, bias_mat=bias_mat, + out_sz=hid_units[0], activation=activation, + in_drop=ffd_drop, coef_drop=attn_drop, residual=False)) + h_1 = tf.concat(attns, axis=-1) + for i in range(1, len(hid_units)): + h_old = h_1 + attns = [] + for _ in range(n_heads[i]): + attns.append(layers.attn_head(h_1, bias_mat=bias_mat, + out_sz=hid_units[i], activation=activation, + in_drop=ffd_drop, coef_drop=attn_drop, residual=residual)) + h_1 = tf.concat(attns, axis=-1) + out = [] + for i in range(n_heads[-1]): + out.append(layers.attn_head(h_1, bias_mat=bias_mat, + out_sz=nb_classes, activation=lambda x: x, + in_drop=ffd_drop, coef_drop=attn_drop, residual=False)) + logits = tf.add_n(out) / n_heads[-1] + + return logits diff --git a/pre_trained/cora/checkpoint b/pre_trained/cora/checkpoint new file mode 100644 index 0000000..6d0650d --- /dev/null +++ b/pre_trained/cora/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "mod_cora.ckpt" +all_model_checkpoint_paths: "mod_cora.ckpt" diff --git a/pre_trained/cora/mod_cora.ckpt.data-00000-of-00001 b/pre_trained/cora/mod_cora.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..7026595 Binary files /dev/null and b/pre_trained/cora/mod_cora.ckpt.data-00000-of-00001 differ diff --git a/pre_trained/cora/mod_cora.ckpt.index b/pre_trained/cora/mod_cora.ckpt.index new file mode 100644 index 0000000..cae74b0 Binary files /dev/null and b/pre_trained/cora/mod_cora.ckpt.index differ diff --git a/pre_trained/cora/mod_cora.ckpt.meta b/pre_trained/cora/mod_cora.ckpt.meta new file mode 100644 index 0000000..7e01120 Binary files /dev/null and b/pre_trained/cora/mod_cora.ckpt.meta differ diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/layers.py b/utils/layers.py new file mode 100644 index 0000000..3703201 --- /dev/null +++ b/utils/layers.py @@ -0,0 +1,79 @@ +import numpy as np +import tensorflow as tf + +conv1d = tf.layers.conv1d + +def attn_head(seq, out_sz, bias_mat, activation, in_drop=0.0, coef_drop=0.0, residual=False): + with tf.name_scope('my_attn'): + if in_drop != 0.0: + seq = tf.nn.dropout(seq, 1.0 - in_drop) + + seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False) + + # simplest self-attention possible + f_1 = tf.layers.conv1d(seq_fts, 1, 1) + f_2 = tf.layers.conv1d(seq_fts, 1, 1) + logits = f_1 + tf.transpose(f_2, [0, 2, 1]) + coefs = tf.nn.softmax(tf.nn.leaky_relu(logits) + bias_mat) + + if coef_drop != 0.0: + coefs = tf.nn.dropout(coefs, 1.0 - coef_drop) + if in_drop != 0.0: + seq_fts = tf.nn.dropout(seq_fts, 1.0 - in_drop) + + vals = tf.matmul(coefs, seq_fts) + ret = tf.contrib.layers.bias_add(vals) + + # residual connection + if residual: + if seq.shape[-1] != ret.shape[-1]: + ret = ret + conv1d(seq, ret.shape[-1], 1) # activation + else: + seq_fts = ret + seq + + return activation(ret) # activation + +# Experimental sparse attention head (for running on datasets such as Pubmed) +# N.B. Because of limitations of current TF implementation, will work _only_ if batch_size = 1! +def sp_attn_head(seq, out_sz, adj_mat, activation, nb_nodes, in_drop=0.0, coef_drop=0.0, residual=False): + with tf.name_scope('sp_attn'): + if in_drop != 0.0: + seq = tf.nn.dropout(seq, 1.0 - in_drop) + + seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False) + + # simplest self-attention possible + f_1 = tf.layers.conv1d(seq_fts, 1, 1) + f_2 = tf.layers.conv1d(seq_fts, 1, 1) + logits = tf.sparse_add(adj_mat * f_1, adj_mat * tf.transpose(f_2, [0, 2, 1])) + lrelu = tf.SparseTensor(indices=logits.indices, + values=tf.nn.leaky_relu(logits.values), + dense_shape=logits.dense_shape) + coefs = tf.sparse_softmax(lrelu) + + if coef_drop != 0.0: + coefs = tf.SparseTensor(indices=coefs.indices, + values=tf.nn.dropout(coefs.values, 1.0 - coef_drop), + dense_shape=coefs.dense_shape) + if in_drop != 0.0: + seq_fts = tf.nn.dropout(seq_fts, 1.0 - in_drop) + + # As tf.sparse_tensor_dense_matmul expects its arguments to have rank-2, + # here we make an assumption that our input is of batch size 1, and reshape appropriately. + # The method will fail in all other cases! + coefs = tf.sparse_reshape(coefs, [nb_nodes, nb_nodes]) + seq_fts = tf.squeeze(seq_fts) + vals = tf.sparse_tensor_dense_matmul(coefs, seq_fts) + vals = tf.expand_dims(vals, axis=0) + vals.set_shape([1, nb_nodes, out_sz]) + ret = tf.contrib.layers.bias_add(vals) + + # residual connection + if residual: + if seq.shape[-1] != ret.shape[-1]: + ret = ret + conv1d(seq, ret.shape[-1], 1) # activation + else: + seq_fts = ret + seq + + return activation(ret) # activation + diff --git a/utils/process.py b/utils/process.py new file mode 100644 index 0000000..90269ca --- /dev/null +++ b/utils/process.py @@ -0,0 +1,151 @@ +import numpy as np +import pickle as pkl +import networkx as nx +import scipy.sparse as sp +from scipy.sparse.linalg.eigen.arpack import eigsh +import sys + +""" + Prepare adjacency matrix by expanding up to a given neighbourhood. + This will insert loops on every node. + Finally, the matrix is converted to bias vectors. + Expected shape: [graph, nodes, nodes] +""" +def adj_to_bias(adj, sizes, nhood=1): + nb_graphs = adj.shape[0] + mt = np.empty(adj.shape) + for g in range(nb_graphs): + mt[g] = np.eye(adj.shape[1]) + for _ in range(nhood): + mt[g] = np.matmul(mt[g], (adj[g] + np.eye(adj.shape[1]))) + for i in range(sizes[g]): + for j in range(sizes[g]): + if mt[g][i][j] > 0.0: + mt[g][i][j] = 1.0 + return -1e9 * (1.0 - mt) + + +############################################### +# This section of code adapted from tkipf/gcn # +############################################### + +def parse_index_file(filename): + """Parse index file.""" + index = [] + for line in open(filename): + index.append(int(line.strip())) + return index + +def sample_mask(idx, l): + """Create mask.""" + mask = np.zeros(l) + mask[idx] = 1 + return np.array(mask, dtype=np.bool) + +def load_data(dataset_str): # {'pubmed', 'citeseer', 'cora'} + """Load data.""" + names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph'] + objects = [] + for i in range(len(names)): + with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f: + if sys.version_info > (3, 0): + objects.append(pkl.load(f, encoding='latin1')) + else: + objects.append(pkl.load(f)) + + x, y, tx, ty, allx, ally, graph = tuple(objects) + test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str)) + test_idx_range = np.sort(test_idx_reorder) + + if dataset_str == 'citeseer': + # Fix citeseer dataset (there are some isolated nodes in the graph) + # Find isolated nodes, add them as zero-vecs into the right position + test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) + tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) + tx_extended[test_idx_range-min(test_idx_range), :] = tx + tx = tx_extended + ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) + ty_extended[test_idx_range-min(test_idx_range), :] = ty + ty = ty_extended + + features = sp.vstack((allx, tx)).tolil() + features[test_idx_reorder, :] = features[test_idx_range, :] + adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) + + labels = np.vstack((ally, ty)) + labels[test_idx_reorder, :] = labels[test_idx_range, :] + + idx_test = test_idx_range.tolist() + idx_train = range(len(y)) + idx_val = range(len(y), len(y)+500) + + train_mask = sample_mask(idx_train, labels.shape[0]) + val_mask = sample_mask(idx_val, labels.shape[0]) + test_mask = sample_mask(idx_test, labels.shape[0]) + + y_train = np.zeros(labels.shape) + y_val = np.zeros(labels.shape) + y_test = np.zeros(labels.shape) + y_train[train_mask, :] = labels[train_mask, :] + y_val[val_mask, :] = labels[val_mask, :] + y_test[test_mask, :] = labels[test_mask, :] + + print(adj.shape) + print(features.shape) + + return adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask + +def sparse_to_tuple(sparse_mx): + """Convert sparse matrix to tuple representation.""" + def to_tuple(mx): + if not sp.isspmatrix_coo(mx): + mx = mx.tocoo() + coords = np.vstack((mx.row, mx.col)).transpose() + values = mx.data + shape = mx.shape + return coords, values, shape + + if isinstance(sparse_mx, list): + for i in range(len(sparse_mx)): + sparse_mx[i] = to_tuple(sparse_mx[i]) + else: + sparse_mx = to_tuple(sparse_mx) + + return sparse_mx + +def standardize_data(f, train_mask): + """Standardize feature matrix and convert to tuple representation""" + # standardize data + f = f.todense() + mu = f[train_mask == True, :].mean(axis=0) + sigma = f[train_mask == True, :].std(axis=0) + f = f[:, np.squeeze(np.array(sigma > 0))] + mu = f[train_mask == True, :].mean(axis=0) + sigma = f[train_mask == True, :].std(axis=0) + f = (f - mu) / sigma + return f + +def preprocess_features(features): + """Row-normalize feature matrix and convert to tuple representation""" + rowsum = np.array(features.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + features = r_mat_inv.dot(features) + return features.todense(), sparse_to_tuple(features) + +def normalize_adj(adj): + """Symmetrically normalize adjacency matrix.""" + adj = sp.coo_matrix(adj) + rowsum = np.array(adj.sum(1)) + d_inv_sqrt = np.power(rowsum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() + + +def preprocess_adj(adj): + """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" + adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) + return sparse_to_tuple(adj_normalized) + diff --git a/utils/process_ppi.py b/utils/process_ppi.py new file mode 100644 index 0000000..739e011 --- /dev/null +++ b/utils/process_ppi.py @@ -0,0 +1,271 @@ +import numpy as np +import json +import networkx as nx +from networkx.readwrite import json_graph +import scipy.sparse as sp +import pdb +import sys +sys.setrecursionlimit(99999) + + +def run_dfs(adj, msk, u, ind, nb_nodes): + if msk[u] == -1: + msk[u] = ind + #for v in range(nb_nodes): + for v in adj[u,:].nonzero()[1]: + #if adj[u,v]== 1: + run_dfs(adj, msk, v, ind, nb_nodes) + +# Use depth-first search to split a graph into subgraphs +def dfs_split(adj): + # Assume adj is of shape [nb_nodes, nb_nodes] + nb_nodes = adj.shape[0] + ret = np.full(nb_nodes, -1, dtype=np.int32) + + graph_id = 0 + + for i in range(nb_nodes): + if ret[i] == -1: + run_dfs(adj, ret, i, graph_id, nb_nodes) + graph_id += 1 + + return ret + +def test(adj, mapping): + nb_nodes = adj.shape[0] + for i in range(nb_nodes): + #for j in range(nb_nodes): + for j in adj[i, :].nonzero()[1]: + if mapping[i] != mapping[j]: + # if adj[i,j] == 1: + return False + return True + + + +def find_split(adj, mapping, ds_label): + nb_nodes = adj.shape[0] + dict_splits={} + for i in range(nb_nodes): + #for j in range(nb_nodes): + for j in adj[i, :].nonzero()[1]: + if mapping[i]==0 or mapping[j]==0: + dict_splits[0]=None + elif mapping[i] == mapping[j]: + if ds_label[i]['val'] == ds_label[j]['val'] and ds_label[i]['test'] == ds_label[j]['test']: + + if mapping[i] not in dict_splits.keys(): + if ds_label[i]['val']: + dict_splits[mapping[i]] = 'val' + + elif ds_label[i]['test']: + dict_splits[mapping[i]]='test' + + else: + dict_splits[mapping[i]] = 'train' + + else: + if ds_label[i]['test']: + ind_label='test' + elif ds_label[i]['val']: + ind_label='val' + else: + ind_label='train' + if dict_splits[mapping[i]]!= ind_label: + print ('inconsistent labels within a graph exiting!!!') + return None + else: + print ('label of both nodes different, exiting!!') + return None + return dict_splits + + + + +def process_p2p(): + + + print ('Loading G...') + with open('p2p_dataset/ppi-G.json') as jsonfile: + g_data = json.load(jsonfile) + print (len(g_data)) + G = json_graph.node_link_graph(g_data) + + #Extracting adjacency matrix + adj=nx.adjacency_matrix(G) + + prev_key='' + for key, value in g_data.items(): + if prev_key!=key: + print (key) + prev_key=key + + print ('Loading id_map...') + with open('p2p_dataset/ppi-id_map.json') as jsonfile: + id_map = json.load(jsonfile) + print (len(id_map)) + + id_map = {int(k):int(v) for k,v in id_map.items()} + for key, value in id_map.items(): + id_map[key]=[value] + print (len(id_map)) + + print ('Loading features...') + features_=np.load('p2p_dataset/ppi-feats.npy') + print (features_.shape) + + #standarizing features + from sklearn.preprocessing import StandardScaler + + train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]) + train_feats = features_[train_ids[:,0]] + scaler = StandardScaler() + scaler.fit(train_feats) + features_ = scaler.transform(features_) + + features = sp.csr_matrix(features_).tolil() + + + print ('Loading class_map...') + class_map = {} + with open('p2p_dataset/ppi-class_map.json') as jsonfile: + class_map = json.load(jsonfile) + print (len(class_map)) + + #pdb.set_trace() + #Split graph into sub-graphs + print ('Splitting graph...') + splits=dfs_split(adj) + + #Rearrange sub-graph index and append sub-graphs with 1 or 2 nodes to bigger sub-graphs + print ('Re-arranging sub-graph IDs...') + list_splits=splits.tolist() + group_inc=1 + + for i in range(np.max(list_splits)+1): + if list_splits.count(i)>=3: + splits[np.array(list_splits) == i] =group_inc + group_inc+=1 + else: + #splits[np.array(list_splits) == i] = 0 + ind_nodes=np.argwhere(np.array(list_splits) == i) + ind_nodes=ind_nodes[:,0].tolist() + split=None + + for ind_node in ind_nodes: + if g_data['nodes'][ind_node]['val']: + if split is None or split=='val': + splits[np.array(list_splits) == i] = 21 + split='val' + else: + raise ValueError('new node is VAL but previously was {}'.format(split)) + elif g_data['nodes'][ind_node]['test']: + if split is None or split=='test': + splits[np.array(list_splits) == i] = 23 + split='test' + else: + raise ValueError('new node is TEST but previously was {}'.format(split)) + else: + if split is None or split == 'train': + splits[np.array(list_splits) == i] = 1 + split='train' + else: + pdb.set_trace() + raise ValueError('new node is TRAIN but previously was {}'.format(split)) + + #counting number of nodes per sub-graph + list_splits=splits.tolist() + nodes_per_graph=[] + for i in range(1,np.max(list_splits) + 1): + nodes_per_graph.append(list_splits.count(i)) + + #Splitting adj matrix into sub-graphs + subgraph_nodes=np.max(nodes_per_graph) + adj_sub=np.empty((len(nodes_per_graph), subgraph_nodes, subgraph_nodes)) + feat_sub = np.empty((len(nodes_per_graph), subgraph_nodes, features.shape[1])) + labels_sub = np.empty((len(nodes_per_graph), subgraph_nodes, 121)) + + for i in range(1, np.max(list_splits) + 1): + #Creating same size sub-graphs + indexes = np.where(splits == i)[0] + subgraph_=adj[indexes,:][:,indexes] + + if subgraph_.shape[0]