[update] refactoring, added type hints

This commit is contained in:
Marcel Paffrath 2024-08-08 16:49:15 +02:00
parent 452f2a2e18
commit a068bb8457

View File

@ -21,7 +21,7 @@ from joblib import Parallel, delayed
from obspy import read, Stream, UTCDateTime, Trace
from obspy.taup import TauPyModel
from obspy.core.event.base import WaveformStreamID, ResourceIdentifier
from obspy.core.event.origin import Pick
from obspy.core.event.origin import Pick, Origin
from obspy.geodetics.base import gps2dist_azimuth
from obspy.signal.cross_correlation import correlate, xcorr_max
@ -34,6 +34,12 @@ from pylot.core.util.event import Event as PylotEvent
from pylot.correlation.utils import (get_event_id, get_event_pylot, get_event_obspy_dmt, get_picks, write_json,
get_metadata)
DEBUG_LEVELS = {'debug': logging.DEBUG,
'info': logging.INFO,
'warn': logging.WARNING,
'error': logging.ERROR,
'critical': logging.CRITICAL}
class CorrelationParameters:
"""
@ -210,7 +216,7 @@ class XCorrPickCorrection:
Returns:
array: A numpy array representing the time axis.
"""
return np.linspace(0,trace.stats.endtime -trace.stats.starttime, trace.stats.npts)
return np.linspace(0, trace.stats.endtime - trace.stats.starttime, trace.stats.npts)
def plot_cc(figure_output_dir: str, plot_filename: str) -> None:
"""
@ -316,7 +322,7 @@ class XCorrPickCorrection:
peak_index = cc.argmax()
first_sample = peak_index
# XXX this could be improved..
while first_sample > 0 and cc_curvature[first_sample - 1] <= 0:
while first_sample > 0 >= cc_curvature[first_sample - 1]:
first_sample -= 1
last_sample = peak_index
while last_sample < len(cc) - 1 and cc_curvature[last_sample + 1] <= 0:
@ -337,7 +343,7 @@ class XCorrPickCorrection:
# quadratic fit for small subwindow
coeffs, cov_mat = np.polyfit(cc_t[first_sample:last_sample + 1], cc[first_sample:last_sample + 1], deg=2,
full=False, cov=True)[:2]
full=False, cov=True)[:2]
a, b, c = coeffs
@ -380,9 +386,10 @@ class XCorrPickCorrection:
return pick2_corr, coeff, uncert, fwfm
def correlation_main(database_path_dmt, pylot_infile_path, params, channel_config, istart=0, istop=1e9, update=False,
event_blacklist=None):
'''
def correlation_main(database_path_dmt: str, pylot_infile_path: str, params: dict, channel_config: dict,
istart: int = 0, istop: int = 1e9, update: bool = False,
event_blacklist: str = None, select_events: list = None) -> None:
"""
Main function of this program, correlates waveforms around theoretical, or other initial (e.g. automatic cf) picks
on one or more reference stations.
All stations with a correlation higher "min_corr_stacking" are stacked onto the station with the highest mean
@ -390,12 +397,21 @@ def correlation_main(database_path_dmt, pylot_infile_path, params, channel_confi
Finally, all other stations will be correlated with the re-picked, stacked seismogram around their theoretical (or
initial) onsets and the shifted pick time of the stacked seismogram will be assigned as their new pick time.
:param database_path_dmt: obspy_dmt databse directory
:param pylot_infile_path: path containing input file for autoPyLoT (automatic picking parameters)
:param params: dictionary with parameter class objects
Args:
database_path_dmt (str): The path to the obspydmt database.
pylot_infile_path (str): The path to the Pylot infile for autoPyLoT (automatic picking parameters).
params (dict): Parameters for the correlation script.
channel_config (dict): Configuration for channels.
istart (int, optional): The starting index for events. Defaults to 0.
istop (float, optional): The stopping index for events. Defaults to 1e9.
update (bool, optional): Whether to update. Defaults to False.
event_blacklist (str, optional): Path to the event blacklist file. Defaults to None.
select_events (list, optional): List of selected events. Defaults to None.
Returns:
None
:return:
'''
"""
assert os.path.isdir(database_path_dmt), 'Unrecognized directory {}'.format(database_path_dmt)
tstart = datetime.now()
@ -405,39 +421,38 @@ def correlation_main(database_path_dmt, pylot_infile_path, params, channel_confi
logging.info(50 * '#')
for phase_type in params.keys():
if params[phase_type]['plot_detailed']: params[phase_type]['plot'] = True
if params[phase_type]['plot_detailed']:
params[phase_type]['plot'] = True
eventdirs = glob.glob(os.path.join(database_path_dmt, '*.?'))
pylot_parameter = PylotParameter(pylot_infile_path)
if event_blacklist:
with open(event_blacklist, 'r') as infile:
event_blacklist = [line.split('\n')[0] for line in infile.readlines()]
with open(event_blacklist, 'r') as fid:
event_blacklist = [line.split('\n')[0] for line in fid.readlines()]
# iterate over all events in "database_path_dmt"
for eventindex, eventdir in enumerate(eventdirs):
if not istart <= eventindex < istop:
continue
# MP MP testing +++
# ids_filter = ['20181229_033914.a']
# if not os.path.split(eventdir)[-1] in ids_filter:
# continue
# MP MP testing ---
if select_events and not os.path.split(eventdir)[-1] in select_events:
continue
logging.info('\n' + 100 * '#')
logging.info('Working on event {} ({}/{})'.format(eventdir, eventindex + 1, len(eventdirs)))
if event_blacklist and get_event_id(eventdir) in event_blacklist:
logging.info('Event on blacklist. Continue')
correlate_event(eventdir, pylot_parameter, params=params, channel_config=channel_config, update=update)
correlate_event(eventdir, pylot_parameter, params=params, channel_config=channel_config,
update=update)
logging.info('Finished script after {} at {}'.format(datetime.now() - tstart, datetime.now()))
def get_estimated_inclination(stations_dict, origin, phases, model='ak135'):
''' calculate a mean inclination angle for all stations for seismometer rotation to spare computation time'''
def get_estimated_inclination(stations_dict: dict, origin: Origin, phases: str, model: str = 'ak135') -> float:
""" calculate a mean inclination angle for all stations for seismometer rotation to spare computation time"""
model = TauPyModel(model)
phases = [*phases.split(',')]
avg_lat = np.median([item['latitude'] for item in stations_dict.values()])
@ -448,9 +463,9 @@ def get_estimated_inclination(stations_dict, origin, phases, model='ak135'):
return arrivals[0].incident_angle
def modify_horizontal_name(wfdata):
''' This function only renames (does not rotate!) numeric channel IDs to be able to rotate them to LQT. It is
therefore not accurate and only used for picking with correlation. Returns a copy of the original stream. '''
def modify_horizontal_name(wfdata: Stream) -> Stream:
""" This function only renames (does not rotate!) numeric channel IDs to be able to rotate them to LQT. It is
therefore not accurate and only used for picking with correlation. Returns a copy of the original stream. """
# wfdata = wfdata.copy()
stations = np.unique([tr.stats.station for tr in wfdata])
for station in stations:
@ -460,13 +475,14 @@ def modify_horizontal_name(wfdata):
st = wfdata.select(station=station, location=location)
channels = [tr.stats.channel for tr in st]
check_numeric = [channel[-1].isnumeric() for channel in channels]
check_nonnumeric = [not (cn) for cn in check_numeric]
check_nonnumeric = [not cn for cn in check_numeric]
if not any(check_numeric):
continue
numeric_channels = np.array(channels)[check_numeric].tolist()
nonnumeric_channels = np.array(channels)[check_nonnumeric].tolist()
if not 'Z' in [nc[-1] for nc in nonnumeric_channels]:
logging.warning('Modify_horizontal_name failed: Only implemented for existing Z component! Return original data.')
if 'Z' not in [nc[-1] for nc in nonnumeric_channels]:
logging.warning(
'Modify_horizontal_name failed: Only implemented for existing Z component! Return original data.')
return wfdata
numeric_characters = sorted([nc[-1] for nc in numeric_channels])
if numeric_characters == ['1', '2']:
@ -474,7 +490,8 @@ def modify_horizontal_name(wfdata):
elif numeric_characters == ['2', '3']:
channel_dict = {'2': 'N', '3': 'E'}
else:
logging.warning('Modify_horizontal_name failed: Channel order not implemented/unknown. Return original data.')
logging.warning(
'Modify_horizontal_name failed: Channel order not implemented/unknown. Return original data.')
return wfdata
for tr in st:
channel = tr.stats.channel
@ -486,10 +503,10 @@ def modify_horizontal_name(wfdata):
return wfdata
def cut_stream_to_same_length(wfdata):
def remove(wfdata, st):
for tr in st:
wfdata.remove(tr)
def cut_stream_to_same_length(wfdata: Stream) -> None:
def remove(data, stream):
for tr in stream:
data.remove(tr)
stations = np.unique([tr.stats.station for tr in wfdata])
for station in stations:
@ -515,7 +532,7 @@ def cut_stream_to_same_length(wfdata):
st.trim(tstart, tend, pad=True, fill_value=0.)
def get_unique_phase(picks, rename_phases=None):
def get_unique_phase(picks: list, rename_phases: dict = None) -> str:
phases = [pick.phase_hint for pick in picks]
if rename_phases:
phases = [rename_phases[phase] if phase in rename_phases else phase for phase in phases]
@ -523,10 +540,10 @@ def get_unique_phase(picks, rename_phases=None):
if len(set(phases)) == 1:
return phases[0]
return False
def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
# TODO: simplify this function
def correlate_event(eventdir: str, pylot_parameter: PylotParameter, params: dict, channel_config: dict,
update: bool) -> bool:
rename_phases = {'Pdiff': 'P'}
# create ObsPy event from .pkl file in dmt eventdir
@ -588,7 +605,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
wfdata = wfdata_raw.copy()
# resample and filter
if params[phase_type]['sampfreq']:
wfdata = resample_parallel(wfdata, params[phase_type]['sampfreq'], ncores=params[phase_type]['ncores'])
wfdata = resample_parallel(wfdata, params[phase_type]['sampfreq'],
ncores=params[phase_type]['ncores'])
else:
logging.warning('Resampling deactivated! '
'Make sure that the sampling rate of all input waveforms is identical')
@ -602,7 +620,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
else:
method = 'auto'
# get all picks from PyLoT *.xml file
if not pfe.startswith('_'): pfe = '_' + pfe
if not pfe.startswith('_'):
pfe = '_' + pfe
picks = get_picks(eventdir, extension=pfe)
# take picks of selected phase only
picks = [pick for pick in picks if pick.phase_hint == phase_type]
@ -630,7 +649,7 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
model='ak135')
# rotate ZNE -> LQT (also cut to same length)
wfdata = rotate_stream(wfdata, metadata=metadata, origin=origin, stations_dict=stations_dict,
channels=channels[phase_type], inclination=inclination, ncores=ncores)
channels=CHANNELS[phase_type], inclination=inclination, ncores=ncores)
# make copies before filtering
wfdata_lowf = wfdata.copy()
@ -644,7 +663,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
# make directory for correlation output
correlation_out_dir = create_correlation_output_dir(eventdir, filter_options_final, true_phase_type)
fig_dir = create_correlation_figure_dir(correlation_out_dir) if params[phase_type]['save_fig'] == True else ''
fig_dir = create_correlation_figure_dir(correlation_out_dir) \
if params[phase_type]['save_fig'] is True else ''
if not params[phase_type]['use_stacked_trace']:
# first stack mastertrace by using stations with high correlation on that station in station list with the
@ -666,13 +686,14 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
if params[phase_type]['plot']:
# plot correlations of traces used to generate stacked trace
plot_mastertrace_corr(nwst_id, correlations_dict, wfdata_lowf, stations_dict, picks, trace_master, method,
min_corr=params[phase_type]['min_corr_stacking'], title=eventdir + '_before_stacking',
fig_dir=fig_dir)
min_corr=params[phase_type]['min_corr_stacking'],
title=eventdir + '_before_stacking', fig_dir=fig_dir)
# continue if there is not enough traces for stacking
if nstack < params[phase_type]['min_stack']:
logging.warning('Not enough traces to stack: {} (min_stack = {}). Skip this event'.format(nstack,
params[phase_type][
params[
phase_type][
'min_stack']))
continue
@ -680,8 +701,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
trace_master.write(os.path.join(correlation_out_dir, '{}_stacked.mseed'.format(trace_master.id)))
# now pick stacked trace with PyLoT for a more precise pick (use raw trace, gets filtered by autoPyLoT)
pick_stacked = repick_mastertrace(wfdata_lowf, trace_master, pylot_parameter, event, event_id, metadata,
phase_type, correlation_out_dir)
pick_stacked = repick_master_trace(wfdata_lowf, trace_master, pylot_parameter, event, event_id, metadata,
phase_type, correlation_out_dir)
if not pick_stacked:
continue
@ -696,7 +717,7 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
trace_master_highf.filter(filter_type, **filter_options_final)
input_list = prepare_correlation_input(wfdata_lowf, picks, channels_list, trace_master_lowf, pick_stacked,
params=params[phase_type], plot=params[phase_type]['plot'],
phase_params=params[phase_type], plot=params[phase_type]['plot'],
fig_dir=fig_dir_traces, ncorr=2, wfdata_highf=wfdata_highf,
trace_master_highf=trace_master_highf)
@ -712,8 +733,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
export_picks(eventdir=eventdir,
correlations_dict=get_stations_min_corr(correlations_dict_stacked,
params[phase_type]['min_corr_export']),
params=params[phase_type], picks=picks, taupypicks=taupypicks_orig, phase_type=true_phase_type,
pf_extension=pfe)
params=params[phase_type], picks=picks, taupypicks=taupypicks_orig,
phase_type=true_phase_type, pf_extension=pfe)
# plot results
if params[phase_type]['plot']:
@ -748,10 +769,12 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
if len(correlations_dict_stacked) > 0:
return True
else:
return False
def remove_outliers(picks, corrected_taupy_picks, threshold):
''' delete a pick if difference to corrected taupy pick is larger than threshold'''
""" delete a pick if difference to corrected taupy pick is larger than threshold"""
n_picks = len(picks)
n_outl = 0
@ -780,7 +803,7 @@ def remove_outliers(picks, corrected_taupy_picks, threshold):
def remove_invalid_picks(picks):
''' Remove picks without uncertainty (invalid PyLoT picks)'''
""" Remove picks without uncertainty (invalid PyLoT picks)"""
count = 0
deleted_picks_ids = []
for index, pick in list(reversed(list(enumerate(picks)))):
@ -803,7 +826,7 @@ def get_picks_mean(picks):
def get_corrected_taupy_picks(picks, taupypicks, all_available=False):
''' get mean/median from picks taupy picks, correct latter for the difference '''
""" get mean/median from picks taupy picks, correct latter for the difference """
def nwst_id_from_wfid(wfid):
return '{}.{}'.format(wfid.network_code if wfid.network_code else '',
@ -824,8 +847,9 @@ def get_corrected_taupy_picks(picks, taupypicks, all_available=False):
mean_diff = taupy_mean - picks_mean
logging.info(f'Calculated {len(taupypicks_new)} TauPy theoretical picks.')
logging.info('Calculated median difference from TauPy theoretical picks of {} seconds. (mean: {})'.format(median_diff,
mean_diff))
logging.info(
'Calculated median difference from TauPy theoretical picks of {} seconds. (mean: {})'.format(median_diff,
mean_diff))
# return all available taupypicks corrected for median difference to autopicks
if all_available:
@ -851,7 +875,8 @@ def load_stacked_trace(eventdir, min_corr_stack):
if not os.path.isfile(corr_dict_fpath):
logging.warning('No correlations_for_stacking dict found for event {}!'.format(eventdir))
return
corr_dict = json.load(corr_dict_fpath) # TODO: Check this line
with open(corr_dict_fpath) as fid:
corr_dict = json.load(fid)
# get stations for stacking and nstack
stations4stack = get_stations_min_corr(corr_dict, min_corr_stack)
@ -866,10 +891,10 @@ def load_stacked_trace(eventdir, min_corr_stack):
return corr_dict, nwst_id, stacked_trace, nstack
def repick_mastertrace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir):
rename_LQT = phase_type == 'S'
def repick_master_trace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir):
rename_lqt = phase_type == 'S'
# create an 'artificial' stream object which can be used as input for PyLoT
stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_LQT=rename_LQT)
stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_lqt=rename_lqt)
stream_master.write(os.path.join(corr_out_dir, 'stacked_master_trace_pylot.mseed'.format(trace_master.id)))
try:
picksdict = autopickstation(stream_master, pylot_parameter, verbose=True, metadata=metadata,
@ -894,10 +919,10 @@ def repick_mastertrace(wfdata, trace_master, pylot_parameter, event, event_id, m
return pick_stacked
def create_correlation_output_dir(eventdir, fopts, phase_type):
def create_correlation_output_dir(eventdir: str, fopts: dict, phase_type: str) -> str:
folder = 'correlation'
folder = add_fpath_extension(folder, fopts, phase_type)
export_path = os.path.join(eventdir, folder)
export_path = str(os.path.join(eventdir, folder))
if not os.path.isdir(export_path):
os.mkdir(export_path)
return export_path
@ -927,12 +952,12 @@ def write_correlation_output(export_path, correlations_dict, correlations_dict_s
write_json(correlations_dict_stacked, os.path.join(export_path, 'correlation_results.json'))
def modify_master_trace4pick(trace_master, wfdata, rename_LQT=True):
'''
def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True):
"""
Create an artificial Stream with master trace instead of the trace of its corresponding channel.
This is done to find metadata for correct station metadata which were modified when stacking (e.g. loc=ST)
USE THIS ONLY FOR PICKING!
'''
"""
# fake ZNE coordinates for autopylot
lqt_zne = {'L': 'Z', 'Q': 'N', 'T': 'E'}
@ -949,7 +974,7 @@ def modify_master_trace4pick(trace_master, wfdata, rename_LQT=True):
# take location from old trace and overwrite
trace_master.stats.location = trace.stats.location
trace = trace_master
if rename_LQT:
if rename_lqt:
channel = trace.stats.channel
component_new = lqt_zne.get(channel[-1])
if not component_new:
@ -1088,7 +1113,7 @@ def get_pickfile_name_correlated(event_id, fopts, phase_type):
# return trace_this
def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, params, plot=None, fig_dir=None,
def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, phase_params, plot=None, fig_dir=None,
ncorr=1, wfdata_highf=None, trace_master_highf=None):
# prepare input for multiprocessing worker for all 'other' picks to correlate with current master-trace
input_list = []
@ -1096,7 +1121,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master
assert (ncorr in [1, 2]), 'ncorr has to be 1 or 2'
for pick_other in picks:
trace_other_highf = None
stream_other = stream_other_high_f = trace_other_high_f = channel = None
network_other = pick_other.waveform_id.network_code
station_other = pick_other.waveform_id.station_code
nwst_id_other = '{nw}.{st}'.format(nw=network_other, st=station_other)
@ -1106,7 +1131,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master
for channel in channels:
stream_other = wfdata.select(network=network_other, station=station_other, channel=channel)
if ncorr == 2:
stream_other_highf = wfdata_highf.select(network=network_other, station=station_other, channel=channel)
stream_other_high_f = wfdata_highf.select(network=network_other, station=station_other, channel=channel)
if stream_other:
break
if not stream_other:
@ -1114,37 +1139,39 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master
continue
trace_other = stream_other[0]
if ncorr == 2:
trace_other_highf = stream_other_highf[0]
trace_other_high_f = stream_other_high_f[0]
if trace_other == stream_other:
continue
input_dict = {'nwst_id': nwst_id_other, 'trace1': trace_master, 'pick1': pick_master, 'trace2': trace_other,
'pick2': pick_other, 'channel': channel, 't_before': params['t_before'],
't_after': params['t_after'], 'cc_maxlag': params['cc_maxlag'],
'cc_maxlag2': params['cc_maxlag2'], 'plot': plot, 'fig_dir': fig_dir, 'ncorr': ncorr,
'trace1_highf': trace_master_highf, 'trace2_highf': trace_other_highf,
'min_corr': params['min_corr_export']}
'pick2': pick_other, 'channel': channel, 't_before': phase_params['t_before'],
't_after': phase_params['t_after'], 'cc_maxlag': phase_params['cc_maxlag'],
'cc_maxlag2': phase_params['cc_maxlag2'], 'plot': plot, 'fig_dir': fig_dir, 'ncorr': ncorr,
'trace1_highf': trace_master_highf, 'trace2_highf': trace_other_high_f,
'min_corr': phase_params['min_corr_export']}
input_list.append(input_dict)
return input_list
def stack_mastertrace(wfdata_lowf, wfdata_highf, wfdata_raw, picks, params, channels, method, fig_dir):
'''
def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Stream, picks: list, params: dict,
channels: list, method: str, fig_dir: str):
"""
Correlate all stations with the first available station given in station_list, a list containing central,
permanent, long operating, low noise stations with descending priority.
A master trace will be created by stacking well correlating traces onto this station.
'''
"""
def get_best_station4stack(station_results):
''' return station with maximum mean_ccc'''
ccc_means = {nwst_id: value['mean_ccc'] for nwst_id, value in station_results.items() if
def get_best_station4stack(sta_result):
""" return station with maximum mean_ccc"""
ccc_means = {nwst_id: value['mean_ccc'] for nwst_id, value in sta_result.items() if
not np.isnan(value['mean_ccc'])}
if len(ccc_means) < 1:
logging.warning('No valid station found for stacking! Return.')
return
best_station_id = max(ccc_means, key=ccc_means.get)
logging.info('Found highest mean correlation for station {} ({})'.format(best_station_id, max(ccc_means.values())))
logging.info(
'Found highest mean correlation for station {} ({})'.format(best_station_id, max(ccc_means.values())))
return best_station_id
station_results = iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, params, fig_dir=fig_dir)
@ -1153,7 +1180,7 @@ def stack_mastertrace(wfdata_lowf, wfdata_highf, wfdata_raw, picks, params, chan
# in case no stream with a valid pick is found
if not nwst_id_master:
logging.info('No mastertrace found! Will skip this event.')
return False
return None
trace_master = station_results[nwst_id_master]['trace']
stations4stack = station_results[nwst_id_master]['stations4stack']
@ -1163,22 +1190,27 @@ def stack_mastertrace(wfdata_lowf, wfdata_highf, wfdata_raw, picks, params, chan
dt_pre, dt_post = params['dt_stacking']
trace_master, nstack = apply_stacking(trace_master, stations4stack, wfdata_raw, picks, method=method,
check_RMS=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir,
check_rms=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir,
dt_pre=dt_pre, dt_post=dt_post)
return correlations_dict, nwst_id_master, trace_master, nstack
def iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, params, fig_dir=None):
'''iterate over possible stations for master-trace store them and return a dictionary'''
def iterate_correlation(wfdata_lowf: Stream, wfdata_highf: Stream, channels: list, picks: list, method: str,
params: dict, fig_dir: str = None) -> dict:
"""iterate over possible stations for master-trace store them and return a dictionary"""
station_results = {nwst_id: {'mean_ccc': np.nan, 'correlations_dict': None, 'stations4stack': None, 'trace': None}
station_results = {nwst_id: {'mean_ccc': np.nan, 'correlations_dict': dict(), 'stations4stack': dict(),
'trace': None}
for nwst_id in params['station_list']}
for nwst_id_master in params['station_list']:
logging.info(20 * '#')
logging.info('Starting correlation for station: {}'.format(nwst_id_master))
nw, st = nwst_id_master.split('.')
# get master-trace
stream_master_lowf = stream_master_highf = None
for channel in channels:
stream_master_lowf = wfdata_lowf.select(network=nw, station=st, channel=channel)
stream_master_highf = wfdata_highf.select(network=nw, station=st, channel=channel)
@ -1211,8 +1243,8 @@ def iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, para
input_list = prepare_correlation_input(wfdata_lowf, picks, channels, trace_master_lowf, pick_master,
trace_master_highf=trace_master_highf, wfdata_highf=wfdata_highf,
params=params, plot=params['plot_detailed'], fig_dir=fig_dir_traces,
ncorr=2)
phase_params=params, plot=params['plot_detailed'],
fig_dir=fig_dir_traces, ncorr=2)
if params['plot_detailed'] and not fig_dir_traces:
# serial
@ -1231,9 +1263,11 @@ def iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, para
return station_results
def apply_stacking(trace_master, stations4stack, wfdata, picks, method, check_RMS, dt_pre=250., dt_post=250.,
plot=False, fig_dir=None):
def check_trace_length_and_cut(trace, correlated_midtime, dt_pre, dt_post):
def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, picks: list, method: str,
check_rms: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False,
fig_dir: str = None) -> tuple:
def check_trace_length_and_cut(trace: Trace):
starttime = correlated_midtime - dt_pre
endtime = correlated_midtime + dt_post
@ -1248,7 +1282,7 @@ def apply_stacking(trace_master, stations4stack, wfdata, picks, method, check_RM
trace_master = trace_master.copy()
trace_master.stats.location = 'ST'
check_trace_length_and_cut(trace_master, pick.time, dt_pre=dt_pre, dt_post=dt_post)
check_trace_length_and_cut(trace_master)
# empty trace so that it does not stack twice
trace_master.data = np.zeros(len(trace_master.data))
@ -1262,15 +1296,16 @@ def apply_stacking(trace_master, stations4stack, wfdata, picks, method, check_RM
stream_other = wfdata.select(network=nw, station=st, channel=channel)
correlated_midtime = pick_other.time - dpick
trace_other = stream_other[0].copy()
check_trace_length_and_cut(trace_other, correlated_midtime, dt_pre=dt_pre, dt_post=dt_post)
check_trace_length_and_cut(trace_other)
if not len(trace_other) == len(trace_master):
logging.warning('Can not stack trace on master trace because of different lengths: {}. Continue'.format(nwst_id))
logging.warning(
'Can not stack trace on master trace because of different lengths: {}. Continue'.format(nwst_id))
continue
traces4stack.append(trace_other)
if check_RMS:
if check_rms:
traces4stack = checkRMS(traces4stack, plot=plot, fig_dir=fig_dir)
if plot:
@ -1307,12 +1342,14 @@ def checkRMS(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.):
traces4stack = sorted(traces4stack, key=lambda x: x.id)
for trace in traces4stack:
rms_list.append(RMS(trace.data))
rms_list.append(calc_rms(trace.data))
trace_names.append(trace.id)
# iterative elimination of RMS outliers
iterate = True
count = 0
std = 0
mean = 0
while iterate:
count += 1
if count >= max_it:
@ -1341,7 +1378,8 @@ def checkRMS(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.):
return traces4stack
def plot_rms(rms_list, trace_names, mean, std, fig_dir, count, ntimes_std):
def plot_rms(rms_list: list, trace_names: list, mean: float, std: float, fig_dir: str, count: int,
ntimes_std: float) -> None:
fig = plt.figure(figsize=(16, 9))
ax = fig.add_subplot(111)
ax.semilogy(rms_list, 'b.')
@ -1359,7 +1397,7 @@ def plot_rms(rms_list, trace_names, mean, std, fig_dir, count, ntimes_std):
plt.close(fig)
def RMS(X):
def calc_rms(X):
"""
Returns root mean square of a given array LON
"""
@ -1381,7 +1419,7 @@ def resample_parallel(stream, freq, ncores):
def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination, ncores):
''' Rotate stream to LQT. To do this all traces have to be cut to equal lenths'''
""" Rotate stream to LQT. To do this all traces have to be cut to equal lenths"""
input_list = []
cut_stream_to_same_length(stream)
new_stream = Stream()
@ -1424,7 +1462,7 @@ def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination
# return stream
def correlate_parallel(input_list, ncores):
def correlate_parallel(input_list: list, ncores: int) -> dict:
parallel = Parallel(n_jobs=ncores)
logging.info('Correlate_parallel: Generated {} parallel jobs.'.format(ncores))
@ -1435,7 +1473,7 @@ def correlate_parallel(input_list, ncores):
return unpack_result(correlation_result)
def correlate_serial(input_list, plot=False):
def correlate_serial(input_list: list, plot: bool = False) -> dict:
correlation_result = []
for input_dict in input_list:
input_dict['plot'] = plot
@ -1443,7 +1481,7 @@ def correlate_serial(input_list, plot=False):
return unpack_result(correlation_result)
def taupy_parallel(input_list, ncores):
def taupy_parallel(input_list: list, ncores: int) -> dict:
parallel = Parallel(n_jobs=ncores)
logging.info('Taupy_parallel: Generated {} parallel jobs.'.format(ncores))
@ -1454,19 +1492,19 @@ def taupy_parallel(input_list, ncores):
return unpack_result(taupy_results)
def unpack_result(result):
def unpack_result(result: list) -> dict:
result_dict = {item['nwst_id']: {key: item[key] for key in item.keys()} for item in result}
nerr = 0
for item in result_dict.values():
if item['err']:
logging.debug(f'Found error for {item["nwst_id"]}: {item["err"]}.')
#logging.debug(f'Detailed traceback: {item["exc"]}')
# logging.debug(f'Detailed traceback: {item["exc"]}')
nerr += 1
logging.info('Unpack results: Found {} errors after multiprocessing'.format(nerr))
return result_dict
def get_taupy_picks(origin, stations_dict, pylot_parameter, phase_type, ncores=None):
def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, ncores=None):
input_list = []
taup_phases = pylot_parameter['taup_phases'].split(',')
@ -1526,10 +1564,10 @@ def taupy_worker(input_dict):
first_arrival = arrivals[0]
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=first_arrival.name,
phase_time=source_time + first_arrival.time, phase_dist=first_arrival.distance, err=None,
exc=None,)
exc=None, )
except Exception as e:
exc = traceback.format_exc()
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e), exc=exc,)
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e), exc=exc, )
return output_dict
@ -1559,7 +1597,7 @@ def rotation_worker(input_dict):
def correlation_worker(input_dict):
'''worker function for multiprocessing'''
"""worker function for multiprocessing"""
# unpack input dictionary
nwst_id = input_dict['nwst_id']
@ -1613,7 +1651,7 @@ def correlation_worker(input_dict):
'channel': channel, }
def get_pick4station(picks, network_code, station_code, method='auto'):
def get_pick4station(picks: list, network_code: str, station_code: str, method: str = 'auto') -> Pick:
for pick in picks:
if pick.waveform_id.network_code == network_code:
if pick.waveform_id.station_code == station_code:
@ -1621,13 +1659,14 @@ def get_pick4station(picks, network_code, station_code, method='auto'):
return pick
def get_stations_min_corr(corr_results, min_corr):
def get_stations_min_corr(corr_results: dict, min_corr: float) -> dict:
corr_results = {nwst_id: result for nwst_id, result in corr_results.items() if result['ccc'] > min_corr}
return corr_results
def plot_mastertrace_corr(nwst_id, corr_results, wfdata, stations_dict, picks, trace_master, method, min_corr, title,
fig_dir=None):
def plot_mastertrace_corr(nwst_id: str, corr_results: dict, wfdata: Stream, stations_dict: dict, picks: list,
trace_master: Trace, method: str, min_corr: float, title: str,
fig_dir: str = None) -> None:
nw, st = nwst_id.split('.')
coords_master = stations_dict[nwst_id]
pick_master = get_pick4station(picks, nw, st, method)
@ -1650,7 +1689,7 @@ def plot_mastertrace_corr(nwst_id, corr_results, wfdata, stations_dict, picks, t
plt.show()
def make_figure_dirs(fig_dir, trace_master_id):
def make_figure_dirs(fig_dir: str, trace_master_id: str) -> str:
fig_dir_traces = ''
if fig_dir and os.path.isdir(fig_dir):
fig_dir_traces = os.path.join(fig_dir, 'corr_with_{}'.format(trace_master_id))
@ -1659,15 +1698,18 @@ def make_figure_dirs(fig_dir, trace_master_id):
return fig_dir_traces
def plot_section(wfdata, trace_this, pick_this, picks, stations2compare, channels, corr_results, method, dt_pre=20.,
dt_post=50, axes=None, max_stations=50.):
'''Plot a section with all stations used for correlation on ax'''
def plot_section(wfdata: Stream, trace_this: Trace, pick_this: Pick, picks: list, stations2compare: dict,
channels: list, corr_results: dict, method: str, dt_pre: float = 20., dt_post: float = 50.,
axes: dict = None, max_stations: int = 50) -> None:
"""Plot a section with all stations used for correlation on ax"""
ylabels = []
yticks = []
trace_this = trace_this.copy()
trace_this.trim(starttime=pick_this.time - dt_pre, endtime=pick_this.time + dt_post)
ax_sec = None
# iterate over all closest stations ("other" stations)
for index, nwst_id in enumerate(stations2compare.keys()):
if index >= max_stations:
@ -1700,6 +1742,8 @@ def plot_section(wfdata, trace_this, pick_this, picks, stations2compare, channel
if np.isnan(dpick) or not pick_other or not pick_other.time:
continue
stream = None
# continue if there are no data for station for whatever reason
for channel in channels:
stream = wfdata.select(station=st, network=nw, channel=channel)
@ -1718,13 +1762,14 @@ def plot_section(wfdata, trace_this, pick_this, picks, stations2compare, channel
index + 0.5, color='r', lw=0.5)
# Plot desciption
ax_sec.set_yticks(yticks)
ax_sec.set_yticklabels(ylabels)
ax_sec.set_title('Section with corresponding picks.')
ax_sec.set_xlabel('Samples. Traces are shifted in time.')
if ax_sec:
ax_sec.set_yticks(yticks)
ax_sec.set_yticklabels(ylabels)
ax_sec.set_title('Section with corresponding picks.')
ax_sec.set_xlabel('Samples. Traces are shifted in time.')
def plot_stacked_trace_pick(trace, pick, pylot_parameter):
def plot_stacked_trace_pick(trace: Trace, pick: Pick, pylot_parameter: PylotParameter) -> plt.Figure:
trace_filt = trace.copy()
if pylot_parameter['filter_type'] and pylot_parameter['filter_options']:
ftype = pylot_parameter['filter_type'][0]
@ -1746,8 +1791,9 @@ def plot_stacked_trace_pick(trace, pick, pylot_parameter):
return fig
def plot_stations(stations_dict, stations2compare, coords_this, corr_result, trace, pick, window_title=None):
''' Plot routine to check proximity algorithm. '''
def plot_stations(stations_dict: dict, stations2compare: dict, coords_this: dict, corr_result: dict, trace: Trace,
pick: Pick, window_title: str = None):
""" Plot routine to check proximity algorithm. """
title = trace.id
@ -1812,15 +1858,15 @@ def plot_stations(stations_dict, stations2compare, coords_this, corr_result, tra
return fig, axes
def get_data(eventdir, data_dir, headonly=False):
''' Read and return waveformdata read from eventdir/data_dir. '''
def get_data(eventdir: str, data_dir: str, headonly: bool = False) -> Stream:
""" Read and return waveformdata read from eventdir/data_dir. """
wfdata_path = os.path.join(eventdir, data_dir)
wfdata = read(os.path.join(wfdata_path, '*'), headonly=headonly)
return wfdata
def get_closest_stations(coords, stations_dict, n):
''' Calculate distances and return n closest stations in stations_dict for station at coords. '''
""" Calculate distances and return n closest stations in stations_dict for station at coords. """
distances = {}
for station_id, st_coords in stations_dict.items():
dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'],
@ -1836,7 +1882,7 @@ def get_closest_stations(coords, stations_dict, n):
def get_random_stations(coords, stations_dict, n):
''' Calculate distances and return n randomly selected stations in stations_dict for station at coords. '''
""" Calculate distances and return n randomly selected stations in stations_dict for station at coords. """
random_keys = random.sample(list(stations_dict.keys()), n)
distances = {}
for station_id, st_coords in stations_dict.items():
@ -1851,11 +1897,41 @@ def get_random_stations(coords, stations_dict, n):
return random_stations
# Notes: if use_master_trace is set True, traces above a specified ccc threshold will be stacked onto one 'ideal'
# station into a master-trace. This trace will be picked and used again for cross correlation with all other stations
# to find a pick for these stations.
def prepare_corr_params(parse_args, logger):
with open(parse_args.params) as infile:
parameters_dict = yaml.safe_load(infile)
if __name__ == "__main__":
# logging
logger.setLevel(DEBUG_LEVELS.get(parameters_dict['logging']))
# number of cores
if parse_args.ncores is not None:
ncores = int(parse_args.ncores)
else:
ncores = None
# plot options
plot_dict = dict(plot=parse_args.plot, plot_detailed=parse_args.plot_detailed,
save_fig=not parse_args.show_fig)
corr_params = {phase: CorrelationParameters(ncores=ncores) for phase in parameters_dict['pick_phases']}
for phase, params_phase in corr_params.items():
params_phase.add_parameters(**plot_dict)
params_phase.add_parameters(**parameters_dict[phase])
return corr_params
def init_logger():
logger = logging.getLogger()
handler = logging.StreamHandler(sys.stdout)
fhandler = logging.FileHandler('pick_corr.out')
logger.addHandler(handler)
logger.addHandler(fhandler)
return logger
def setup_parser():
parser = argparse.ArgumentParser(description='Correlate picks from PyLoT.')
parser.add_argument('dmt_path', help='path containing dmt_database with PyLoT picks')
parser.add_argument('pylot_infile', help='path to autoPyLoT inputfile (pylot.in)')
@ -1881,44 +1957,22 @@ if __name__ == "__main__":
parser.add_argument('-istart', default=0, help='first event index')
parser.add_argument('-istop', default=1e9, help='last event index')
args = parser.parse_args()
return parser.parse_args()
# PARAMETER DEFINITIONS ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# plot options
plot_dict = dict(plot=args.plot, plot_detailed=args.plot_detailed, save_fig=not (args.show_fig))
if __name__ == "__main__":
ARGS = setup_parser()
# initialize logging
LOGGER = init_logger()
# alparray configuration: ZNE ~ 123
channels = {'P': ['*Z', '*1'], 'S': ['*Q', '*T']} # '*N', '*E', '*2', '*3',
debug_levels = {'debug': logging.DEBUG,
'info': logging.INFO,
'warn': logging.WARNING,
'error': logging.ERROR,
'critical': logging.CRITICAL}
CHANNELS = {'P': ['*Z', '*1'], 'S': ['*Q', '*T']} # '*N', '*E', '*2', '*3',
# initialize parameters from yaml, set logging level
CORR_PARAMS = prepare_corr_params(ARGS, LOGGER)
if args.ncores is not None:
_ncores = int(args.ncores)
else:
_ncores = None
with open(args.params) as infile:
parameters_dict = yaml.safe_load(infile)
params = {phase: None for phase in parameters_dict['pick_phases']}
logger = logging.getLogger()
logger.setLevel(debug_levels.get(parameters_dict['logging']))
handler = logging.StreamHandler(sys.stdout)
fhandler = logging.FileHandler('pick_corr.out')
logger.addHandler(handler)
logger.addHandler(fhandler)
for phase in params.keys():
params_phase = CorrelationParameters(ncores=_ncores)
params_phase.add_parameters(**plot_dict)
params_phase.add_parameters(**parameters_dict[phase])
params[phase] = params_phase
correlation_main(args.dmt_path, args.pylot_infile, params=params, istart=int(args.istart), istop=int(args.istop),
channel_config=channels, update=args.update, event_blacklist=args.blacklist)
# MAIN +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
correlation_main(ARGS.dmt_path, ARGS.pylot_infile, params=CORR_PARAMS, istart=int(ARGS.istart),
istop=int(ARGS.istop),
channel_config=CHANNELS, update=ARGS.update, event_blacklist=ARGS.blacklist)