Skip to content

Commit 86a59d7

Browse files
Circle CICircle CI
authored andcommitted
CircleCI update of dev docs (3352).
1 parent e7f8941 commit 86a59d7

File tree

315 files changed

+237682
-234870
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

315 files changed

+237682
-234870
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Solving Many Optimal Transport Problems in Parallel\n\nIn some situations, one may want to solve many OT problems with the same\nstructure (same number of samples, same cost function, etc.) at the same time.\n\nIn that case using a for loop to solve the problems sequentially is inefficient.\nThis example shows how to use the batch solvers implemented in POT to solve\nmany problems in parallel on CPU or GPU (even more efficient on GPU).\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: Paul Krzakala <paul.krzakala@gmail.com>\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Computing the Cost Matrices\n\nWe want to create a batch of optimal transport problems with\n$n$ samples in $d$ dimensions.\n\nTo do this, we first need to compute the cost matrices for each problem.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>A straightforward approach would be to use a Python loop and\n :func:`ot.dist`.\n However, this is inefficient when working with batches.</p></div>\n\nInstead, you can directly use :func:`ot.batch.dist_batch`, which computes\nall cost matrices in parallel.\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"import ot\nimport numpy as np\n\nn_problems = 4 # nb problems/batch size\nn_samples = 8 # nb samples\ndim = 2 # nb dimensions\n\nnp.random.seed(0)\nsamples_source = np.random.randn(n_problems, n_samples, dim)\nsamples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim)\n\n# Naive approach\nM_list = []\nfor i in range(n_problems):\n M_list.append(\n ot.dist(samples_source[i], samples_target[i])\n ) # List of cost matrices n_samples x n_samples\n# Batched approach\nM_batch = ot.batch.dist_batch(\n samples_source, samples_target\n) # Array of cost matrices n_problems x n_samples x n_samples\n\nfor i in range(n_problems):\n assert np.allclose(M_list[i], M_batch[i])"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Solving the Problems\n\nOnce the cost matrices are computed, we can solve the corresponding\noptimal transport problems.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>One option is to solve them sequentially with a Python loop using\n :func:`ot.solve`.\n This is simple but inefficient for large batches.</p></div>\n\nInstead, you can use :func:`ot.batch.solve_batch`, which solves all\nproblems in parallel.\n\n"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"collapsed": false
51+
},
52+
"outputs": [],
53+
"source": [
54+
"reg = 1.0\nmax_iter = 100\ntol = 1e-3\n\n# Naive approach\nresults_values_list = []\nfor i in range(n_problems):\n res = ot.solve(M_list[i], reg=reg, max_iter=max_iter, tol=tol, reg_type=\"entropy\")\n results_values_list.append(res.value_linear)\n\n# Batched approach\nresults_batch = ot.batch.solve_batch(\n M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type=\"entropy\"\n)\nresults_values_batch = results_batch.value_linear\n\nassert np.allclose(np.array(results_values_list), results_values_batch, atol=tol * 10)"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Comparing Computation Time\n\nWe now compare the runtime of the two approaches on larger problems.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>The speedup obtained with :mod:`ot.batch` can be even more\n significant when computations are performed on a GPU.</p></div>\n\n"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {
68+
"collapsed": false
69+
},
70+
"outputs": [],
71+
"source": [
72+
"from time import perf_counter\n\nn_problems = 128\nn_samples = 8\ndim = 2\nreg = 10.0\nmax_iter = 1000\ntol = 1e-3\n\nsamples_source = np.random.randn(n_problems, n_samples, dim)\nsamples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim)\n\n\ndef benchmark_naive(samples_source, samples_target):\n start = perf_counter()\n for i in range(n_problems):\n M = ot.dist(samples_source[i], samples_target[i])\n res = ot.solve(M, reg=reg, max_iter=max_iter, tol=tol, reg_type=\"entropy\")\n end = perf_counter()\n return end - start\n\n\ndef benchmark_batch(samples_source, samples_target):\n start = perf_counter()\n M_batch = ot.batch.dist_batch(samples_source, samples_target)\n res_batch = ot.batch.solve_batch(\n M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type=\"entropy\"\n )\n end = perf_counter()\n return end - start\n\n\ntime_naive = benchmark_naive(samples_source, samples_target)\ntime_batch = benchmark_batch(samples_source, samples_target)\n\nprint(f\"Naive approach time: {time_naive:.4f} seconds\")\nprint(f\"Batched approach time: {time_batch:.4f} seconds\")"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"## Gromov-Wasserstein\n\nThe :mod:`ot.batch` module also provides a batched Gromov-Wasserstein solver.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>This solver is **not** equivalent to calling :func:`ot.solve_gromov`\n repeatedly in a loop.</p></div>\n\nKey differences:\n\n- :func:`ot.solve_gromov`\n Uses the conditional gradient algorithm. Each inner iteration relies on\n an exact EMD solver.\n\n- :func:`ot.batch.solve_gromov_batch`\n Uses a proximal variant, where each inner iteration applies entropic\n regularization.\n\nAs a result:\n\n- :func:`ot.solve_gromov` is usually faster on CPU\n- :func:`ot.batch.solve_gromov_batch` is slower on CPU, but provides\n better objective values.\n\n.. tip::\n If your data is on a GPU, :func:`ot.batch.solve_gromov_batch`\n is significantly faster AND provides better objective values.\n\n"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"collapsed": false
87+
},
88+
"outputs": [],
89+
"source": [
90+
"from ot import solve_gromov\nfrom ot.batch import solve_gromov_batch\n\n\ndef benchmark_naive_gw(samples_source, samples_target):\n start = perf_counter()\n avg_value = 0\n for i in range(n_problems):\n C1 = ot.dist(samples_source[i], samples_source[i])\n C2 = ot.dist(samples_target[i], samples_target[i])\n res = solve_gromov(C1, C2, max_iter=1000, tol=tol)\n avg_value += res.value\n avg_value /= n_problems\n end = perf_counter()\n return end - start, avg_value\n\n\ndef benchmark_batch_gw(samples_source, samples_target):\n start = perf_counter()\n C1_batch = ot.batch.dist_batch(samples_source, samples_source)\n C2_batch = ot.batch.dist_batch(samples_target, samples_target)\n res_batch = solve_gromov_batch(\n C1_batch, C2_batch, reg=1, max_iter=100, max_iter_inner=50, tol=tol\n )\n avg_value = np.mean(res_batch.value)\n end = perf_counter()\n return end - start, avg_value\n\n\ntime_naive_gw, avg_value_naive_gw = benchmark_naive_gw(samples_source, samples_target)\ntime_batch_gw, avg_value_batch_gw = benchmark_batch_gw(samples_source, samples_target)\n\nprint(f\"{'Method':<20}{'Time (s)':<15}{'Avg Value':<15}\")\nprint(f\"{'Naive GW':<20}{time_naive_gw:<15.4f}{avg_value_naive_gw:<15.4f}\")\nprint(f\"{'Batched GW':<20}{time_batch_gw:<15.4f}{avg_value_batch_gw:<15.4f}\")"
91+
]
92+
},
93+
{
94+
"cell_type": "markdown",
95+
"metadata": {},
96+
"source": [
97+
"## In summary: no more for loops!\n\n"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {
104+
"collapsed": false
105+
},
106+
"outputs": [],
107+
"source": [
108+
"import matplotlib.pyplot as plt\n\nfig, ax = plt.subplots(figsize=(4, 4))\nax.text(0.5, 0.5, \"For\", fontsize=160, ha=\"center\", va=\"center\", zorder=0)\nax.axis(\"off\")\nax.plot([0, 1], [0, 1], color=\"red\", linewidth=10, zorder=1)\nax.plot([0, 1], [1, 0], color=\"red\", linewidth=10, zorder=1)\nplt.show()"
109+
]
110+
}
111+
],
112+
"metadata": {
113+
"kernelspec": {
114+
"display_name": "Python 3",
115+
"language": "python",
116+
"name": "python3"
117+
},
118+
"language_info": {
119+
"codemirror_mode": {
120+
"name": "ipython",
121+
"version": 3
122+
},
123+
"file_extension": ".py",
124+
"mimetype": "text/x-python",
125+
"name": "python",
126+
"nbconvert_exporter": "python",
127+
"pygments_lexer": "ipython3",
128+
"version": "3.10.18"
129+
}
130+
},
131+
"nbformat": 4,
132+
"nbformat_minor": 0
133+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
0 Bytes
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)