From a068bb84577c10d9af592f16f5b3e61448cf6f10 Mon Sep 17 00:00:00 2001 From: Marcel Date: Thu, 8 Aug 2024 16:49:15 +0200 Subject: [PATCH] [update] refactoring, added type hints --- .../pick_correlation_correction.py | 384 ++++++++++-------- 1 file changed, 219 insertions(+), 165 deletions(-) diff --git a/pylot/correlation/pick_correlation_correction.py b/pylot/correlation/pick_correlation_correction.py index 8719ad75..737e2853 100644 --- a/pylot/correlation/pick_correlation_correction.py +++ b/pylot/correlation/pick_correlation_correction.py @@ -21,7 +21,7 @@ from joblib import Parallel, delayed from obspy import read, Stream, UTCDateTime, Trace from obspy.taup import TauPyModel from obspy.core.event.base import WaveformStreamID, ResourceIdentifier -from obspy.core.event.origin import Pick +from obspy.core.event.origin import Pick, Origin from obspy.geodetics.base import gps2dist_azimuth from obspy.signal.cross_correlation import correlate, xcorr_max @@ -34,6 +34,12 @@ from pylot.core.util.event import Event as PylotEvent from pylot.correlation.utils import (get_event_id, get_event_pylot, get_event_obspy_dmt, get_picks, write_json, get_metadata) +DEBUG_LEVELS = {'debug': logging.DEBUG, + 'info': logging.INFO, + 'warn': logging.WARNING, + 'error': logging.ERROR, + 'critical': logging.CRITICAL} + class CorrelationParameters: """ @@ -210,7 +216,7 @@ class XCorrPickCorrection: Returns: array: A numpy array representing the time axis. """ - return np.linspace(0,trace.stats.endtime -trace.stats.starttime, trace.stats.npts) + return np.linspace(0, trace.stats.endtime - trace.stats.starttime, trace.stats.npts) def plot_cc(figure_output_dir: str, plot_filename: str) -> None: """ @@ -316,7 +322,7 @@ class XCorrPickCorrection: peak_index = cc.argmax() first_sample = peak_index # XXX this could be improved.. - while first_sample > 0 and cc_curvature[first_sample - 1] <= 0: + while first_sample > 0 >= cc_curvature[first_sample - 1]: first_sample -= 1 last_sample = peak_index while last_sample < len(cc) - 1 and cc_curvature[last_sample + 1] <= 0: @@ -337,7 +343,7 @@ class XCorrPickCorrection: # quadratic fit for small subwindow coeffs, cov_mat = np.polyfit(cc_t[first_sample:last_sample + 1], cc[first_sample:last_sample + 1], deg=2, - full=False, cov=True)[:2] + full=False, cov=True)[:2] a, b, c = coeffs @@ -380,9 +386,10 @@ class XCorrPickCorrection: return pick2_corr, coeff, uncert, fwfm -def correlation_main(database_path_dmt, pylot_infile_path, params, channel_config, istart=0, istop=1e9, update=False, - event_blacklist=None): - ''' +def correlation_main(database_path_dmt: str, pylot_infile_path: str, params: dict, channel_config: dict, + istart: int = 0, istop: int = 1e9, update: bool = False, + event_blacklist: str = None, select_events: list = None) -> None: + """ Main function of this program, correlates waveforms around theoretical, or other initial (e.g. automatic cf) picks on one or more reference stations. All stations with a correlation higher "min_corr_stacking" are stacked onto the station with the highest mean @@ -390,12 +397,21 @@ def correlation_main(database_path_dmt, pylot_infile_path, params, channel_confi Finally, all other stations will be correlated with the re-picked, stacked seismogram around their theoretical (or initial) onsets and the shifted pick time of the stacked seismogram will be assigned as their new pick time. - :param database_path_dmt: obspy_dmt databse directory - :param pylot_infile_path: path containing input file for autoPyLoT (automatic picking parameters) - :param params: dictionary with parameter class objects + Args: + database_path_dmt (str): The path to the obspydmt database. + pylot_infile_path (str): The path to the Pylot infile for autoPyLoT (automatic picking parameters). + params (dict): Parameters for the correlation script. + channel_config (dict): Configuration for channels. + istart (int, optional): The starting index for events. Defaults to 0. + istop (float, optional): The stopping index for events. Defaults to 1e9. + update (bool, optional): Whether to update. Defaults to False. + event_blacklist (str, optional): Path to the event blacklist file. Defaults to None. + select_events (list, optional): List of selected events. Defaults to None. + Returns: + None :return: - ''' + """ assert os.path.isdir(database_path_dmt), 'Unrecognized directory {}'.format(database_path_dmt) tstart = datetime.now() @@ -405,39 +421,38 @@ def correlation_main(database_path_dmt, pylot_infile_path, params, channel_confi logging.info(50 * '#') for phase_type in params.keys(): - if params[phase_type]['plot_detailed']: params[phase_type]['plot'] = True + if params[phase_type]['plot_detailed']: + params[phase_type]['plot'] = True eventdirs = glob.glob(os.path.join(database_path_dmt, '*.?')) pylot_parameter = PylotParameter(pylot_infile_path) if event_blacklist: - with open(event_blacklist, 'r') as infile: - event_blacklist = [line.split('\n')[0] for line in infile.readlines()] + with open(event_blacklist, 'r') as fid: + event_blacklist = [line.split('\n')[0] for line in fid.readlines()] # iterate over all events in "database_path_dmt" for eventindex, eventdir in enumerate(eventdirs): if not istart <= eventindex < istop: continue - # MP MP testing +++ - # ids_filter = ['20181229_033914.a'] - # if not os.path.split(eventdir)[-1] in ids_filter: - # continue - # MP MP testing --- + if select_events and not os.path.split(eventdir)[-1] in select_events: + continue logging.info('\n' + 100 * '#') logging.info('Working on event {} ({}/{})'.format(eventdir, eventindex + 1, len(eventdirs))) if event_blacklist and get_event_id(eventdir) in event_blacklist: logging.info('Event on blacklist. Continue') - correlate_event(eventdir, pylot_parameter, params=params, channel_config=channel_config, update=update) + correlate_event(eventdir, pylot_parameter, params=params, channel_config=channel_config, + update=update) logging.info('Finished script after {} at {}'.format(datetime.now() - tstart, datetime.now())) -def get_estimated_inclination(stations_dict, origin, phases, model='ak135'): - ''' calculate a mean inclination angle for all stations for seismometer rotation to spare computation time''' +def get_estimated_inclination(stations_dict: dict, origin: Origin, phases: str, model: str = 'ak135') -> float: + """ calculate a mean inclination angle for all stations for seismometer rotation to spare computation time""" model = TauPyModel(model) phases = [*phases.split(',')] avg_lat = np.median([item['latitude'] for item in stations_dict.values()]) @@ -448,9 +463,9 @@ def get_estimated_inclination(stations_dict, origin, phases, model='ak135'): return arrivals[0].incident_angle -def modify_horizontal_name(wfdata): - ''' This function only renames (does not rotate!) numeric channel IDs to be able to rotate them to LQT. It is - therefore not accurate and only used for picking with correlation. Returns a copy of the original stream. ''' +def modify_horizontal_name(wfdata: Stream) -> Stream: + """ This function only renames (does not rotate!) numeric channel IDs to be able to rotate them to LQT. It is + therefore not accurate and only used for picking with correlation. Returns a copy of the original stream. """ # wfdata = wfdata.copy() stations = np.unique([tr.stats.station for tr in wfdata]) for station in stations: @@ -460,13 +475,14 @@ def modify_horizontal_name(wfdata): st = wfdata.select(station=station, location=location) channels = [tr.stats.channel for tr in st] check_numeric = [channel[-1].isnumeric() for channel in channels] - check_nonnumeric = [not (cn) for cn in check_numeric] + check_nonnumeric = [not cn for cn in check_numeric] if not any(check_numeric): continue numeric_channels = np.array(channels)[check_numeric].tolist() nonnumeric_channels = np.array(channels)[check_nonnumeric].tolist() - if not 'Z' in [nc[-1] for nc in nonnumeric_channels]: - logging.warning('Modify_horizontal_name failed: Only implemented for existing Z component! Return original data.') + if 'Z' not in [nc[-1] for nc in nonnumeric_channels]: + logging.warning( + 'Modify_horizontal_name failed: Only implemented for existing Z component! Return original data.') return wfdata numeric_characters = sorted([nc[-1] for nc in numeric_channels]) if numeric_characters == ['1', '2']: @@ -474,7 +490,8 @@ def modify_horizontal_name(wfdata): elif numeric_characters == ['2', '3']: channel_dict = {'2': 'N', '3': 'E'} else: - logging.warning('Modify_horizontal_name failed: Channel order not implemented/unknown. Return original data.') + logging.warning( + 'Modify_horizontal_name failed: Channel order not implemented/unknown. Return original data.') return wfdata for tr in st: channel = tr.stats.channel @@ -486,10 +503,10 @@ def modify_horizontal_name(wfdata): return wfdata -def cut_stream_to_same_length(wfdata): - def remove(wfdata, st): - for tr in st: - wfdata.remove(tr) +def cut_stream_to_same_length(wfdata: Stream) -> None: + def remove(data, stream): + for tr in stream: + data.remove(tr) stations = np.unique([tr.stats.station for tr in wfdata]) for station in stations: @@ -515,7 +532,7 @@ def cut_stream_to_same_length(wfdata): st.trim(tstart, tend, pad=True, fill_value=0.) -def get_unique_phase(picks, rename_phases=None): +def get_unique_phase(picks: list, rename_phases: dict = None) -> str: phases = [pick.phase_hint for pick in picks] if rename_phases: phases = [rename_phases[phase] if phase in rename_phases else phase for phase in phases] @@ -523,10 +540,10 @@ def get_unique_phase(picks, rename_phases=None): if len(set(phases)) == 1: return phases[0] - return False - -def correlate_event(eventdir, pylot_parameter, params, channel_config, update): +# TODO: simplify this function +def correlate_event(eventdir: str, pylot_parameter: PylotParameter, params: dict, channel_config: dict, + update: bool) -> bool: rename_phases = {'Pdiff': 'P'} # create ObsPy event from .pkl file in dmt eventdir @@ -588,7 +605,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): wfdata = wfdata_raw.copy() # resample and filter if params[phase_type]['sampfreq']: - wfdata = resample_parallel(wfdata, params[phase_type]['sampfreq'], ncores=params[phase_type]['ncores']) + wfdata = resample_parallel(wfdata, params[phase_type]['sampfreq'], + ncores=params[phase_type]['ncores']) else: logging.warning('Resampling deactivated! ' 'Make sure that the sampling rate of all input waveforms is identical') @@ -602,7 +620,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): else: method = 'auto' # get all picks from PyLoT *.xml file - if not pfe.startswith('_'): pfe = '_' + pfe + if not pfe.startswith('_'): + pfe = '_' + pfe picks = get_picks(eventdir, extension=pfe) # take picks of selected phase only picks = [pick for pick in picks if pick.phase_hint == phase_type] @@ -630,7 +649,7 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): model='ak135') # rotate ZNE -> LQT (also cut to same length) wfdata = rotate_stream(wfdata, metadata=metadata, origin=origin, stations_dict=stations_dict, - channels=channels[phase_type], inclination=inclination, ncores=ncores) + channels=CHANNELS[phase_type], inclination=inclination, ncores=ncores) # make copies before filtering wfdata_lowf = wfdata.copy() @@ -644,7 +663,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): # make directory for correlation output correlation_out_dir = create_correlation_output_dir(eventdir, filter_options_final, true_phase_type) - fig_dir = create_correlation_figure_dir(correlation_out_dir) if params[phase_type]['save_fig'] == True else '' + fig_dir = create_correlation_figure_dir(correlation_out_dir) \ + if params[phase_type]['save_fig'] is True else '' if not params[phase_type]['use_stacked_trace']: # first stack mastertrace by using stations with high correlation on that station in station list with the @@ -666,13 +686,14 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): if params[phase_type]['plot']: # plot correlations of traces used to generate stacked trace plot_mastertrace_corr(nwst_id, correlations_dict, wfdata_lowf, stations_dict, picks, trace_master, method, - min_corr=params[phase_type]['min_corr_stacking'], title=eventdir + '_before_stacking', - fig_dir=fig_dir) + min_corr=params[phase_type]['min_corr_stacking'], + title=eventdir + '_before_stacking', fig_dir=fig_dir) # continue if there is not enough traces for stacking if nstack < params[phase_type]['min_stack']: logging.warning('Not enough traces to stack: {} (min_stack = {}). Skip this event'.format(nstack, - params[phase_type][ + params[ + phase_type][ 'min_stack'])) continue @@ -680,8 +701,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): trace_master.write(os.path.join(correlation_out_dir, '{}_stacked.mseed'.format(trace_master.id))) # now pick stacked trace with PyLoT for a more precise pick (use raw trace, gets filtered by autoPyLoT) - pick_stacked = repick_mastertrace(wfdata_lowf, trace_master, pylot_parameter, event, event_id, metadata, - phase_type, correlation_out_dir) + pick_stacked = repick_master_trace(wfdata_lowf, trace_master, pylot_parameter, event, event_id, metadata, + phase_type, correlation_out_dir) if not pick_stacked: continue @@ -696,7 +717,7 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): trace_master_highf.filter(filter_type, **filter_options_final) input_list = prepare_correlation_input(wfdata_lowf, picks, channels_list, trace_master_lowf, pick_stacked, - params=params[phase_type], plot=params[phase_type]['plot'], + phase_params=params[phase_type], plot=params[phase_type]['plot'], fig_dir=fig_dir_traces, ncorr=2, wfdata_highf=wfdata_highf, trace_master_highf=trace_master_highf) @@ -712,8 +733,8 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): export_picks(eventdir=eventdir, correlations_dict=get_stations_min_corr(correlations_dict_stacked, params[phase_type]['min_corr_export']), - params=params[phase_type], picks=picks, taupypicks=taupypicks_orig, phase_type=true_phase_type, - pf_extension=pfe) + params=params[phase_type], picks=picks, taupypicks=taupypicks_orig, + phase_type=true_phase_type, pf_extension=pfe) # plot results if params[phase_type]['plot']: @@ -748,10 +769,12 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update): if len(correlations_dict_stacked) > 0: return True + else: + return False def remove_outliers(picks, corrected_taupy_picks, threshold): - ''' delete a pick if difference to corrected taupy pick is larger than threshold''' + """ delete a pick if difference to corrected taupy pick is larger than threshold""" n_picks = len(picks) n_outl = 0 @@ -780,7 +803,7 @@ def remove_outliers(picks, corrected_taupy_picks, threshold): def remove_invalid_picks(picks): - ''' Remove picks without uncertainty (invalid PyLoT picks)''' + """ Remove picks without uncertainty (invalid PyLoT picks)""" count = 0 deleted_picks_ids = [] for index, pick in list(reversed(list(enumerate(picks)))): @@ -803,7 +826,7 @@ def get_picks_mean(picks): def get_corrected_taupy_picks(picks, taupypicks, all_available=False): - ''' get mean/median from picks taupy picks, correct latter for the difference ''' + """ get mean/median from picks taupy picks, correct latter for the difference """ def nwst_id_from_wfid(wfid): return '{}.{}'.format(wfid.network_code if wfid.network_code else '', @@ -824,8 +847,9 @@ def get_corrected_taupy_picks(picks, taupypicks, all_available=False): mean_diff = taupy_mean - picks_mean logging.info(f'Calculated {len(taupypicks_new)} TauPy theoretical picks.') - logging.info('Calculated median difference from TauPy theoretical picks of {} seconds. (mean: {})'.format(median_diff, - mean_diff)) + logging.info( + 'Calculated median difference from TauPy theoretical picks of {} seconds. (mean: {})'.format(median_diff, + mean_diff)) # return all available taupypicks corrected for median difference to autopicks if all_available: @@ -851,7 +875,8 @@ def load_stacked_trace(eventdir, min_corr_stack): if not os.path.isfile(corr_dict_fpath): logging.warning('No correlations_for_stacking dict found for event {}!'.format(eventdir)) return - corr_dict = json.load(corr_dict_fpath) # TODO: Check this line + with open(corr_dict_fpath) as fid: + corr_dict = json.load(fid) # get stations for stacking and nstack stations4stack = get_stations_min_corr(corr_dict, min_corr_stack) @@ -866,10 +891,10 @@ def load_stacked_trace(eventdir, min_corr_stack): return corr_dict, nwst_id, stacked_trace, nstack -def repick_mastertrace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir): - rename_LQT = phase_type == 'S' +def repick_master_trace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir): + rename_lqt = phase_type == 'S' # create an 'artificial' stream object which can be used as input for PyLoT - stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_LQT=rename_LQT) + stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_lqt=rename_lqt) stream_master.write(os.path.join(corr_out_dir, 'stacked_master_trace_pylot.mseed'.format(trace_master.id))) try: picksdict = autopickstation(stream_master, pylot_parameter, verbose=True, metadata=metadata, @@ -894,10 +919,10 @@ def repick_mastertrace(wfdata, trace_master, pylot_parameter, event, event_id, m return pick_stacked -def create_correlation_output_dir(eventdir, fopts, phase_type): +def create_correlation_output_dir(eventdir: str, fopts: dict, phase_type: str) -> str: folder = 'correlation' folder = add_fpath_extension(folder, fopts, phase_type) - export_path = os.path.join(eventdir, folder) + export_path = str(os.path.join(eventdir, folder)) if not os.path.isdir(export_path): os.mkdir(export_path) return export_path @@ -927,12 +952,12 @@ def write_correlation_output(export_path, correlations_dict, correlations_dict_s write_json(correlations_dict_stacked, os.path.join(export_path, 'correlation_results.json')) -def modify_master_trace4pick(trace_master, wfdata, rename_LQT=True): - ''' +def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True): + """ Create an artificial Stream with master trace instead of the trace of its corresponding channel. This is done to find metadata for correct station metadata which were modified when stacking (e.g. loc=ST) USE THIS ONLY FOR PICKING! - ''' + """ # fake ZNE coordinates for autopylot lqt_zne = {'L': 'Z', 'Q': 'N', 'T': 'E'} @@ -949,7 +974,7 @@ def modify_master_trace4pick(trace_master, wfdata, rename_LQT=True): # take location from old trace and overwrite trace_master.stats.location = trace.stats.location trace = trace_master - if rename_LQT: + if rename_lqt: channel = trace.stats.channel component_new = lqt_zne.get(channel[-1]) if not component_new: @@ -1088,7 +1113,7 @@ def get_pickfile_name_correlated(event_id, fopts, phase_type): # return trace_this -def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, params, plot=None, fig_dir=None, +def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, phase_params, plot=None, fig_dir=None, ncorr=1, wfdata_highf=None, trace_master_highf=None): # prepare input for multiprocessing worker for all 'other' picks to correlate with current master-trace input_list = [] @@ -1096,7 +1121,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master assert (ncorr in [1, 2]), 'ncorr has to be 1 or 2' for pick_other in picks: - trace_other_highf = None + stream_other = stream_other_high_f = trace_other_high_f = channel = None network_other = pick_other.waveform_id.network_code station_other = pick_other.waveform_id.station_code nwst_id_other = '{nw}.{st}'.format(nw=network_other, st=station_other) @@ -1106,7 +1131,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master for channel in channels: stream_other = wfdata.select(network=network_other, station=station_other, channel=channel) if ncorr == 2: - stream_other_highf = wfdata_highf.select(network=network_other, station=station_other, channel=channel) + stream_other_high_f = wfdata_highf.select(network=network_other, station=station_other, channel=channel) if stream_other: break if not stream_other: @@ -1114,37 +1139,39 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master continue trace_other = stream_other[0] if ncorr == 2: - trace_other_highf = stream_other_highf[0] + trace_other_high_f = stream_other_high_f[0] if trace_other == stream_other: continue input_dict = {'nwst_id': nwst_id_other, 'trace1': trace_master, 'pick1': pick_master, 'trace2': trace_other, - 'pick2': pick_other, 'channel': channel, 't_before': params['t_before'], - 't_after': params['t_after'], 'cc_maxlag': params['cc_maxlag'], - 'cc_maxlag2': params['cc_maxlag2'], 'plot': plot, 'fig_dir': fig_dir, 'ncorr': ncorr, - 'trace1_highf': trace_master_highf, 'trace2_highf': trace_other_highf, - 'min_corr': params['min_corr_export']} + 'pick2': pick_other, 'channel': channel, 't_before': phase_params['t_before'], + 't_after': phase_params['t_after'], 'cc_maxlag': phase_params['cc_maxlag'], + 'cc_maxlag2': phase_params['cc_maxlag2'], 'plot': plot, 'fig_dir': fig_dir, 'ncorr': ncorr, + 'trace1_highf': trace_master_highf, 'trace2_highf': trace_other_high_f, + 'min_corr': phase_params['min_corr_export']} input_list.append(input_dict) return input_list -def stack_mastertrace(wfdata_lowf, wfdata_highf, wfdata_raw, picks, params, channels, method, fig_dir): - ''' +def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Stream, picks: list, params: dict, + channels: list, method: str, fig_dir: str): + """ Correlate all stations with the first available station given in station_list, a list containing central, permanent, long operating, low noise stations with descending priority. A master trace will be created by stacking well correlating traces onto this station. - ''' + """ - def get_best_station4stack(station_results): - ''' return station with maximum mean_ccc''' - ccc_means = {nwst_id: value['mean_ccc'] for nwst_id, value in station_results.items() if + def get_best_station4stack(sta_result): + """ return station with maximum mean_ccc""" + ccc_means = {nwst_id: value['mean_ccc'] for nwst_id, value in sta_result.items() if not np.isnan(value['mean_ccc'])} if len(ccc_means) < 1: logging.warning('No valid station found for stacking! Return.') return best_station_id = max(ccc_means, key=ccc_means.get) - logging.info('Found highest mean correlation for station {} ({})'.format(best_station_id, max(ccc_means.values()))) + logging.info( + 'Found highest mean correlation for station {} ({})'.format(best_station_id, max(ccc_means.values()))) return best_station_id station_results = iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, params, fig_dir=fig_dir) @@ -1153,7 +1180,7 @@ def stack_mastertrace(wfdata_lowf, wfdata_highf, wfdata_raw, picks, params, chan # in case no stream with a valid pick is found if not nwst_id_master: logging.info('No mastertrace found! Will skip this event.') - return False + return None trace_master = station_results[nwst_id_master]['trace'] stations4stack = station_results[nwst_id_master]['stations4stack'] @@ -1163,22 +1190,27 @@ def stack_mastertrace(wfdata_lowf, wfdata_highf, wfdata_raw, picks, params, chan dt_pre, dt_post = params['dt_stacking'] trace_master, nstack = apply_stacking(trace_master, stations4stack, wfdata_raw, picks, method=method, - check_RMS=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir, + check_rms=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir, dt_pre=dt_pre, dt_post=dt_post) return correlations_dict, nwst_id_master, trace_master, nstack -def iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, params, fig_dir=None): - '''iterate over possible stations for master-trace store them and return a dictionary''' +def iterate_correlation(wfdata_lowf: Stream, wfdata_highf: Stream, channels: list, picks: list, method: str, + params: dict, fig_dir: str = None) -> dict: + """iterate over possible stations for master-trace store them and return a dictionary""" - station_results = {nwst_id: {'mean_ccc': np.nan, 'correlations_dict': None, 'stations4stack': None, 'trace': None} + station_results = {nwst_id: {'mean_ccc': np.nan, 'correlations_dict': dict(), 'stations4stack': dict(), + 'trace': None} for nwst_id in params['station_list']} for nwst_id_master in params['station_list']: logging.info(20 * '#') logging.info('Starting correlation for station: {}'.format(nwst_id_master)) nw, st = nwst_id_master.split('.') + + # get master-trace + stream_master_lowf = stream_master_highf = None for channel in channels: stream_master_lowf = wfdata_lowf.select(network=nw, station=st, channel=channel) stream_master_highf = wfdata_highf.select(network=nw, station=st, channel=channel) @@ -1211,8 +1243,8 @@ def iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, para input_list = prepare_correlation_input(wfdata_lowf, picks, channels, trace_master_lowf, pick_master, trace_master_highf=trace_master_highf, wfdata_highf=wfdata_highf, - params=params, plot=params['plot_detailed'], fig_dir=fig_dir_traces, - ncorr=2) + phase_params=params, plot=params['plot_detailed'], + fig_dir=fig_dir_traces, ncorr=2) if params['plot_detailed'] and not fig_dir_traces: # serial @@ -1231,9 +1263,11 @@ def iterate_correlation(wfdata_lowf, wfdata_highf, channels, picks, method, para return station_results -def apply_stacking(trace_master, stations4stack, wfdata, picks, method, check_RMS, dt_pre=250., dt_post=250., - plot=False, fig_dir=None): - def check_trace_length_and_cut(trace, correlated_midtime, dt_pre, dt_post): +def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, picks: list, method: str, + check_rms: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False, + fig_dir: str = None) -> tuple: + + def check_trace_length_and_cut(trace: Trace): starttime = correlated_midtime - dt_pre endtime = correlated_midtime + dt_post @@ -1248,7 +1282,7 @@ def apply_stacking(trace_master, stations4stack, wfdata, picks, method, check_RM trace_master = trace_master.copy() trace_master.stats.location = 'ST' - check_trace_length_and_cut(trace_master, pick.time, dt_pre=dt_pre, dt_post=dt_post) + check_trace_length_and_cut(trace_master) # empty trace so that it does not stack twice trace_master.data = np.zeros(len(trace_master.data)) @@ -1262,15 +1296,16 @@ def apply_stacking(trace_master, stations4stack, wfdata, picks, method, check_RM stream_other = wfdata.select(network=nw, station=st, channel=channel) correlated_midtime = pick_other.time - dpick trace_other = stream_other[0].copy() - check_trace_length_and_cut(trace_other, correlated_midtime, dt_pre=dt_pre, dt_post=dt_post) + check_trace_length_and_cut(trace_other) if not len(trace_other) == len(trace_master): - logging.warning('Can not stack trace on master trace because of different lengths: {}. Continue'.format(nwst_id)) + logging.warning( + 'Can not stack trace on master trace because of different lengths: {}. Continue'.format(nwst_id)) continue traces4stack.append(trace_other) - if check_RMS: + if check_rms: traces4stack = checkRMS(traces4stack, plot=plot, fig_dir=fig_dir) if plot: @@ -1307,12 +1342,14 @@ def checkRMS(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.): traces4stack = sorted(traces4stack, key=lambda x: x.id) for trace in traces4stack: - rms_list.append(RMS(trace.data)) + rms_list.append(calc_rms(trace.data)) trace_names.append(trace.id) # iterative elimination of RMS outliers iterate = True count = 0 + std = 0 + mean = 0 while iterate: count += 1 if count >= max_it: @@ -1341,7 +1378,8 @@ def checkRMS(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.): return traces4stack -def plot_rms(rms_list, trace_names, mean, std, fig_dir, count, ntimes_std): +def plot_rms(rms_list: list, trace_names: list, mean: float, std: float, fig_dir: str, count: int, + ntimes_std: float) -> None: fig = plt.figure(figsize=(16, 9)) ax = fig.add_subplot(111) ax.semilogy(rms_list, 'b.') @@ -1359,7 +1397,7 @@ def plot_rms(rms_list, trace_names, mean, std, fig_dir, count, ntimes_std): plt.close(fig) -def RMS(X): +def calc_rms(X): """ Returns root mean square of a given array LON """ @@ -1381,7 +1419,7 @@ def resample_parallel(stream, freq, ncores): def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination, ncores): - ''' Rotate stream to LQT. To do this all traces have to be cut to equal lenths''' + """ Rotate stream to LQT. To do this all traces have to be cut to equal lenths""" input_list = [] cut_stream_to_same_length(stream) new_stream = Stream() @@ -1424,7 +1462,7 @@ def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination # return stream -def correlate_parallel(input_list, ncores): +def correlate_parallel(input_list: list, ncores: int) -> dict: parallel = Parallel(n_jobs=ncores) logging.info('Correlate_parallel: Generated {} parallel jobs.'.format(ncores)) @@ -1435,7 +1473,7 @@ def correlate_parallel(input_list, ncores): return unpack_result(correlation_result) -def correlate_serial(input_list, plot=False): +def correlate_serial(input_list: list, plot: bool = False) -> dict: correlation_result = [] for input_dict in input_list: input_dict['plot'] = plot @@ -1443,7 +1481,7 @@ def correlate_serial(input_list, plot=False): return unpack_result(correlation_result) -def taupy_parallel(input_list, ncores): +def taupy_parallel(input_list: list, ncores: int) -> dict: parallel = Parallel(n_jobs=ncores) logging.info('Taupy_parallel: Generated {} parallel jobs.'.format(ncores)) @@ -1454,19 +1492,19 @@ def taupy_parallel(input_list, ncores): return unpack_result(taupy_results) -def unpack_result(result): +def unpack_result(result: list) -> dict: result_dict = {item['nwst_id']: {key: item[key] for key in item.keys()} for item in result} nerr = 0 for item in result_dict.values(): if item['err']: logging.debug(f'Found error for {item["nwst_id"]}: {item["err"]}.') - #logging.debug(f'Detailed traceback: {item["exc"]}') + # logging.debug(f'Detailed traceback: {item["exc"]}') nerr += 1 logging.info('Unpack results: Found {} errors after multiprocessing'.format(nerr)) return result_dict -def get_taupy_picks(origin, stations_dict, pylot_parameter, phase_type, ncores=None): +def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, ncores=None): input_list = [] taup_phases = pylot_parameter['taup_phases'].split(',') @@ -1526,10 +1564,10 @@ def taupy_worker(input_dict): first_arrival = arrivals[0] output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=first_arrival.name, phase_time=source_time + first_arrival.time, phase_dist=first_arrival.distance, err=None, - exc=None,) + exc=None, ) except Exception as e: exc = traceback.format_exc() - output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e), exc=exc,) + output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e), exc=exc, ) return output_dict @@ -1559,7 +1597,7 @@ def rotation_worker(input_dict): def correlation_worker(input_dict): - '''worker function for multiprocessing''' + """worker function for multiprocessing""" # unpack input dictionary nwst_id = input_dict['nwst_id'] @@ -1613,7 +1651,7 @@ def correlation_worker(input_dict): 'channel': channel, } -def get_pick4station(picks, network_code, station_code, method='auto'): +def get_pick4station(picks: list, network_code: str, station_code: str, method: str = 'auto') -> Pick: for pick in picks: if pick.waveform_id.network_code == network_code: if pick.waveform_id.station_code == station_code: @@ -1621,13 +1659,14 @@ def get_pick4station(picks, network_code, station_code, method='auto'): return pick -def get_stations_min_corr(corr_results, min_corr): +def get_stations_min_corr(corr_results: dict, min_corr: float) -> dict: corr_results = {nwst_id: result for nwst_id, result in corr_results.items() if result['ccc'] > min_corr} return corr_results -def plot_mastertrace_corr(nwst_id, corr_results, wfdata, stations_dict, picks, trace_master, method, min_corr, title, - fig_dir=None): +def plot_mastertrace_corr(nwst_id: str, corr_results: dict, wfdata: Stream, stations_dict: dict, picks: list, + trace_master: Trace, method: str, min_corr: float, title: str, + fig_dir: str = None) -> None: nw, st = nwst_id.split('.') coords_master = stations_dict[nwst_id] pick_master = get_pick4station(picks, nw, st, method) @@ -1650,7 +1689,7 @@ def plot_mastertrace_corr(nwst_id, corr_results, wfdata, stations_dict, picks, t plt.show() -def make_figure_dirs(fig_dir, trace_master_id): +def make_figure_dirs(fig_dir: str, trace_master_id: str) -> str: fig_dir_traces = '' if fig_dir and os.path.isdir(fig_dir): fig_dir_traces = os.path.join(fig_dir, 'corr_with_{}'.format(trace_master_id)) @@ -1659,15 +1698,18 @@ def make_figure_dirs(fig_dir, trace_master_id): return fig_dir_traces -def plot_section(wfdata, trace_this, pick_this, picks, stations2compare, channels, corr_results, method, dt_pre=20., - dt_post=50, axes=None, max_stations=50.): - '''Plot a section with all stations used for correlation on ax''' +def plot_section(wfdata: Stream, trace_this: Trace, pick_this: Pick, picks: list, stations2compare: dict, + channels: list, corr_results: dict, method: str, dt_pre: float = 20., dt_post: float = 50., + axes: dict = None, max_stations: int = 50) -> None: + """Plot a section with all stations used for correlation on ax""" ylabels = [] yticks = [] trace_this = trace_this.copy() trace_this.trim(starttime=pick_this.time - dt_pre, endtime=pick_this.time + dt_post) + ax_sec = None + # iterate over all closest stations ("other" stations) for index, nwst_id in enumerate(stations2compare.keys()): if index >= max_stations: @@ -1700,6 +1742,8 @@ def plot_section(wfdata, trace_this, pick_this, picks, stations2compare, channel if np.isnan(dpick) or not pick_other or not pick_other.time: continue + stream = None + # continue if there are no data for station for whatever reason for channel in channels: stream = wfdata.select(station=st, network=nw, channel=channel) @@ -1718,13 +1762,14 @@ def plot_section(wfdata, trace_this, pick_this, picks, stations2compare, channel index + 0.5, color='r', lw=0.5) # Plot desciption - ax_sec.set_yticks(yticks) - ax_sec.set_yticklabels(ylabels) - ax_sec.set_title('Section with corresponding picks.') - ax_sec.set_xlabel('Samples. Traces are shifted in time.') + if ax_sec: + ax_sec.set_yticks(yticks) + ax_sec.set_yticklabels(ylabels) + ax_sec.set_title('Section with corresponding picks.') + ax_sec.set_xlabel('Samples. Traces are shifted in time.') -def plot_stacked_trace_pick(trace, pick, pylot_parameter): +def plot_stacked_trace_pick(trace: Trace, pick: Pick, pylot_parameter: PylotParameter) -> plt.Figure: trace_filt = trace.copy() if pylot_parameter['filter_type'] and pylot_parameter['filter_options']: ftype = pylot_parameter['filter_type'][0] @@ -1746,8 +1791,9 @@ def plot_stacked_trace_pick(trace, pick, pylot_parameter): return fig -def plot_stations(stations_dict, stations2compare, coords_this, corr_result, trace, pick, window_title=None): - ''' Plot routine to check proximity algorithm. ''' +def plot_stations(stations_dict: dict, stations2compare: dict, coords_this: dict, corr_result: dict, trace: Trace, + pick: Pick, window_title: str = None): + """ Plot routine to check proximity algorithm. """ title = trace.id @@ -1812,15 +1858,15 @@ def plot_stations(stations_dict, stations2compare, coords_this, corr_result, tra return fig, axes -def get_data(eventdir, data_dir, headonly=False): - ''' Read and return waveformdata read from eventdir/data_dir. ''' +def get_data(eventdir: str, data_dir: str, headonly: bool = False) -> Stream: + """ Read and return waveformdata read from eventdir/data_dir. """ wfdata_path = os.path.join(eventdir, data_dir) wfdata = read(os.path.join(wfdata_path, '*'), headonly=headonly) return wfdata def get_closest_stations(coords, stations_dict, n): - ''' Calculate distances and return n closest stations in stations_dict for station at coords. ''' + """ Calculate distances and return n closest stations in stations_dict for station at coords. """ distances = {} for station_id, st_coords in stations_dict.items(): dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'], @@ -1836,7 +1882,7 @@ def get_closest_stations(coords, stations_dict, n): def get_random_stations(coords, stations_dict, n): - ''' Calculate distances and return n randomly selected stations in stations_dict for station at coords. ''' + """ Calculate distances and return n randomly selected stations in stations_dict for station at coords. """ random_keys = random.sample(list(stations_dict.keys()), n) distances = {} for station_id, st_coords in stations_dict.items(): @@ -1851,11 +1897,41 @@ def get_random_stations(coords, stations_dict, n): return random_stations -# Notes: if use_master_trace is set True, traces above a specified ccc threshold will be stacked onto one 'ideal' -# station into a master-trace. This trace will be picked and used again for cross correlation with all other stations -# to find a pick for these stations. +def prepare_corr_params(parse_args, logger): + with open(parse_args.params) as infile: + parameters_dict = yaml.safe_load(infile) -if __name__ == "__main__": + # logging + logger.setLevel(DEBUG_LEVELS.get(parameters_dict['logging'])) + + # number of cores + if parse_args.ncores is not None: + ncores = int(parse_args.ncores) + else: + ncores = None + + # plot options + plot_dict = dict(plot=parse_args.plot, plot_detailed=parse_args.plot_detailed, + save_fig=not parse_args.show_fig) + + corr_params = {phase: CorrelationParameters(ncores=ncores) for phase in parameters_dict['pick_phases']} + for phase, params_phase in corr_params.items(): + params_phase.add_parameters(**plot_dict) + params_phase.add_parameters(**parameters_dict[phase]) + + return corr_params + + +def init_logger(): + logger = logging.getLogger() + handler = logging.StreamHandler(sys.stdout) + fhandler = logging.FileHandler('pick_corr.out') + logger.addHandler(handler) + logger.addHandler(fhandler) + return logger + + +def setup_parser(): parser = argparse.ArgumentParser(description='Correlate picks from PyLoT.') parser.add_argument('dmt_path', help='path containing dmt_database with PyLoT picks') parser.add_argument('pylot_infile', help='path to autoPyLoT inputfile (pylot.in)') @@ -1881,44 +1957,22 @@ if __name__ == "__main__": parser.add_argument('-istart', default=0, help='first event index') parser.add_argument('-istop', default=1e9, help='last event index') - args = parser.parse_args() + return parser.parse_args() - # PARAMETER DEFINITIONS ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - # plot options - plot_dict = dict(plot=args.plot, plot_detailed=args.plot_detailed, save_fig=not (args.show_fig)) +if __name__ == "__main__": + ARGS = setup_parser() + + # initialize logging + LOGGER = init_logger() # alparray configuration: ZNE ~ 123 - channels = {'P': ['*Z', '*1'], 'S': ['*Q', '*T']} # '*N', '*E', '*2', '*3', - debug_levels = {'debug': logging.DEBUG, - 'info': logging.INFO, - 'warn': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL} + CHANNELS = {'P': ['*Z', '*1'], 'S': ['*Q', '*T']} # '*N', '*E', '*2', '*3', + # initialize parameters from yaml, set logging level + CORR_PARAMS = prepare_corr_params(ARGS, LOGGER) - if args.ncores is not None: - _ncores = int(args.ncores) - else: - _ncores = None - - with open(args.params) as infile: - parameters_dict = yaml.safe_load(infile) - - params = {phase: None for phase in parameters_dict['pick_phases']} - - logger = logging.getLogger() - logger.setLevel(debug_levels.get(parameters_dict['logging'])) - handler = logging.StreamHandler(sys.stdout) - fhandler = logging.FileHandler('pick_corr.out') - logger.addHandler(handler) - logger.addHandler(fhandler) - - for phase in params.keys(): - params_phase = CorrelationParameters(ncores=_ncores) - params_phase.add_parameters(**plot_dict) - params_phase.add_parameters(**parameters_dict[phase]) - params[phase] = params_phase - - correlation_main(args.dmt_path, args.pylot_infile, params=params, istart=int(args.istart), istop=int(args.istop), - channel_config=channels, update=args.update, event_blacklist=args.blacklist) + # MAIN +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + correlation_main(ARGS.dmt_path, ARGS.pylot_infile, params=CORR_PARAMS, istart=int(ARGS.istart), + istop=int(ARGS.istop), + channel_config=CHANNELS, update=ARGS.update, event_blacklist=ARGS.blacklist)