Skip to content

Commit 635ae9d

Browse files
committed
move some window weights functions into pytomo3d(for better test purpose)
1 parent de5ba12 commit 635ae9d

File tree

2 files changed

+58
-109
lines changed

2 files changed

+58
-109
lines changed

examples/adjoint_sources/parfile/multitaper.adjoint.50_100.config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ process_config:
3737
# for sum multiple insturments, like "II.AAK.00.BHZ" and "II.AAK.10.BHZ". if you turn
3838
# the weight_flag to be true, then you need also provide the weight_dict in the code
3939
sum_over_comp_flag: False
40-
weight_flag: True
40+
weight_flag: False
4141

4242
# filter the adjoint source
4343
filter_flag: True

pypaw/window_weights.py

Lines changed: 57 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414

1515
import os
1616
from collections import defaultdict
17+
from copy import deepcopy
1718
import numpy as np
1819
import logging
1920
import matplotlib.pyplot as plt
2021
plt.switch_backend('agg') # NOQA
2122

2223
from pyasdf import ASDFDataSet
2324
from pytomo3d.adjoint.sum_adjoint import check_events_consistent
24-
from pytomo3d.window.window_weights import determine_receiver_weighting, \
25-
determine_category_weighting
25+
from pytomo3d.window.window_weights import \
26+
calculate_receiver_weights_interface, \
27+
calculate_category_weights_interface,\
28+
combine_receiver_and_category_weights
2629
from pypaw.bins.utils import load_json, dump_json, load_yaml
2730

2831

@@ -112,7 +115,7 @@ def validate_param(param):
112115

113116
def extract_receiver_locations(station_file, windows):
114117
"""
115-
Extract receiver location information from asdf file
118+
Extract receiver location information from station json file
116119
"""
117120
station_info = load_json(station_file)
118121
return station_info
@@ -139,7 +142,7 @@ def extract_source_location(input_info):
139142
origin = event_base.preferred_origin()
140143
src_info = {
141144
"latitude": origin.latitude, "longitude": origin.longitude,
142-
"depth:": origin.depth}
145+
"depth_in_m": origin.depth}
143146

144147
return src_info
145148

@@ -165,32 +168,17 @@ def plot_histogram(figname, array, nbins=50):
165168
plt.savefig(figname)
166169

167170

168-
def combine_weights(rec_weights, cat_weights):
171+
def validate_overall_weights(weights_array, nwins_array):
169172
"""
170-
Combine weights for receiver weighting and category weighting
173+
Validate the overall weights.
171174
"""
172-
logger_block("Combine Weighting")
173-
# combine weights
174-
weights = {}
175-
for period, period_info in rec_weights.iteritems():
176-
weights[period] = {}
177-
for comp, comp_info in period_info.iteritems():
178-
for chan_id in comp_info:
179-
rec_weight = comp_info[chan_id]
180-
cat_weight = cat_weights[period][comp]
181-
_weight = {"receiver": rec_weight,
182-
"category": cat_weight}
183-
_weight["weight"] = \
184-
rec_weight * cat_weight
185-
weights[period][chan_id] = _weight
186-
return weights
187-
188-
189-
def validate_overall_weights(weights_array, nwins_array):
190175
wsum = np.dot(nwins_array, weights_array)
191-
if not np.isclose(wsum, 1.0):
176+
logger.info("Summation of weights*nwindows: %.5e" % wsum)
177+
nwins_total = np.sum(nwins_array)
178+
if not np.isclose(wsum, nwins_total):
192179
raise ValueError("The sum of all weights(%f) does not add "
193-
"up to 1.0" % wsum)
180+
"up to total number of windows"
181+
% (wsum, nwins_total))
194182

195183

196184
def analyze_overall_weights(weights, rec_wcounts, logdir):
@@ -199,10 +187,11 @@ def analyze_overall_weights(weights, rec_wcounts, logdir):
199187
# validate the sum of all weights is 1
200188
for _p, _pw in weights.iteritems():
201189
for _chan, _chanw in _pw.iteritems():
202-
nwins_array.append(rec_wcounts[_p][_chan])
190+
comp = _chan.split(".")[-1]
191+
nwins_array.append(rec_wcounts[_p][comp][_chan])
203192
weights_array.append(_chanw["weight"])
204193

