[refactor] finished annotations (type hints)

This commit is contained in:
Marcel Paffrath 2024-08-09 16:52:32 +02:00
parent 759e7bb848
commit 176e93d833

View File

@ -11,7 +11,7 @@ import traceback
import glob
import json
from datetime import datetime
from typing import Any
from typing import Any, Optional
import numpy as np
import matplotlib.pyplot as plt
@ -19,6 +19,7 @@ import yaml
from joblib import Parallel, delayed
from obspy import read, Stream, UTCDateTime, Trace
from obspy.core.event import Event
from obspy.taup import TauPyModel
from obspy.core.event.base import WaveformStreamID, ResourceIdentifier
from obspy.core.event.origin import Pick, Origin
@ -28,6 +29,7 @@ from obspy.signal.cross_correlation import correlate, xcorr_max
from pylot.core.io.inputs import PylotParameter
from pylot.core.io.phases import picks_from_picksdict
from pylot.core.pick.autopick import autopickstation
from pylot.core.util.dataprocessing import Metadata
from pylot.core.util.utils import check4rotated
from pylot.core.util.utils import identifyPhaseID
from pylot.core.util.event import Event as PylotEvent
@ -771,7 +773,7 @@ def correlate_event(eventdir: str, pylot_parameter: PylotParameter, params: dict
return False
def remove_outliers(picks, corrected_taupy_picks, threshold):
def remove_outliers(picks: list, corrected_taupy_picks: list, threshold: float) -> list:
""" delete a pick if difference to corrected taupy pick is larger than threshold"""
n_picks = len(picks)
@ -800,7 +802,7 @@ def remove_outliers(picks, corrected_taupy_picks, threshold):
return picks
def remove_invalid_picks(picks):
def remove_invalid_picks(picks: list) -> list:
""" Remove picks without uncertainty (invalid PyLoT picks)"""
count = 0
deleted_picks_ids = []
@ -815,15 +817,15 @@ def remove_invalid_picks(picks):
return picks
def get_picks_median(picks):
def get_picks_median(picks: list) -> UTCDateTime:
return UTCDateTime(int(np.median([pick.time.timestamp for pick in picks if pick.time])))
def get_picks_mean(picks):
def get_picks_mean(picks: list) -> UTCDateTime:
return UTCDateTime(np.mean([pick.time.timestamp for pick in picks if pick.time]))
def get_corrected_taupy_picks(picks, taupypicks, all_available=False):
def get_corrected_taupy_picks(picks: list, taupypicks: list, all_available: bool = False) -> tuple:
""" get mean/median from picks taupy picks, correct latter for the difference """
def nwst_id_from_wfid(wfid):
@ -859,7 +861,7 @@ def get_corrected_taupy_picks(picks, taupypicks, all_available=False):
return taupypicks_new, median_diff
def load_stacked_trace(eventdir, min_corr_stack):
def load_stacked_trace(eventdir: str, min_corr_stack: float) -> Optional[tuple]:
# load stacked stream (miniseed)
str_stack_fpaths = glob.glob(os.path.join(eventdir, 'correlation', '*_stacked.mseed'))
if not len(str_stack_fpaths) == 1:
@ -888,7 +890,8 @@ def load_stacked_trace(eventdir, min_corr_stack):
return corr_dict, nwst_id, stacked_trace, nstack
def repick_master_trace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir):
def repick_master_trace(wfdata: Stream, trace_master: Trace, pylot_parameter: PylotParameter, event: Event,
event_id: str, metadata: Metadata, phase_type: str, corr_out_dir: str) -> Optional[Pick]:
rename_lqt = phase_type == 'S'
# create an 'artificial' stream object which can be used as input for PyLoT
stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_lqt=rename_lqt)
@ -925,14 +928,14 @@ def create_correlation_output_dir(eventdir: str, fopts: dict, phase_type: str) -
return export_path
def create_correlation_figure_dir(correlation_out_dir):
def create_correlation_figure_dir(correlation_out_dir: str) -> str:
export_path = os.path.join(correlation_out_dir, 'figures')
if not os.path.isdir(export_path):
os.mkdir(export_path)
return export_path
def add_fpath_extension(fpath, fopts, phase):
def add_fpath_extension(fpath: str, fopts: dict, phase: str) -> str:
if fopts:
freqmin, freqmax = fopts['freqmin'], fopts['freqmax']
if freqmin:
@ -944,12 +947,12 @@ def add_fpath_extension(fpath, fopts, phase):
return fpath
def write_correlation_output(export_path, correlations_dict, correlations_dict_stacked):
def write_correlation_output(export_path: str, correlations_dict: dict, correlations_dict_stacked: dict) -> None:
write_json(correlations_dict, os.path.join(export_path, 'correlations_for_stacking.json'))
write_json(correlations_dict_stacked, os.path.join(export_path, 'correlation_results.json'))
def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True):
def modify_master_trace4pick(trace_master: Trace, wfdata: Stream, rename_lqt: bool = True) -> Stream:
"""
Create an artificial Stream with master trace instead of the trace of its corresponding channel.
This is done to find metadata for correct station metadata which were modified when stacking (e.g. loc=ST)
@ -985,7 +988,8 @@ def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True):
return stream_master
def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_type, pf_extension):
def export_picks(eventdir: str, correlations_dict: dict, picks: list, taupypicks: list, params: dict,
phase_type: str, pf_extension: str) -> None:
threshold = params['export_threshold']
min_picks_export = params['min_picks_export']
# make copy so that modified picks are not exported
@ -1057,7 +1061,8 @@ def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_t
logging.info('Wrote {} correlated picks to file {}'.format(len(event.picks), fpath))
def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='corrected_taup_times'):
def write_taupy_picks(event: Event, eventdir: str, taupypicks: list, time_shift: float,
extension: str = 'corrected_taup_times') -> None:
# make copies because both objects are being modified
event = copy.deepcopy(event)
taupypicks = copy.deepcopy(taupypicks)
@ -1079,17 +1084,17 @@ def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='correc
write_event(event, eventdir, fname)
def write_event(event, eventdir, fname):
def write_event(event: Event, eventdir: str, fname: str) -> str:
fpath = os.path.join(eventdir, fname)
event.write(fpath, format='QUAKEML')
return fpath
def get_pickfile_name(event_id, fname_extension):
def get_pickfile_name(event_id: str, fname_extension: str) -> str:
return 'PyLoT_{}_{}.xml'.format(event_id, fname_extension)
def get_pickfile_name_correlated(event_id, fopts, phase_type):
def get_pickfile_name_correlated(event_id: str, fopts: dict, phase_type: str) -> str:
fname_extension = add_fpath_extension('correlated', fopts, phase_type)
return get_pickfile_name(event_id, fname_extension)
@ -1110,8 +1115,9 @@ def get_pickfile_name_correlated(event_id, fopts, phase_type):
# return trace_this
def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, phase_params, plot=None, fig_dir=None,
ncorr=1, wfdata_highf=None, trace_master_highf=None):
def prepare_correlation_input(wfdata: Stream, picks: list, channels: list, trace_master: Trace, pick_master: Pick,
phase_params: dict, plot: bool = None, fig_dir: str = None,
ncorr: int = 1, wfdata_highf: Stream = None, trace_master_highf: Stream = None) -> list:
# prepare input for multiprocessing worker for all 'other' picks to correlate with current master-trace
input_list = []
@ -1152,7 +1158,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master
def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Stream, picks: list, params: dict,
channels: list, method: str, fig_dir: str):
channels: list, method: str, fig_dir: str) -> Optional[tuple]:
"""
Correlate all stations with the first available station given in station_list, a list containing central,
permanent, long operating, low noise stations with descending priority.
@ -1187,7 +1193,7 @@ def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Str
dt_pre, dt_post = params['dt_stacking']
trace_master, nstack = apply_stacking(trace_master, stations4stack, wfdata_raw, picks, method=method,
check_rms=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir,
do_rms_check=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir,
dt_pre=dt_pre, dt_post=dt_post)
return correlations_dict, nwst_id_master, trace_master, nstack
@ -1261,7 +1267,7 @@ def iterate_correlation(wfdata_lowf: Stream, wfdata_highf: Stream, channels: lis
def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, picks: list, method: str,
check_rms: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False,
do_rms_check: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False,
fig_dir: str = None) -> tuple:
def check_trace_length_and_cut(trace: Trace, pick_time: UTCDateTime = None):
@ -1334,7 +1340,8 @@ def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, pi
return trace_master, count
def check_rms(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.):
def check_rms(traces4stack: list, plot: bool = False, fig_dir: str = None, max_it: int = 10,
ntimes_std: float = 5.) -> list:
rms_list = []
trace_names = []
@ -1396,14 +1403,14 @@ def plot_rms(rms_list: list, trace_names: list, mean: float, std: float, fig_dir
plt.close(fig)
def calc_rms(X):
def calc_rms(array: np.ndarray) -> float:
"""
Returns root mean square of a given array LON
"""
return np.sqrt(np.sum(np.power(X, 2)) / len(X))
return np.sqrt(np.sum(np.power(array, 2)) / len(array))
def resample_parallel(stream, freq, ncores):
def resample_parallel(stream: Stream, freq: float, ncores: int) -> Stream:
input_list = [{'trace': trace, 'freq': freq} for trace in stream.traces]
parallel = Parallel(n_jobs=ncores)
@ -1417,7 +1424,8 @@ def resample_parallel(stream, freq, ncores):
return stream
def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination, ncores):
def rotate_stream(stream: Stream, metadata: Metadata, origin: Origin, stations_dict: dict, channels: list,
inclination: float, ncores: int) -> Stream:
""" Rotate stream to LQT. To do this all traces have to be cut to equal lenths"""
input_list = []
cut_stream_to_same_length(stream)
@ -1503,7 +1511,8 @@ def unpack_result(result: list) -> dict:
return result_dict
def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, ncores=None):
def get_taupy_picks(origin: Origin, stations_dict: dict, pylot_parameter: PylotParameter, phase_type: str,
ncores: int = None) -> list:
input_list = []
taup_phases = pylot_parameter['taup_phases'].split(',')
@ -1531,7 +1540,7 @@ def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type,
return taupy_picks
def create_artificial_picks(taupy_result):
def create_artificial_picks(taupy_result: dict) -> list:
artificial_picks = []
for nwst_id, taupy_dict in taupy_result.items():
nw, st = nwst_id.split('.')
@ -1543,14 +1552,14 @@ def create_artificial_picks(taupy_result):
return artificial_picks
def check_taupy_phases(taupy_result):
def check_taupy_phases(taupy_result: dict) -> None:
test_phase_name = list(taupy_result.values())[0]['phase_name']
phase_names_equal = [item['phase_name'] == test_phase_name for item in taupy_result.values()]
if not all(phase_names_equal):
logging.info('Different first arriving phases detected in TauPy phases for this event.')
def taupy_worker(input_dict):
def taupy_worker(input_dict: dict) -> dict:
model = input_dict['model']
taupy_input = input_dict['taupy_input']
source_time = input_dict['source_time']
@ -1570,7 +1579,7 @@ def taupy_worker(input_dict):
return output_dict
def resample_worker(input_dict):
def resample_worker(input_dict: dict) -> Trace:
trace = input_dict['trace']
freq = input_dict['freq']
if freq == trace.stats.sampling_rate:
@ -1578,7 +1587,7 @@ def resample_worker(input_dict):
return trace.resample(freq, no_filter=freq > trace.stats.sampling_rate)
def rotation_worker(input_dict):
def rotation_worker(input_dict: dict) -> Optional[Stream]:
stream = input_dict['stream']
tstart = max([tr.stats.starttime for tr in stream])
tend = min([tr.stats.endtime for tr in stream])
@ -1595,7 +1604,7 @@ def rotation_worker(input_dict):
return stream
def correlation_worker(input_dict):
def correlation_worker(input_dict: dict) -> dict:
"""worker function for multiprocessing"""
# unpack input dictionary
@ -1791,7 +1800,7 @@ def plot_stacked_trace_pick(trace: Trace, pick: Pick, pylot_parameter: PylotPara
def plot_stations(stations_dict: dict, stations2compare: dict, coords_this: dict, corr_result: dict, trace: Trace,
pick: Pick, window_title: str = None):
pick: Pick, window_title: str = None) -> Optional[tuple]:
""" Plot routine to check proximity algorithm. """
title = trace.id
@ -1864,8 +1873,7 @@ def get_data(eventdir: str, data_dir: str, headonly: bool = False) -> Stream:
return wfdata
def get_closest_stations(coords, stations_dict, n):
""" Calculate distances and return n closest stations in stations_dict for station at coords. """
def get_station_distances(stations_dict: dict, coords: dict) -> dict:
distances = {}
for station_id, st_coords in stations_dict.items():
dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'],
@ -1875,28 +1883,28 @@ def get_closest_stations(coords, stations_dict, n):
continue
distances[dist] = station_id
return distances
def get_closest_stations(coords: dict, stations_dict: dict, n: int) -> dict:
""" Calculate distances and return the n closest stations in stations_dict for station at coords. """
distances = get_station_distances(stations_dict, coords)
closest_distances = sorted(list(distances.keys()))[:n + 1]
closest_stations = {station: dist for dist, station in distances.items() if dist in closest_distances}
return closest_stations
def get_random_stations(coords, stations_dict, n):
def get_random_stations(coords: dict, stations_dict: dict, n: int) -> dict:
""" Calculate distances and return n randomly selected stations in stations_dict for station at coords. """
random_keys = random.sample(list(stations_dict.keys()), n)
distances = {}
for station_id, st_coords in stations_dict.items():
dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'],
a=6.371e6, f=0)[0]
# exclude same coordinate (self or other instrument at same location)
if dist == 0:
continue
distances[dist] = station_id
distances = get_station_distances(stations_dict, coords)
random_stations = {station: dist for dist, station in distances.items() if station in random_keys}
return random_stations
def prepare_corr_params(parse_args, logger):
def prepare_corr_params(parse_args: argparse, logger: logging.Logger) -> dict:
with open(parse_args.params) as infile:
parameters_dict = yaml.safe_load(infile)
@ -1921,7 +1929,7 @@ def prepare_corr_params(parse_args, logger):
return corr_params
def init_logger():
def init_logger() -> logging.Logger:
logger = logging.getLogger()
handler = logging.StreamHandler(sys.stdout)
fhandler = logging.FileHandler('pick_corr.out')