Skip to content

Commit fb5417e

Browse files
committed
refine code
1 parent ef9c7d7 commit fb5417e

File tree

4 files changed

+40
-51
lines changed

4 files changed

+40
-51
lines changed

pypaw/bins/filter_windows.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import (absolute_import, division, print_function)
1717
import os
1818
import argparse
19+
from pprint import pprint
1920
from pytomo3d.window.filter_windows import filter_windows, count_windows
2021
from .utils import load_json, dump_json, load_yaml
2122

@@ -35,6 +36,9 @@ def check_path(paths):
3536
if not check_keys(paths, keys):
3637
raise ValueError("Path file is bad!")
3738

39+
print("=" * 10 + " Path info " + "=" * 10)
40+
pprint(paths)
41+
3842

3943
def check_param(params):
4044
keys = ["sensor", "measurement"]
@@ -49,6 +53,9 @@ def check_param(params):
4953
if not check_keys(params["measurement"], keys):
5054
raise ValueError("Param['measurement'] is bad!")
5155

56+
print("=" * 10 + " Path info " + "=" * 10)
57+
pprint(params)
58+
5259

5360
def run_window_filter(paths, params, verbose=False):
5461
check_path(paths)
@@ -59,12 +66,8 @@ def run_window_filter(paths, params, verbose=False):
5966
output_file = paths["output_file"]
6067
measurement_file = paths["measurement_file"]
6168

62-
print("window file: %s" % window_file)
63-
print("station_file: %s" % station_file)
64-
print("measurement_file: %s" % measurement_file)
65-
print("output filtered window file: %s" % output_file)
66-
6769
windows = load_json(window_file)
70+
# count the number of windows in the original window file
6871
nchans_old, nwins_old = count_windows(windows)
6972
stations = load_json(station_file)
7073
measurements = load_json(measurement_file)
@@ -74,17 +77,18 @@ def run_window_filter(paths, params, verbose=False):
7477
windows, stations, measurements, params, verbose=verbose)
7578

7679
nchans_new, nwins_new = count_windows(windows_new)
80+
7781
# dump the new windows file to replace the original one
7882
dump_json(windows_new, output_file)
7983

8084
# dump the log file
8185
logfile = os.path.join(os.path.dirname(output_file), "filter.log")
82-
print("Log file located at: %s" % logfile)
8386
dump_json(log, logfile)
8487

8588
print("=" * 10 + " Summary " + "=" * 10)
8689
print("channels: %d --> %d" % (nchans_old, nchans_new))
8790
print("windows: %d -- > %d" % (nwins_old, nwins_new))
91+
print("Log file located at: %s" % logfile)
8892

8993

9094
def main():

pypaw/sum_adjoint.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,12 @@
1515
from __future__ import print_function, division, absolute_import
1616
import os
1717
from pprint import pprint
18-
import yaml
1918
from pyasdf import ASDFDataSet
20-
from pypaw.bins import load_json, dump_json
2119
from pytomo3d.adjoint.sum_adjoint import load_to_adjsrc, dump_adjsrc, \
2220
check_events_consistent, \
2321
create_weighted_adj, sum_adj_to_base, check_station_consistent, \
2422
rotate_adjoint_sources
25-
26-
27-
def load_yaml(fn):
28-
with open(fn) as fh:
29-
return yaml.load(fh)
23+
from .utils import read_json_file, dump_json, read_yaml_file
3024

3125

3226
def validate_path(path):
@@ -72,6 +66,10 @@ def validate_param(param):
7266

7367

7468
def save_adjoint_to_asdf(outputfile, events, adjoint_sources, stations):
69+
"""
70+
Save events(obspy.Catalog) and adjoint sources, together with
71+
staiton information, to asdf file on disk.
72+
"""
7573
print("="*15 + "\nWrite to file: %s" % outputfile)
7674
outputdir = os.path.dirname(outputfile)
7775
if not os.path.exists(outputdir):
@@ -211,7 +209,7 @@ def sum_asdf(self):
211209
filename = _file_info["asdf_file"]
212210
ds = ASDFDataSet(filename, mode='r')
213211
weight_file = _file_info["weight_file"]
214-
weights = load_json(weight_file)
212+
weights = read_json_file(weight_file)
215213
print("-" * 20)
216214
print("Adding asdf file(%s) using assigned weight_file(%s)"
217215
% (filename, weight_file))
@@ -239,7 +237,7 @@ def dump_to_asdf(self, outputfile):
239237

240238
def _parse_path(self):
241239
if isinstance(self.path, str):
242-
path = load_json(self.path)
240+
path = read_json_file(self.path)
243241
elif isinstance(self.path, dict):
244242
path = self.path
245243
else:
@@ -248,7 +246,7 @@ def _parse_path(self):
248246

