Skip to content

Commit af70cf4

Browse files
committed
add adjoint stations generator and plot histogram of measurements utils
1 parent 1c07427 commit af70cf4

File tree

7 files changed

+308
-18
lines changed

7 files changed

+308
-18
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
This script will generate the STATIONS_ADJOINT file from
5+
measurements file and stations file(stations.json). The
6+
STATIONS_ADJOINT will then be used in adjoint simulations.
7+
"""
8+
from __future__ import print_function, division, absolute_import
9+
import os
10+
import argparse
11+
from pprint import pprint
12+
from .utils import load_json
13+
from pytomo3d.station.generate_adjoint_stations import \
14+
generate_adjoint_stations
15+
16+
17+
def main():
18+
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument('-f', action='store', dest='path_file', required=True,
21+
help="path file")
22+
parser.add_argument('-v', action='store_true', dest='verbose',
23+
help="verbose flag")
24+
args = parser.parse_args()
25+
26+
paths = load_json(args.path_file)
27+
28+
print("Path information:")
29+
pprint(paths)
30+
31+
# load stations
32+
station_file = paths["station_file"]
33+
stations = load_json(station_file)
34+
35+
# load measurements
36+
measure_files = paths["measure_files"]
37+
measurements = {}
38+
for period, fn in measure_files.iteritems():
39+
measurements[period] = load_json(fn)
40+
41+
outputfile = paths["output_file"]
42+
outputdir = os.path.dirname(outputfile)
43+
if not os.path.exists(outputdir):
44+
os.makedirs(outputdir)
45+
46+
generate_adjoint_stations(measurements, stations, outputfile)
47+
48+
49+
if __name__ == "__main__":
50+
main()

pypaw/bins/generate_stations_asdf.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#!/usr/bin/env python
22
"""
3-
Scripts that generate stations file from asdf file. If
4-
there are stations in waveforms, then a file `STATIONS_waveform`
5-
will be generated. Or if there are stations in AuxlilaryData,
3+
Scripts that generate stations file from asdf file.
4+
1) If there are stations in waveforms, then a file `STATIONS`
5+
will be generated.
6+
2) if there are stations in AuxlilaryData.AdjointSource,
67
then a file `STATIONS_ADJOINT` will be generated.
78
9+
The output STATIONS file follows the format in SPECFEM3D_GLOBE.
10+
811
:copyright:
912
Wenjie Lei (lei@princeton.edu), 2016
1013
:license:
@@ -20,7 +23,7 @@
2023

2124
from pypaw.stations import extract_adjoint_stations
2225
from pypaw.stations import extract_waveform_stations
23-
from pypaw.stations import write_stations_file
26+
from pytomo3d.station.utils import write_stations_file
2427

2528

2629
def generate_waveform_stations(asdf, outputfn):

pypaw/stations.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,6 @@
1212
from __future__ import (print_function, division, absolute_import)
1313
import pyasdf
1414
from pytomo3d.station import extract_staxml_info
15-
import collections
16-
17-
18-
def write_stations_file(sta_dict, filename="STATIONS"):
19-
"""
20-
Write station information out to a txt file(in SPECFEM FORMAT)
21-
"""
22-
with open(filename, 'w') as fh:
23-
od = collections.OrderedDict(sorted(sta_dict.items()))
24-
for _sta_id, _sta in od.iteritems():
25-
network, station = _sta_id.split(".")
26-
fh.write("%-9s %5s %15.4f %12.4f %10.1f %6.1f\n"
27-
% (station, network, _sta[0], _sta[1], _sta[2], _sta[3]))
2815

2916

