|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +========================== |
| 4 | +Entropic-regularized semi-relaxed (Fused) Gromov-Wasserstein example |
| 5 | +========================== |
| 6 | +
|
| 7 | +This example is designed to show how to use the entropic semi-relaxed Gromov-Wasserstein |
| 8 | +and the entropic semi-relaxed Fused Gromov-Wasserstein divergences. |
| 9 | +
|
| 10 | +Entropic-regularized sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of |
| 11 | +G2 at a minimal entropic-regularized (F)GW distance from G1. |
| 12 | +
|
| 13 | +First, we generate two graphs following Stochastic Block Models, then show |
| 14 | +how to compute their srGW matchings and illustrate them. These graphs are then |
| 15 | +endowed with node features and we follow the same process with srFGW. |
| 16 | +
|
| 17 | +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. |
| 18 | +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" |
| 19 | +International Conference on Learning Representations (ICLR), 2021. |
| 20 | +""" |
| 21 | + |
| 22 | +# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com> |
| 23 | +# |
| 24 | +# License: MIT License |
| 25 | + |
| 26 | +# sphinx_gallery_thumbnail_number = 1 |
| 27 | + |
| 28 | +import numpy as np |
| 29 | +import matplotlib.pylab as pl |
| 30 | +from ot.gromov import entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein |
| 31 | +import networkx |
| 32 | +from networkx.generators.community import stochastic_block_model as sbm |
| 33 | + |
| 34 | +############################################################################# |
| 35 | +# |
| 36 | +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. |
| 37 | +# --------------------------------------------- |
| 38 | + |
| 39 | + |
| 40 | +N2 = 20 # 2 communities |
| 41 | +N3 = 30 # 3 communities |
| 42 | +p2 = [[1., 0.1], |
| 43 | + [0.1, 0.9]] |
| 44 | +p3 = [[1., 0.1, 0.], |
| 45 | + [0.1, 0.95, 0.1], |
| 46 | + [0., 0.1, 0.9]] |
| 47 | +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) |
| 48 | +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) |
| 49 | + |
| 50 | + |
| 51 | +C2 = networkx.to_numpy_array(G2) |
| 52 | +C3 = networkx.to_numpy_array(G3) |
| 53 | + |
| 54 | +h2 = np.ones(C2.shape[0]) / C2.shape[0] |
| 55 | +h3 = np.ones(C3.shape[0]) / C3.shape[0] |
| 56 | + |
| 57 | +# Add weights on the edges for visualization later on |
| 58 | +weight_intra_G2 = 5 |
| 59 | +weight_inter_G2 = 0.5 |
| 60 | +weight_intra_G3 = 1. |
| 61 | +weight_inter_G3 = 1.5 |
| 62 | + |
| 63 | +weightedG2 = networkx.Graph() |
| 64 | +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] |
| 65 | + |
| 66 | +for node in G2.nodes(): |
| 67 | + weightedG2.add_node(node) |
| 68 | +for i, j in G2.edges(): |
| 69 | + if part_G2[i] == part_G2[j]: |
| 70 | + weightedG2.add_edge(i, j, weight=weight_intra_G2) |
| 71 | + else: |
| 72 | + weightedG2.add_edge(i, j, weight=weight_inter_G2) |
| 73 | + |
| 74 | +weightedG3 = networkx.Graph() |
| 75 | +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] |
| 76 | + |
| 77 | +for node in G3.nodes(): |
| 78 | + weightedG3.add_node(node) |
| 79 | +for i, j in G3.edges(): |
| 80 | + if part_G3[i] == part_G3[j]: |
| 81 | + weightedG3.add_edge(i, j, weight=weight_intra_G3) |
| 82 | + else: |
| 83 | + weightedG3.add_edge(i, j, weight=weight_inter_G3) |
| 84 | + |
| 85 | +############################################################################# |
| 86 | +# |
| 87 | +# Compute their entropic-regularized semi-relaxed Gromov-Wasserstein divergences |
| 88 | +# --------------------------------------------- |
| 89 | + |
| 90 | +# 0) GW(C2, h2, C3, h3) for reference |
| 91 | +OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) |
| 92 | +gw = log['gw_dist'] |
| 93 | + |
| 94 | +# 1) srGW_e(C2, h2, C3) |
| 95 | +OT_23, log_23 = entropic_semirelaxed_gromov_wasserstein( |
| 96 | + C2, C3, h2, symmetric=True, epsilon=1., G0=None, log=True) |
| 97 | +srgw_23 = log_23['srgw_dist'] |
| 98 | + |
| 99 | +# 2) srGW_e(C3, h3, C2) |
| 100 | + |
| 101 | +OT_32, log_32 = entropic_semirelaxed_gromov_wasserstein( |
| 102 | + C3, C2, h3, symmetric=None, epsilon=1., G0=None, log=True) |
| 103 | +srgw_32 = log_32['srgw_dist'] |
| 104 | + |
| 105 | +print('GW(C2, C3) = ', gw) |
| 106 | +print('srGW_e(C2, h2, C3) = ', srgw_23) |
| 107 | +print('srGW_e(C3, h3, C2) = ', srgw_32) |
| 108 | + |
| 109 | + |
| 110 | +############################################################################# |
| 111 | +# |
| 112 | +# Visualization of the entropic-regularized semi-relaxed Gromov-Wasserstein matchings |
| 113 | +# --------------------------------------------- |
| 114 | +# |
| 115 | +# We color nodes of the graph on the right - then project its node colors |
| 116 | +# based on the optimal transport plan from the entropic srGW matching. |
| 117 | +# We adjust the intensity of links across domains proportionaly to the mass |
| 118 | +# sent, adding a minimal intensity of 0.1 if mass sent is not zero. |
| 119 | + |
| 120 | + |
| 121 | +def draw_graph(G, C, nodes_color_part, Gweights=None, |
| 122 | + pos=None, edge_color='black', node_size=None, |
| 123 | + shiftx=0, seed=0): |
| 124 | + |
| 125 | + if (pos is None): |
| 126 | + pos = networkx.spring_layout(G, scale=1., seed=seed) |
| 127 | + |
| 128 | + if shiftx != 0: |
| 129 | + for k, v in pos.items(): |
| 130 | + v[0] = v[0] + shiftx |
| 131 | + |
| 132 | + alpha_edge = 0.7 |
| 133 | + width_edge = 1.8 |
| 134 | + if Gweights is None: |
| 135 | + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) |
| 136 | + else: |
| 137 | + # We make more visible connections between activated nodes |
| 138 | + n = len(Gweights) |
| 139 | + edgelist_activated = [] |
| 140 | + edgelist_deactivated = [] |
| 141 | + for i in range(n): |
| 142 | + for j in range(n): |
| 143 | + if Gweights[i] * Gweights[j] * C[i, j] > 0: |
| 144 | + edgelist_activated.append((i, j)) |
| 145 | + elif C[i, j] > 0: |
| 146 | + edgelist_deactivated.append((i, j)) |
| 147 | + |
| 148 | + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, |
| 149 | + width=width_edge, alpha=alpha_edge, |
| 150 | + edge_color=edge_color) |
| 151 | + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, |
| 152 | + width=width_edge, alpha=0.1, |
| 153 | + edge_color=edge_color) |
| 154 | + |
| 155 | + if Gweights is None: |
| 156 | + for node, node_color in enumerate(nodes_color_part): |
| 157 | + networkx.draw_networkx_nodes(G, pos, nodelist=[node], |
| 158 | + node_size=node_size, alpha=1, |
| 159 | + node_color=node_color) |
| 160 | + else: |
| 161 | + scaled_Gweights = Gweights / (0.5 * Gweights.max()) |
| 162 | + nodes_size = node_size * scaled_Gweights |
| 163 | + for node, node_color in enumerate(nodes_color_part): |
| 164 | + networkx.draw_networkx_nodes(G, pos, nodelist=[node], |
| 165 | + node_size=nodes_size[node], alpha=1, |
| 166 | + node_color=node_color) |
| 167 | + return pos |
| 168 | + |
| 169 | + |
| 170 | +def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, |
| 171 | + p1, p2, T, pos1=None, pos2=None, |
| 172 | + shiftx=4, switchx=False, node_size=70, |
| 173 | + seed_G1=0, seed_G2=0): |
| 174 | + starting_color = 0 |
| 175 | + # get graphs partition and their coloring |
| 176 | + part1 = part_G1.copy() |
| 177 | + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] |
| 178 | + nodes_color_part1 = [] |
| 179 | + for cluster in part1: |
| 180 | + nodes_color_part1.append(unique_colors[cluster]) |
| 181 | + |
| 182 | + nodes_color_part2 = [] |
| 183 | + # T: getting colors assignment from argmin of columns |
| 184 | + for i in range(len(G2.nodes())): |
| 185 | + j = np.argmax(T[:, i]) |
| 186 | + nodes_color_part2.append(nodes_color_part1[j]) |
| 187 | + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, |
| 188 | + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) |
| 189 | + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, |
| 190 | + node_size=node_size, shiftx=shiftx, seed=seed_G2) |
| 191 | + for k1, v1 in pos1.items(): |
| 192 | + max_Tk1 = np.max(T[k1, :]) |
| 193 | + for k2, v2 in pos2.items(): |
| 194 | + if (T[k1, k2] > 0): |
| 195 | + pl.plot([pos1[k1][0], pos2[k2][0]], |
| 196 | + [pos1[k1][1], pos2[k2][1]], |
| 197 | + '-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.), |
| 198 | + color=nodes_color_part1[k1]) |
| 199 | + return pos1, pos2 |
| 200 | + |
| 201 | + |
| 202 | +node_size = 40 |
| 203 | +fontsize = 10 |
| 204 | +seed_G2 = 0 |
| 205 | +seed_G3 = 4 |
| 206 | + |
| 207 | +pl.figure(1, figsize=(8, 2.5)) |
| 208 | +pl.clf() |
| 209 | +pl.subplot(121) |
| 210 | +pl.axis('off') |
| 211 | +pl.axis |
| 212 | +pl.title(r'$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) |
| 213 | + |
| 214 | +hbar2 = OT_23.sum(axis=0) |
| 215 | +pos1, pos2 = draw_transp_colored_srGW( |
| 216 | + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, |
| 217 | + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) |
| 218 | +pl.subplot(122) |
| 219 | +pl.axis('off') |
| 220 | +hbar3 = OT_32.sum(axis=0) |
| 221 | +pl.title(r'$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) |
| 222 | +pos1, pos2 = draw_transp_colored_srGW( |
| 223 | + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, |
| 224 | + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) |
| 225 | +pl.tight_layout() |
| 226 | + |
| 227 | +pl.show() |
| 228 | + |
| 229 | +############################################################################# |
| 230 | +# |
| 231 | +# Add node features |
| 232 | +# --------------------------------------------- |
| 233 | + |
| 234 | +# We add node features with given mean - by clusters |
| 235 | +# and inversely proportional to clusters' intra-connectivity |
| 236 | + |
| 237 | +F2 = np.zeros((N2, 1)) |
| 238 | +for i, c in enumerate(part_G2): |
| 239 | + F2[i, 0] = np.random.normal(loc=c, scale=0.01) |
| 240 | + |
| 241 | +F3 = np.zeros((N3, 1)) |
| 242 | +for i, c in enumerate(part_G3): |
| 243 | + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) |
| 244 | + |
| 245 | +############################################################################# |
| 246 | +# |
| 247 | +# Compute their semi-relaxed Fused Gromov-Wasserstein divergences |
| 248 | +# --------------------------------------------- |
| 249 | + |
| 250 | +alpha = 0.5 |
| 251 | +# Compute pairwise euclidean distance between node features |
| 252 | +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) |
| 253 | + |
| 254 | +# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference |
| 255 | + |
| 256 | +OT, log = fused_gromov_wasserstein( |
| 257 | + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) |
| 258 | +fgw = log['fgw_dist'] |
| 259 | + |
| 260 | +# 1) srFGW_e(C2, F2, h2, C3, F3) |
| 261 | +OT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein( |
| 262 | + M, C2, C3, h2, symmetric=True, epsilon=1., alpha=0.5, log=True, G0=None) |
| 263 | +srfgw_23 = log_23['srfgw_dist'] |
| 264 | + |
| 265 | +# 2) srFGW(C3, F3, h3, C2, F2) |
| 266 | + |
| 267 | +OT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein( |
| 268 | + M.T, C3, C2, h3, symmetric=None, epsilon=1., alpha=alpha, log=True, G0=None) |
| 269 | +srfgw_32 = log_32['srfgw_dist'] |
| 270 | + |
| 271 | +print('FGW(C2, F2, C3, F3) = ', fgw) |
| 272 | +print(r'$srGW_e$(C2, F2, h2, C3, F3) = ', srfgw_23) |
| 273 | +print(r'$srGW_e$(C3, F3, h3, C2, F2) = ', srfgw_32) |
| 274 | + |
| 275 | +############################################################################# |
| 276 | +# |
| 277 | +# Visualization of the entropic semi-relaxed Fused Gromov-Wasserstein matchings |
| 278 | +# --------------------------------------------- |
| 279 | +# |
| 280 | +# We color nodes of the graph on the right - then project its node colors |
| 281 | +# based on the optimal transport plan from the srFGW matching |
| 282 | +# NB: colors refer to clusters - not to node features |
| 283 | + |
| 284 | +pl.figure(2, figsize=(8, 2.5)) |
| 285 | +pl.clf() |
| 286 | +pl.subplot(121) |
| 287 | +pl.axis('off') |
| 288 | +pl.axis |
| 289 | +pl.title(r'$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) |
| 290 | + |
| 291 | +hbar2 = OT_23.sum(axis=0) |
| 292 | +pos1, pos2 = draw_transp_colored_srGW( |
| 293 | + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, |
| 294 | + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) |
| 295 | +pl.subplot(122) |
| 296 | +pl.axis('off') |
| 297 | +hbar3 = OT_32.sum(axis=0) |
| 298 | +pl.title(r'$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) |
| 299 | +pos1, pos2 = draw_transp_colored_srGW( |
| 300 | + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, |
| 301 | + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) |
| 302 | +pl.tight_layout() |
| 303 | + |
| 304 | +pl.show() |
0 commit comments