diff --git a/pylot/correlation/pick_correlation_correction.py b/pylot/correlation/pick_correlation_correction.py index b0d926e0..714d190b 100644 --- a/pylot/correlation/pick_correlation_correction.py +++ b/pylot/correlation/pick_correlation_correction.py @@ -11,7 +11,7 @@ import traceback import glob import json from datetime import datetime -from typing import Any +from typing import Any, Optional import numpy as np import matplotlib.pyplot as plt @@ -19,6 +19,7 @@ import yaml from joblib import Parallel, delayed from obspy import read, Stream, UTCDateTime, Trace +from obspy.core.event import Event from obspy.taup import TauPyModel from obspy.core.event.base import WaveformStreamID, ResourceIdentifier from obspy.core.event.origin import Pick, Origin @@ -28,6 +29,7 @@ from obspy.signal.cross_correlation import correlate, xcorr_max from pylot.core.io.inputs import PylotParameter from pylot.core.io.phases import picks_from_picksdict from pylot.core.pick.autopick import autopickstation +from pylot.core.util.dataprocessing import Metadata from pylot.core.util.utils import check4rotated from pylot.core.util.utils import identifyPhaseID from pylot.core.util.event import Event as PylotEvent @@ -771,7 +773,7 @@ def correlate_event(eventdir: str, pylot_parameter: PylotParameter, params: dict return False -def remove_outliers(picks, corrected_taupy_picks, threshold): +def remove_outliers(picks: list, corrected_taupy_picks: list, threshold: float) -> list: """ delete a pick if difference to corrected taupy pick is larger than threshold""" n_picks = len(picks) @@ -800,7 +802,7 @@ def remove_outliers(picks, corrected_taupy_picks, threshold): return picks -def remove_invalid_picks(picks): +def remove_invalid_picks(picks: list) -> list: """ Remove picks without uncertainty (invalid PyLoT picks)""" count = 0 deleted_picks_ids = [] @@ -815,15 +817,15 @@ def remove_invalid_picks(picks): return picks -def get_picks_median(picks): +def get_picks_median(picks: list) -> UTCDateTime: return UTCDateTime(int(np.median([pick.time.timestamp for pick in picks if pick.time]))) -def get_picks_mean(picks): +def get_picks_mean(picks: list) -> UTCDateTime: return UTCDateTime(np.mean([pick.time.timestamp for pick in picks if pick.time])) -def get_corrected_taupy_picks(picks, taupypicks, all_available=False): +def get_corrected_taupy_picks(picks: list, taupypicks: list, all_available: bool = False) -> tuple: """ get mean/median from picks taupy picks, correct latter for the difference """ def nwst_id_from_wfid(wfid): @@ -859,7 +861,7 @@ def get_corrected_taupy_picks(picks, taupypicks, all_available=False): return taupypicks_new, median_diff -def load_stacked_trace(eventdir, min_corr_stack): +def load_stacked_trace(eventdir: str, min_corr_stack: float) -> Optional[tuple]: # load stacked stream (miniseed) str_stack_fpaths = glob.glob(os.path.join(eventdir, 'correlation', '*_stacked.mseed')) if not len(str_stack_fpaths) == 1: @@ -888,7 +890,8 @@ def load_stacked_trace(eventdir, min_corr_stack): return corr_dict, nwst_id, stacked_trace, nstack -def repick_master_trace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir): +def repick_master_trace(wfdata: Stream, trace_master: Trace, pylot_parameter: PylotParameter, event: Event, + event_id: str, metadata: Metadata, phase_type: str, corr_out_dir: str) -> Optional[Pick]: rename_lqt = phase_type == 'S' # create an 'artificial' stream object which can be used as input for PyLoT stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_lqt=rename_lqt) @@ -925,14 +928,14 @@ def create_correlation_output_dir(eventdir: str, fopts: dict, phase_type: str) - return export_path -def create_correlation_figure_dir(correlation_out_dir): +def create_correlation_figure_dir(correlation_out_dir: str) -> str: export_path = os.path.join(correlation_out_dir, 'figures') if not os.path.isdir(export_path): os.mkdir(export_path) return export_path -def add_fpath_extension(fpath, fopts, phase): +def add_fpath_extension(fpath: str, fopts: dict, phase: str) -> str: if fopts: freqmin, freqmax = fopts['freqmin'], fopts['freqmax'] if freqmin: @@ -944,12 +947,12 @@ def add_fpath_extension(fpath, fopts, phase): return fpath -def write_correlation_output(export_path, correlations_dict, correlations_dict_stacked): +def write_correlation_output(export_path: str, correlations_dict: dict, correlations_dict_stacked: dict) -> None: write_json(correlations_dict, os.path.join(export_path, 'correlations_for_stacking.json')) write_json(correlations_dict_stacked, os.path.join(export_path, 'correlation_results.json')) -def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True): +def modify_master_trace4pick(trace_master: Trace, wfdata: Stream, rename_lqt: bool = True) -> Stream: """ Create an artificial Stream with master trace instead of the trace of its corresponding channel. This is done to find metadata for correct station metadata which were modified when stacking (e.g. loc=ST) @@ -985,7 +988,8 @@ def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True): return stream_master -def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_type, pf_extension): +def export_picks(eventdir: str, correlations_dict: dict, picks: list, taupypicks: list, params: dict, + phase_type: str, pf_extension: str) -> None: threshold = params['export_threshold'] min_picks_export = params['min_picks_export'] # make copy so that modified picks are not exported @@ -1057,7 +1061,8 @@ def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_t logging.info('Wrote {} correlated picks to file {}'.format(len(event.picks), fpath)) -def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='corrected_taup_times'): +def write_taupy_picks(event: Event, eventdir: str, taupypicks: list, time_shift: float, + extension: str = 'corrected_taup_times') -> None: # make copies because both objects are being modified event = copy.deepcopy(event) taupypicks = copy.deepcopy(taupypicks) @@ -1079,17 +1084,17 @@ def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='correc write_event(event, eventdir, fname) -def write_event(event, eventdir, fname): +def write_event(event: Event, eventdir: str, fname: str) -> str: fpath = os.path.join(eventdir, fname) event.write(fpath, format='QUAKEML') return fpath -def get_pickfile_name(event_id, fname_extension): +def get_pickfile_name(event_id: str, fname_extension: str) -> str: return 'PyLoT_{}_{}.xml'.format(event_id, fname_extension) -def get_pickfile_name_correlated(event_id, fopts, phase_type): +def get_pickfile_name_correlated(event_id: str, fopts: dict, phase_type: str) -> str: fname_extension = add_fpath_extension('correlated', fopts, phase_type) return get_pickfile_name(event_id, fname_extension) @@ -1110,8 +1115,9 @@ def get_pickfile_name_correlated(event_id, fopts, phase_type): # return trace_this -def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, phase_params, plot=None, fig_dir=None, - ncorr=1, wfdata_highf=None, trace_master_highf=None): +def prepare_correlation_input(wfdata: Stream, picks: list, channels: list, trace_master: Trace, pick_master: Pick, + phase_params: dict, plot: bool = None, fig_dir: str = None, + ncorr: int = 1, wfdata_highf: Stream = None, trace_master_highf: Stream = None) -> list: # prepare input for multiprocessing worker for all 'other' picks to correlate with current master-trace input_list = [] @@ -1152,7 +1158,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Stream, picks: list, params: dict, - channels: list, method: str, fig_dir: str): + channels: list, method: str, fig_dir: str) -> Optional[tuple]: """ Correlate all stations with the first available station given in station_list, a list containing central, permanent, long operating, low noise stations with descending priority. @@ -1187,7 +1193,7 @@ def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Str dt_pre, dt_post = params['dt_stacking'] trace_master, nstack = apply_stacking(trace_master, stations4stack, wfdata_raw, picks, method=method, - check_rms=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir, + do_rms_check=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir, dt_pre=dt_pre, dt_post=dt_post) return correlations_dict, nwst_id_master, trace_master, nstack @@ -1261,7 +1267,7 @@ def iterate_correlation(wfdata_lowf: Stream, wfdata_highf: Stream, channels: lis def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, picks: list, method: str, - check_rms: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False, + do_rms_check: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False, fig_dir: str = None) -> tuple: def check_trace_length_and_cut(trace: Trace, pick_time: UTCDateTime = None): @@ -1334,7 +1340,8 @@ def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, pi return trace_master, count -def check_rms(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.): +def check_rms(traces4stack: list, plot: bool = False, fig_dir: str = None, max_it: int = 10, + ntimes_std: float = 5.) -> list: rms_list = [] trace_names = [] @@ -1396,14 +1403,14 @@ def plot_rms(rms_list: list, trace_names: list, mean: float, std: float, fig_dir plt.close(fig) -def calc_rms(X): +def calc_rms(array: np.ndarray) -> float: """ Returns root mean square of a given array LON """ - return np.sqrt(np.sum(np.power(X, 2)) / len(X)) + return np.sqrt(np.sum(np.power(array, 2)) / len(array)) -def resample_parallel(stream, freq, ncores): +def resample_parallel(stream: Stream, freq: float, ncores: int) -> Stream: input_list = [{'trace': trace, 'freq': freq} for trace in stream.traces] parallel = Parallel(n_jobs=ncores) @@ -1417,7 +1424,8 @@ def resample_parallel(stream, freq, ncores): return stream -def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination, ncores): +def rotate_stream(stream: Stream, metadata: Metadata, origin: Origin, stations_dict: dict, channels: list, + inclination: float, ncores: int) -> Stream: """ Rotate stream to LQT. To do this all traces have to be cut to equal lenths""" input_list = [] cut_stream_to_same_length(stream) @@ -1503,7 +1511,8 @@ def unpack_result(result: list) -> dict: return result_dict -def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, ncores=None): +def get_taupy_picks(origin: Origin, stations_dict: dict, pylot_parameter: PylotParameter, phase_type: str, + ncores: int = None) -> list: input_list = [] taup_phases = pylot_parameter['taup_phases'].split(',') @@ -1531,7 +1540,7 @@ def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, return taupy_picks -def create_artificial_picks(taupy_result): +def create_artificial_picks(taupy_result: dict) -> list: artificial_picks = [] for nwst_id, taupy_dict in taupy_result.items(): nw, st = nwst_id.split('.') @@ -1543,14 +1552,14 @@ def create_artificial_picks(taupy_result): return artificial_picks -def check_taupy_phases(taupy_result): +def check_taupy_phases(taupy_result: dict) -> None: test_phase_name = list(taupy_result.values())[0]['phase_name'] phase_names_equal = [item['phase_name'] == test_phase_name for item in taupy_result.values()] if not all(phase_names_equal): logging.info('Different first arriving phases detected in TauPy phases for this event.') -def taupy_worker(input_dict): +def taupy_worker(input_dict: dict) -> dict: model = input_dict['model'] taupy_input = input_dict['taupy_input'] source_time = input_dict['source_time'] @@ -1570,7 +1579,7 @@ def taupy_worker(input_dict): return output_dict -def resample_worker(input_dict): +def resample_worker(input_dict: dict) -> Trace: trace = input_dict['trace'] freq = input_dict['freq'] if freq == trace.stats.sampling_rate: @@ -1578,7 +1587,7 @@ def resample_worker(input_dict): return trace.resample(freq, no_filter=freq > trace.stats.sampling_rate) -def rotation_worker(input_dict): +def rotation_worker(input_dict: dict) -> Optional[Stream]: stream = input_dict['stream'] tstart = max([tr.stats.starttime for tr in stream]) tend = min([tr.stats.endtime for tr in stream]) @@ -1595,7 +1604,7 @@ def rotation_worker(input_dict): return stream -def correlation_worker(input_dict): +def correlation_worker(input_dict: dict) -> dict: """worker function for multiprocessing""" # unpack input dictionary @@ -1791,7 +1800,7 @@ def plot_stacked_trace_pick(trace: Trace, pick: Pick, pylot_parameter: PylotPara def plot_stations(stations_dict: dict, stations2compare: dict, coords_this: dict, corr_result: dict, trace: Trace, - pick: Pick, window_title: str = None): + pick: Pick, window_title: str = None) -> Optional[tuple]: """ Plot routine to check proximity algorithm. """ title = trace.id @@ -1864,8 +1873,7 @@ def get_data(eventdir: str, data_dir: str, headonly: bool = False) -> Stream: return wfdata -def get_closest_stations(coords, stations_dict, n): - """ Calculate distances and return n closest stations in stations_dict for station at coords. """ +def get_station_distances(stations_dict: dict, coords: dict) -> dict: distances = {} for station_id, st_coords in stations_dict.items(): dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'], @@ -1875,28 +1883,28 @@ def get_closest_stations(coords, stations_dict, n): continue distances[dist] = station_id + return distances + + +def get_closest_stations(coords: dict, stations_dict: dict, n: int) -> dict: + """ Calculate distances and return the n closest stations in stations_dict for station at coords. """ + distances = get_station_distances(stations_dict, coords) + closest_distances = sorted(list(distances.keys()))[:n + 1] closest_stations = {station: dist for dist, station in distances.items() if dist in closest_distances} return closest_stations -def get_random_stations(coords, stations_dict, n): +def get_random_stations(coords: dict, stations_dict: dict, n: int) -> dict: """ Calculate distances and return n randomly selected stations in stations_dict for station at coords. """ random_keys = random.sample(list(stations_dict.keys()), n) - distances = {} - for station_id, st_coords in stations_dict.items(): - dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'], - a=6.371e6, f=0)[0] - # exclude same coordinate (self or other instrument at same location) - if dist == 0: - continue - distances[dist] = station_id + distances = get_station_distances(stations_dict, coords) random_stations = {station: dist for dist, station in distances.items() if station in random_keys} return random_stations -def prepare_corr_params(parse_args, logger): +def prepare_corr_params(parse_args: argparse, logger: logging.Logger) -> dict: with open(parse_args.params) as infile: parameters_dict = yaml.safe_load(infile) @@ -1921,7 +1929,7 @@ def prepare_corr_params(parse_args, logger): return corr_params -def init_logger(): +def init_logger() -> logging.Logger: logger = logging.getLogger() handler = logging.StreamHandler(sys.stdout) fhandler = logging.FileHandler('pick_corr.out')