249247
def _parse_param(self):
250248
if isinstance(self.param, str):
251-
param = load_yaml(self.param)
249+
param = read_yaml_file(self.param)
252250
elif isinstance(self.param, dict):
253251
param = self.param
254252
else:

pypaw/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, d):
5252
self.__dict__ = d
5353

5454

55-
def read_json_file(parfile, obj_hook=True):
55+
def read_json_file(parfile, obj_hook=False):
5656
"""
5757
Hook json to an JSONObject instance
5858
"""

pypaw/window_weights.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -147,27 +147,28 @@ def extract_source_location(input_info):
147147
return src_info
148148

149149

150-
def check_cat_consistency(cat_ratio, cat_wcounts):
151-
err = 0
152-
# check consistency
153-
for p, pinfo in cat_ratio:
154-
for c in pinfo:
155-
try:
156-
cat_wcounts[p][c]
157-
except KeyError:
158-
err = 1
159-
print("Missing %s.%s" % (p, c))
160-
if err:
161-
raise ValueError("category weighting ratio information is not "
162-
"consistent with window information")
163-
164-
165150
def plot_histogram(figname, array, nbins=50):
166151
# plot histogram of weights
167152
plt.hist(array, nbins)
168153
plt.savefig(figname)
169154

170155

156+
def analyze_category_weights(cat_weights, logfile):
157+
log = {"category_weights": cat_weights}
158+
maxw = 0
159+
minw = 10**9
160+
for _p, _pw in cat_weights.iteritems():
161+
for _comp, _compw in _pw.iteritems():
162+
if _compw > maxw:
163+
maxw = _compw
164+
if _compw < minw:
165+
minw = _compw
166+
log["summary"] = {"maxw": maxw, "minw": minw,
167+
"cond_num": maxw/minw}
168+
169+
dump_json(log, logfile)
170+
171+
171172
def validate_overall_weights(weights_array, nwins_array):
172173
"""
173174
Validate the overall weights.
@@ -238,7 +239,7 @@ def __init__(self, path, param):
238239
self.cat_wcounts = {}
239240
self.cat_weights = None
240241

241-
def analyze_receiver(self, logfile):
242+
def analyze_receiver_weights(self, logfile):
242243
log = {}
243244
for _p, _pw in self.weights.iteritems():
244245
log[_p] = {}
@@ -259,21 +260,6 @@ def analyze_receiver(self, logfile):
259260

260261
dump_json(log, logfile)
261262

262-
def analyze_category(self, logfile):
263-
log = {"category_weights": self.cat_weights}
264-
maxw = 0
265-
minw = 10**9
266-
for _p, _pw in self.cat_weights.iteritems():
267-
for _comp, _compw in _pw.iteritems():
268-
if _compw > maxw:
269-
maxw = _compw
270-
if _compw < minw:
271-
minw = _compw
272-
log["summary"] = {"maxw": maxw, "minw": minw,
273-
"cond_num": maxw/minw}
274-
275-
dump_json(log, logfile)
276-
277263
def analyze(self):
278264
"""
279265
Analyze the final weight and generate log file
@@ -282,11 +268,11 @@ def analyze(self):
282268
logdir = os.path.dirname(self.path["logfile"])
283269
logfile = os.path.join(logdir, "log.receiver_weights.json")
284270
logger.info("receiver log file: %s" % logfile)
285-
self.analyze_receiver(logfile)
271+
self.analyze_receiver_weights(logfile)
286272

287273
logfile = os.path.join(logdir, "log.category_weights.json")
288274
logger.info("category log file: %s" % logfile)
289-
self.analyze_category(logfile)
275+
analyze_category_weights(self.cat_weights, logfile)
290276

291277
analyze_overall_weights(self.weights, self.rec_wcounts, logdir)
292278

@@ -304,10 +290,10 @@ def calculate_receiver_weights_asdf(self):
304290
and normalized)separately.
305291
"""
306292
logger_block("Receiver Weighting")
307-
input_info = self.path["input"]
308293

309294
weighting_param = self.param["receiver_weighting"]
310295

296+
input_info = self.path["input"]
311297
nperiods = len(input_info)
312298
period_idx = 0
313299
# determine receiver weightings for each asdf file
@@ -335,6 +321,7 @@ def smart_run(self):
335321
# calculate receiver weights
336322
self.calculate_receiver_weights_asdf()
337323
# calculate category weights
324+
logger_block("Category Weighting")
338325
self.cat_weights = calculate_category_weights_interface(
339326
self.param["category_weighting"], self.cat_wcounts)
340327

0 commit comments

Comments
 (0)