3017
def extract_station_info_from_asdf(asdf, verbose=False):
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
### Introduction
2+
3+
This script is used to plot histogram of measurements generated by
4+
pyadjoint, including traveltime(dt) and amplitude(dlnA).
5+
6+
It is very important to monitor the histogram, such as the upper
7+
bound, lower bound, mean and standard deviation values of the
8+
measurements distributions during the inversion.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Generate weights for each window based on the number of windows, location
5+
of stations and receivers
6+
7+
:copyright:
8+
Wenjie Lei (lei@princeton.edu), 2016
9+
:license:
10+
GNU Lesser General Public License, version 3 (LGPLv3)
11+
(http://www.gnu.org/licenses/lgpl-3.0.en.html)
12+
"""
13+
from __future__ import print_function, division
14+
import os
15+
import json
16+
import argparse
17+
18+
# #############################
19+
period_list = ["17_40", "40_100", "90_250"]
20+
21+
superbase = "/lustre/atlas/proj-shared/geo111/Wenjie/DATA_M16"
22+
measurebase = os.path.join(superbase, "measure")
23+
stationbase = os.path.join(superbase, "stations")
24+
# #############################
25+
26+
27+
def load_txt(txtfile):
28+
with open(txtfile, 'r') as fh:
29+
return [line.rstrip() for line in fh]
30+
31+
32+
def check_file_exists(filename):
33+
if not os.path.exists(filename):
34+
raise ValueError("Missing file: %s" % filename)
35+
36+
37+
def generate_json_paths(eventlist, outputfile, mtype=""):
38+
paths = {"input": {}, "outputdir": "./output%s" % mtype}
39+
40+
for event in eventlist:
41+
event_info = {}
42+
stationfile = os.path.join(stationbase, "%s.stations.json" % event)
43+
check_file_exists(stationfile)
44+
period_info = {}
45+
for period in period_list:
46+
measure_file = \
47+
os.path.join(measurebase, "%s.%s.measure_adj.json%s"
48+
% (event, period, mtype))
49+
check_file_exists(measure_file)
50+
period_info[period] = {"measure_file": measure_file}
51+
event_info = {"stationfile": stationfile,
52+
"period_info": period_info}
53+
54+
paths["input"][event] = event_info
55+
56+
print("Output dir json file: ", outputfile)
57+
with open(outputfile, 'w') as f:
58+
json.dump(paths, f, indent=2, sort_keys=True)
59+
60+
61+
if __name__ == "__main__":
62+
parser = argparse.ArgumentParser()
63+
parser.add_argument('-f', action='store', dest='eventlist_file',
64+
required=True)
65+
args = parser.parse_args()
66+
67+
eventlist = load_txt(args.eventlist_file)
68+
69+
generate_json_paths(eventlist, "window_weight.path.json", mtype="")
70+
71+
generate_json_paths(eventlist, "window_weight.filter.path.json",
72+
mtype=".filter")
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Plot the histogram of measurements. All the input files
5+
are specified in the path json file.
6+
7+
:copyright:
8+
Wenjie Lei (lei@princeton.edu), 2016
9+
:license:
10+
GNU Lesser General Public License, version 3 (LGPLv3)
11+
(http://www.gnu.org/licenses/lgpl-3.0.en.html)
12+
"""
13+
from __future__ import print_function, division
14+
import os
15+
import json
16+
import argparse
17+
import numpy as np
18+
import matplotlib
19+
matplotlib.use('Agg')
20+
import matplotlib.pyplot as plt
21+
22+
23+
def load_txt(txtfile):
24+
with open(txtfile, 'r') as fh:
25+
return [line.rstrip() for line in fh]
26+
27+
28+
def load_json(fn):
29+
with open(fn) as fh:
30+
return json.load(fh)
31+
32+
33+
def dump_json(content, fn):
34+
with open(fn, 'w') as fh:
35+
json.dump(content, fh, indent=2, sort_keys=True)
36+
37+
38+
def check_file_exists(filename):
39+
if not os.path.exists(filename):
40+
raise ValueError("Missing file: %s" % filename)
41+
42+
43+
def load_one_measurefile(measure_file):
44+
measure = load_json(measure_file)
45+
46+
dt = {}
47+
dlna = {}
48+
for sta, stainfo in measure.iteritems():
49+
for chan, chaninfo in stainfo.iteritems():
50+
comp = chan.split(".")[-1]
51+
if comp not in dt:
52+
dt[comp] = []
53+
dt[comp].extend([m["dt"] for m in chaninfo])
54+
if comp not in dlna:
55+
dlna[comp] = []
56+
dlna[comp].extend([m["dlna"] for m in chaninfo])
57+
58+
return dt, dlna
59+
60+
61+
def update_overall(dict_one, dict_all, pb):
62+
if pb not in dict_all:
63+
dict_all[pb] = {}
64+
for comp in dict_one:
65+
if comp not in dict_all[pb]:
66+
dict_all[pb][comp] = []
67+
dict_all[pb][comp].extend(dict_one[comp])
68+
69+
70+
def get_mean_and_std(dictv):
71+
mean = {}
72+
std = {}
73+
for pb, pbinfo in dictv.iteritems():
74+
mean[pb] = {}
75+
std[pb] = {}
76+
for comp, compinfo in pbinfo.iteritems():
77+
mean[pb][comp] = np.mean(compinfo)
78+
std[pb][comp] = np.std(compinfo)
79+
80+
return mean, std
81+
82+
83+
def stats_analysis(dts, dlnas, outputdir):
84+
dt_mean, dt_std = get_mean_and_std(dts)
85+
dlna_mean, dlna_std = get_mean_and_std(dlnas)
86+
87+
log_content = {"dt": {"mean": dt_mean, "std": dt_std},
88+
"dlna": {"mean": dlna_mean, "std": dlna_std}}
89+
90+
outputfn = os.path.join(outputdir, "measure.log.json")
91+
print("log file: %s" % outputfn)
92+
dump_json(log_content, outputfn)
93+
94+
95+
def load_measurements(inputs):
96+
dts = {}
97+
dlnas = {}
98+
for ev, evinfo in inputs.iteritems():
99+
for pb, pbinfo in evinfo["period_info"].iteritems():
100+
_dt, _dlna = load_one_measurefile(pbinfo["measure_file"])
101+
update_overall(_dt, dts, pb)
102+
update_overall(_dlna, dlnas, pb)
103+
104+
return dts, dlnas
105+
106+
107+
def plot_hist(data, figname=None):
108+
period_bands = ["17_40", "40_100", "90_250"]
109+
components = ["BHR", "BHT", "BHZ"]
110+
111+
fig = plt.figure(figsize=(20, 20))
112+
113+
irow = 0
114+
for pb in period_bands:
115+
icol = 0
116+
for comp in components:
117+
idx = irow * 3 + icol + 1
118+
plt.subplot(3, 3, idx)
119+
plt.hist(data[pb][comp], bins=30)
120+
mean = np.mean(data[pb][comp])
121+
std = np.std(data[pb][comp])
122+
xloc = plt.xlim()[0] + 0.05 * (plt.xlim()[1] - plt.xlim()[0])
123+
plt.text(xloc, plt.ylim()[1] * 0.9, "mean: %.4f" %
124+
(mean))
125+
plt.text(xloc, plt.ylim()[1] * 0.85, "std: %.4f" %
126+
(std))
127+
if icol == 0:
128+
plt.ylabel(pb)
129+
if irow == 2:
130+
plt.xlabel(comp)
131+
icol += 1
132+
irow += 1
133+
134+
print("Save figure to: %s" % figname)
135+
plt.tight_layout()
136+
plt.savefig(figname)
137+
plt.close(fig)
138+
139+
140+
def plot_measures(dts, dlnas, outputdir):
141+
142+
figname = os.path.join(outputdir, "dt.histogram.pdf")
143+
plot_hist(dts, figname=figname)
144+
145+
figname = os.path.join(outputdir, "dlna.histogram.pdf")
146+
plot_hist(dlnas, figname=figname)
147+
148+
149+
def main(path):
150+
inputs = path["input"]
151+
outputdir = path["outputdir"]
152+
print("Number of events: %d" % len(inputs))
153+
if not os.path.exists(outputdir):
154+
os.makedirs(outputdir)
155+
156+
dts, dlnas = load_measurements(inputs)
157+
158+
stats_analysis(dts, dlnas, outputdir)
159+
plot_measures(dts, dlnas, outputdir)
160+
161+
162+
if __name__ == "__main__":
163+
parser = argparse.ArgumentParser()
164+
parser.add_argument('-f', action='store', dest='path',
165+
required=True)
166+
args = parser.parse_args()
167+
168+
path = load_json(args.path)
169+
main(path)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def run_tests(self):
3232
'pypaw-convert_adjsrcs_from_asdf=pypaw.bins.convert_adjsrcs_from_asdf:main', # NOQA
3333
'pypaw-convert_to_asdf=pypaw.bins.convert_to_asdf:main',
3434
'pypaw-convert_to_sac=pypaw.bins.convert_to_sac:main',
35-
'pypaw-generate_stations_asdf=pypaw.bins.generate_stations_asdf:main'
35+
'pypaw-generate_stations_asdf=pypaw.bins.generate_stations_asdf:main',
36+
'pypaw-generate_adjoint_stations=pypaw.bins.generate_adjoint_stations:main'
3637
]
3738

3839

0 commit comments

Comments
 (0)