205-
validate_overall_weights(weights)
194+
validate_overall_weights(weights_array, nwins_array)
206195

207196
figname = os.path.join(logdir, "weights.hist.png")
208197
plot_histogram(figname, weights_array)
@@ -211,13 +200,19 @@ def analyze_overall_weights(weights, rec_wcounts, logdir):
211200

212201
maxw = max(weights_array)
213202
minw = min(weights_array)
214-
logger.info("Total number of receivers: %d" % len(weights_array))
215-
logger.info("Total number of windows: %d" % np.sum(nwins_array))
203+
nreceivers = len(weights_array)
204+
nwindows = np.sum(nwins_array)
205+
logger.info("Total number of receivers: %d" % nreceivers)
206+
logger.info("Total number of windows: %d" % nwindows)
216207
logger.info("Weight max, min, max/min: %f, %f, %f"
217208
% (maxw, minw, maxw/minw))
218209

219-
return {"max_weights": maxw, "min_weights": minw,
220-
"total_nwindows": np.sum(nwins_array)}
210+
logfile = os.path.join(logdir, "weights.summary.json")
211+
content = {"max_weights": maxw, "min_weights": minw,
212+
"total_nwindows": np.sum(nwins_array),
213+
"windows": nwindows, "receivers": nreceivers}
214+
logger.info("Overall log file: %s" % logfile)
215+
dump_json(content, logfile)
221216

222217

223218
class WindowWeight(object):
@@ -235,15 +230,15 @@ def __init__(self, path, param):
235230

236231
self.weights = None
237232

238-
self.rec_weights = None
239-
self.rec_wcounts = None
240-
self.rec_ref_dists = None
241-
self.rec_cond_nums = None
233+
self.rec_weights = {}
234+
self.rec_wcounts = {}
235+
self.rec_ref_dists = {}
236+
self.rec_cond_nums = {}
242237

243-
self.cat_wcounts = None
238+
self.cat_wcounts = {}
244239
self.cat_weights = None
245240

246-
def analysis_receiver(self, logfile):
241+
def analyze_receiver(self, logfile):
247242
log = {}
248243
for _p, _pw in self.weights.iteritems():
249244
log[_p] = {}
@@ -264,32 +259,7 @@ def analysis_receiver(self, logfile):
264259

265260
dump_json(log, logfile)
266261

267-
def analysis_source(self, logfile):
268-
"""
269-
dump source weights and some statistic information
270-
"""
271-
log = {"source_weights": self.src_weights}
272-
summary = {}
273-
for _p, _pw in self.src_weights.iteritems():
274-
summary[_p] = {}
275-
for _comp, _compw in _pw.iteritems():
276-
maxw = 0
277-
minw = 10**9
278-
for _ev, _evw in _compw.iteritems():
279-
if _evw > maxw:
280-
maxw = _evw
281-
if _evw < minw:
282-
minw = _evw
283-
summary[_p][_comp] = \
284-
{"maxw": maxw, "minw": minw,
285-
"ref_distance": self.src_ref_dists[_p][_comp],
286-
"cond_num": self.src_cond_nums[_p][_comp],
287-
"nwindows": self.cat_wcounts[_p][_comp]}
288-
289-
log["summary"] = summary
290-
dump_json(log, logfile)
291-
292-
def analysis_category(self, logfile):
262+
def analyze_category(self, logfile):
293263
log = {"category_weights": self.cat_weights}
294264
maxw = 0
295265
minw = 10**9
@@ -304,97 +274,76 @@ def analysis_category(self, logfile):
304274

305275
dump_json(log, logfile)
306276

307-
def analysis(self):
277+
def analyze(self):
308278
"""
309279
Analyze the final weight and generate log file
310280
"""
311281
logger_block("Summary")
312282
logdir = os.path.dirname(self.path["logfile"])
313283
logfile = os.path.join(logdir, "log.receiver_weights.json")
314284
logger.info("receiver log file: %s" % logfile)
315-
self.analysis_receiver(logfile)
285+
self.analyze_receiver(logfile)
316286

317287
logfile = os.path.join(logdir, "log.category_weights.json")
318288
logger.info("category log file: %s" % logfile)
319-
self.analysis_category(logfile)
289+
self.analyze_category(logfile)
290+
291+
analyze_overall_weights(self.weights, self.rec_wcounts, logdir)
320292

