Skip to content

Commit d42107e

Browse files
committed
Renamed class.measure members to match paper
1 parent b6cf33c commit d42107e

File tree

4 files changed

+38
-39
lines changed

4 files changed

+38
-39
lines changed

src/classes.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -340,33 +340,32 @@ class measure:
340340
List of member curves.
341341
weights : numpy.ndarray
342342
Array of positive weights associated to each curve.
343-
intensities : numpy.ndarray
344-
Array of the intensities associated to each curve.
343+
energies : numpy.ndarray
344+
Array of stored Benamou-Brenier energy associated to each curve.
345345
main_energy : float
346346
The Tikhonov energy of the measure.
347347
"""
348348
def __init__(self):
349349
self.curves = []
350350
self.energies = np.array([])
351-
self.intensities = np.array([])
351+
self.weights = np.array([])
352352
self.main_energy = None
353353

354-
def add(self, new_curve, new_intensity):
355-
# Input: new_curve is a curve class object. new_intensity > 0 real.
356-
if new_intensity > config.measure_coefficient_too_low:
354+
def add(self, new_curve, new_weight):
355+
# Input: new_curve is a curve class object. new_weight > 0 real.
356+
if new_weight > config.measure_coefficient_too_low:
357357
self.curves.extend([new_curve])
358358
self.energies = np.append(self.energies,
359359
new_curve.energy())
360-
self.intensities = np.append(self.intensities, new_intensity)
360+
self.weights = np.append(self.weights, new_weight)
361361
self.main_energy = None
362362

363363
def __add__(self, measure2):
364364
new_measure = copy.deepcopy(self)
365365
new_measure.curves.extend(copy.deepcopy(measure2.curves))
366366
new_measure.energies = np.append(new_measure.energies,
367367
measure2.energies)
368-
new_measure.intensities = np.append(new_measure.intensities,
369-
measure2.intensities)
368+
new_measure.weights = np.append(new_measure.weights, measure2.weights)
370369
new_measure.main_energy = None
371370
return new_measure
372371

@@ -375,46 +374,46 @@ def __mul__(self, factor):
375374
raise Exception('Cannot use a negative factor for a measure')
376375
new_measure = copy.deepcopy(self)
377376
new_measure.main_energy = None
378-
for i in range(len(self.intensities)):
379-
new_measure.intensities[i] = new_measure.intensities[i]*factor
377+
for i in range(len(self.weights)):
378+
new_measure.weights[i] = new_measure.weights[i]*factor
380379
return new_measure
381380

382381
def __rmul__(self, factor):
383382
return self*factor
384383

385-
def modify_intensity(self, curve_index, new_intensity):
384+
def modify_weight(self, curve_index, new_weight):
386385
self.main_energy = None
387386
if curve_index >= len(self.curves):
388387
raise Exception('Trying to modify an unexistant curve! The given'
389388
+ 'curve index is too high for the current array')
390-
if new_intensity < config.measure_coefficient_too_low:
389+
if new_weight < config.measure_coefficient_too_low:
391390
del self.curves[curve_index]
392-
self.intensities = np.delete(self.intensities, curve_index)
391+
self.weights = np.delete(self.weights, curve_index)
393392
self.energies = np.delete(self.energies, curve_index)
394393
else:
395-
self.intensities[curve_index] = new_intensity
394+
self.weights[curve_index] = new_weight
396395

397396
def integrate_against(self, w_t):
398397
assert isinstance(w_t, op.w_t)
399398
# Method to integrate against this measure.
400399
integral = 0
401400
for i, curv in enumerate(self.curves):
402-
integral += self.intensities[i]/self.energies[i] * \
401+
integral += self.weights[i]/self.energies[i] * \
403402
curv.integrate_against(w_t)
404403
return integral
405404

406405
def spatial_integrate(self, t, target):
407406
# Method to integrate against this measure for a fixed time, target is
408407
# a function handle
409408
val = 0
410-
for i in range(len(self.intensities)):
411-
val = val + self.intensities[i]/self.energies[i] * \
409+
for i in range(len(self.weights)):
410+
val = val + self.weights[i]/self.energies[i] * \
412411
target(self.curves[i].eval_discrete(t))
413412
return val
414413

415414
def to_curve_product(self):
416415
# transforms the measure to a curve product object
417-
return curve_product(self.curves, self.intensities)
416+
return curve_product(self.curves, self.weights)
418417

419418
def get_main_energy(self):
420419
if self.main_energy is None:
@@ -425,8 +424,8 @@ def get_main_energy(self):
425424

426425

427426
def draw(self, ax=None):
428-
num_plots = len(self.intensities)
429-
total_intensities = self.intensities/self.energies
427+
num_plots = len(self.weights)
428+
total_intensities = self.weights/self.energies
430429
'get the brg colormap for the intensities of curves'
431430
colors = plt.cm.brg(total_intensities/max(total_intensities))
432431
ax = ax or plt.gca()
@@ -467,13 +466,13 @@ def animate(self, filename=None, show=True, block=False):
467466
def reorder(self):
468467
# Script to reorder the curves inside the measure with an increasing
469468
# total energy.
470-
total_intensities = self.intensities/self.energies
469+
total_intensities = self.weights/self.energies
471470
new_order = np.argsort(total_intensities)
472471
new_measure = measure()
473472
for idx in new_order:
474-
new_measure.add(self.curves[idx], self.intensities[idx])
473+
new_measure.add(self.curves[idx], self.weights[idx])
475474
self.curves = new_measure.curves
476-
self.intensities = new_measure.intensities
475+
self.weights = new_measure.weights
477476
self.energies = new_measure.energies
478477

479478

src/misc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def __init__(self, measure, **kwargs):
3939
#
4040
measure.reorder()
4141
# Define the colors, these depends on the intensities
42-
total_intensities = measure.intensities/measure.energies
42+
total_intensities = measure.weights/measure.energies
4343
brg_cmap = plt.cm.get_cmap('brg')
4444
colors = brg_cmap(np.array(total_intensities)/max(total_intensities))
4545
# Get the family of segments and times
4646
segments = []
4747
times = []
48-
for i in range(len(measure.intensities)):
48+
for i in range(len(measure.weights)):
4949
supsamp_t, supsamp_x = supersample(measure.curves[i],
5050
max_jump=0.01)
5151
# Get segments and use as time the last part of each segment
@@ -74,12 +74,12 @@ def __init__(self, measure, **kwargs):
7474
self.show = show
7575
self.block = block
7676
# For colorbar
77-
norm = mpl.colors.Normalize(vmin=0, vmax=max(measure.intensities))
77+
norm = mpl.colors.Normalize(vmin=0, vmax=max(measure.weights))
7878
cmap = plt.get_cmap('brg', 100)
7979
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
8080
sm.set_array([])
8181
self.fig.colorbar(sm,
82-
ticks=np.linspace(0, max(measure.intensities), 9))
82+
ticks=np.linspace(0, max(measure.weights), 9))
8383

8484

8585
def animate(self, i):

src/operators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class w_t:
124124
def __init__(self, rho_t):
125125
assert isinstance(rho_t, classes.measure)
126126
# take the difference between the current curve and the problem's data.
127-
if rho_t.intensities.size == 0:
127+
if rho_t.weights.size == 0:
128128
if config.f_t is None:
129129
# Case in which the data has not yet been set
130130
self.data = None
@@ -213,8 +213,8 @@ def density_transformation(self, x):
213213

214214
def as_density_get_params(self, t):
215215
if np.isnan(self.as_predensity_mass[t]):
216-
# Produce, and store, the parameters needed to define a density with
217-
# the dual variable. These parameters change for each time t.
216+
# Produce, and store, the parameters needed to define a density
217+
# with the dual variable. These parameters change for each time t.
218218
evaluations = misc.grid_evaluate(lambda x: self.eval(t, x),
219219
resolution=0.01)
220220
# extracting the epsilon support for rejection sampling
@@ -259,5 +259,5 @@ def main_energy(measure, f):
259259
# Output: positive number
260260
forward = K_t_star_full(measure)
261261
diff = forward - f
262-
return int_time_H_t_product(diff, diff)/2 + sum(measure.intensities)
262+
return int_time_H_t_product(diff, diff)/2 + sum(measure.weights)
263263

src/optimization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,14 @@ def after_optimization_sparsifier(current_measure, energy_curves=None):
109109
curve_2 = output_measure.curves[id2]
110110
if (curve_1 - curve_2).H1_norm() < config.H1_tolerance:
111111
# if the curves are close, we have 3 alternatives to test
112-
weight_1 = output_measure.intensities[id1]
113-
weight_2 = output_measure.intensities[id2]
112+
weight_1 = output_measure.weights[id1]
113+
weight_2 = output_measure.weights[id2]
114114
measure_1 = copy.deepcopy(output_measure)
115-
measure_1.modify_intensity(id1, weight_1 + weight_2)
116-
measure_1.modify_intensity(id2, 0)
115+
measure_1.modify_weight(id1, weight_1 + weight_2)
116+
measure_1.modify_weight(id2, 0)
117117
measure_2 = copy.deepcopy(output_measure)
118-
measure_2.modify_intensity(id2, weight_1 + weight_2)
119-
measure_2.modify_intensity(id1, 0)
118+
measure_2.modify_weight(id2, weight_1 + weight_2)
119+
measure_2.modify_weight(id1, 0)
120120
energy_0 = output_measure.get_main_energy()
121121
energy_1 = measure_1.get_main_energy()
122122
energy_2 = measure_2.get_main_energy()
@@ -253,7 +253,7 @@ def full_gradient(current_measure):
253253
curve_list = []
254254
for curve in current_measure.curves:
255255
curve_list.append(grad_F(curve, w_t))
256-
return classes.curve_product(curve_list, current_measure.intensities)
256+
return classes.curve_product(curve_list, current_measure.weights)
257257
# Stop when stepsize get smaller than
258258
limit_stepsize = config.g_flow_limit_stepsize
259259

0 commit comments

Comments
 (0)