321293
def dump_weights(self):
322294
""" dump weights to files """
323295
for period, period_info in self.weights.iteritems():
324296
outputfn = self.path['input'][period]["output_file"]
325297
dump_json(period_info, outputfn)
326298

327-
def calculate_receiver_weights(self):
299+
def calculate_receiver_weights_asdf(self):
328300
"""
329-
calculate receiver weights for each asdf file
330-
detertmine source weightings based on source infor and window
331-
count and info
301+
calculate receiver weights for each asdf file. Since
302+
each asdf file contains three components from one period
303+
band, there are 3 categories which should be treated(weighted
304+
and normalized)separately.
332305
"""
333306
logger_block("Receiver Weighting")
334307
input_info = self.path["input"]
335308

336309
weighting_param = self.param["receiver_weighting"]
337310

338-
self.rec_weights = defaultdict(dict)
339-
self.rec_wcounts = defaultdict(dict)
340-
self.rec_ref_dists = defaultdict(dict)
341-
self.rec_cond_nums = defaultdict(dict)
342-
self.cat_wcounts = defaultdict(dict)
343-
344311
nperiods = len(input_info)
345312
period_idx = 0
346313
# determine receiver weightings for each asdf file
347314
for period, period_info in input_info.iteritems():
348315
period_idx += 1
349316
logger.info("-" * 15 + "[%d/%d]Period band: %s"
350317
% (period_idx, nperiods, period) + "-" * 15)
351-
352-
_results = self.calculate_receiver_weights_asdf(
353-
period_info, weighting_param)
318+
_path_info = deepcopy(period_info)
319+
_path_info.pop("asdf_file", None)
320+
# the _results contains three components data
321+
_results = calculate_receiver_weights_interface(
322+
self.src_info, _path_info, weighting_param)
354323

355324
self.rec_weights[period] = _results["rec_weights"]
356325
self.rec_wcounts[period] = _results["rec_wcounts"]
357326
self.rec_ref_dists[period] = _results["rec_ref_dists"]
358327
self.rec_cond_nums[period] = _results["rec_cond_nums"]
359328
self.cat_wcounts[period] = _results["cat_wcounts"]
360329

361-
def calculate_receiver_weights_asdf(self, period_info, weighting_param):
362-
search_ratio = weighting_param["search_ratio"]
363-
plot_flag = weighting_param["plot"]
364-
weight_flag = weighting_param["flag"]
365-
# each file still contains 3-component
366-
logger.info("station file: %s" % period_info["station_file"])
367-
logger.info("window file: %s" % period_info["window_file"])
368-
logger.info("output file: %s" % period_info["output_file"])
369-
station_info = load_json(period_info["station_file"])
370-
window_info = load_json(period_info["window_file"])
371-
372-
outputdir = os.path.dirname(period_info["output_file"])
373-
safe_mkdir(outputdir)
374-
figname_prefix = os.path.join(outputdir, "weights")
375-
376-
_results = determine_receiver_weighting(
377-
self.src_info, station_info, window_info,
378-
search_ratio=search_ratio,
379-
weight_flag=weight_flag,
380-
plot_flag=plot_flag, figname_prefix=figname_prefix)
381-
382-
return _results
383-
384330
def smart_run(self):
385-
386331
validate_path(self.path)
387332
validate_param(self.param)
388333
# extract source location information
389334
self.src_info = extract_source_location(self.path["input"])
390335
# calculate receiver weights
391-
self.calculate_receiver_weights()
336+
self.calculate_receiver_weights_asdf()
392337
# calculate category weights
393-
self.cat_weights = determine_category_weighting(
338+
self.cat_weights = calculate_category_weights_interface(
394339
self.param["category_weighting"], self.cat_wcounts)
340+
395341
# combine the receiver weights with category weights
396-
self.weights = combine_weights(self.rec_weights, self.cat_weights)
342+
logger_block("Combine Weights")
343+
self.weights = combine_receiver_and_category_weights(
344+
self.rec_weights, self.cat_weights)
345+
397346
# statistical analysis
398-
self.analysis()
347+
self.analyze()
399348
# dump the results out
400349
self.dump_weights()

0 commit comments

Comments
 (0)