[refactor] finished annotations (type hints)
This commit is contained in:
parent
759e7bb848
commit
176e93d833
@ -11,7 +11,7 @@ import traceback
|
||||
import glob
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -19,6 +19,7 @@ import yaml
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from obspy import read, Stream, UTCDateTime, Trace
|
||||
from obspy.core.event import Event
|
||||
from obspy.taup import TauPyModel
|
||||
from obspy.core.event.base import WaveformStreamID, ResourceIdentifier
|
||||
from obspy.core.event.origin import Pick, Origin
|
||||
@ -28,6 +29,7 @@ from obspy.signal.cross_correlation import correlate, xcorr_max
|
||||
from pylot.core.io.inputs import PylotParameter
|
||||
from pylot.core.io.phases import picks_from_picksdict
|
||||
from pylot.core.pick.autopick import autopickstation
|
||||
from pylot.core.util.dataprocessing import Metadata
|
||||
from pylot.core.util.utils import check4rotated
|
||||
from pylot.core.util.utils import identifyPhaseID
|
||||
from pylot.core.util.event import Event as PylotEvent
|
||||
@ -771,7 +773,7 @@ def correlate_event(eventdir: str, pylot_parameter: PylotParameter, params: dict
|
||||
return False
|
||||
|
||||
|
||||
def remove_outliers(picks, corrected_taupy_picks, threshold):
|
||||
def remove_outliers(picks: list, corrected_taupy_picks: list, threshold: float) -> list:
|
||||
""" delete a pick if difference to corrected taupy pick is larger than threshold"""
|
||||
|
||||
n_picks = len(picks)
|
||||
@ -800,7 +802,7 @@ def remove_outliers(picks, corrected_taupy_picks, threshold):
|
||||
return picks
|
||||
|
||||
|
||||
def remove_invalid_picks(picks):
|
||||
def remove_invalid_picks(picks: list) -> list:
|
||||
""" Remove picks without uncertainty (invalid PyLoT picks)"""
|
||||
count = 0
|
||||
deleted_picks_ids = []
|
||||
@ -815,15 +817,15 @@ def remove_invalid_picks(picks):
|
||||
return picks
|
||||
|
||||
|
||||
def get_picks_median(picks):
|
||||
def get_picks_median(picks: list) -> UTCDateTime:
|
||||
return UTCDateTime(int(np.median([pick.time.timestamp for pick in picks if pick.time])))
|
||||
|
||||
|
||||
def get_picks_mean(picks):
|
||||
def get_picks_mean(picks: list) -> UTCDateTime:
|
||||
return UTCDateTime(np.mean([pick.time.timestamp for pick in picks if pick.time]))
|
||||
|
||||
|
||||
def get_corrected_taupy_picks(picks, taupypicks, all_available=False):
|
||||
def get_corrected_taupy_picks(picks: list, taupypicks: list, all_available: bool = False) -> tuple:
|
||||
""" get mean/median from picks taupy picks, correct latter for the difference """
|
||||
|
||||
def nwst_id_from_wfid(wfid):
|
||||
@ -859,7 +861,7 @@ def get_corrected_taupy_picks(picks, taupypicks, all_available=False):
|
||||
return taupypicks_new, median_diff
|
||||
|
||||
|
||||
def load_stacked_trace(eventdir, min_corr_stack):
|
||||
def load_stacked_trace(eventdir: str, min_corr_stack: float) -> Optional[tuple]:
|
||||
# load stacked stream (miniseed)
|
||||
str_stack_fpaths = glob.glob(os.path.join(eventdir, 'correlation', '*_stacked.mseed'))
|
||||
if not len(str_stack_fpaths) == 1:
|
||||
@ -888,7 +890,8 @@ def load_stacked_trace(eventdir, min_corr_stack):
|
||||
return corr_dict, nwst_id, stacked_trace, nstack
|
||||
|
||||
|
||||
def repick_master_trace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir):
|
||||
def repick_master_trace(wfdata: Stream, trace_master: Trace, pylot_parameter: PylotParameter, event: Event,
|
||||
event_id: str, metadata: Metadata, phase_type: str, corr_out_dir: str) -> Optional[Pick]:
|
||||
rename_lqt = phase_type == 'S'
|
||||
# create an 'artificial' stream object which can be used as input for PyLoT
|
||||
stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_lqt=rename_lqt)
|
||||
@ -925,14 +928,14 @@ def create_correlation_output_dir(eventdir: str, fopts: dict, phase_type: str) -
|
||||
return export_path
|
||||
|
||||
|
||||
def create_correlation_figure_dir(correlation_out_dir):
|
||||
def create_correlation_figure_dir(correlation_out_dir: str) -> str:
|
||||
export_path = os.path.join(correlation_out_dir, 'figures')
|
||||
if not os.path.isdir(export_path):
|
||||
os.mkdir(export_path)
|
||||
return export_path
|
||||
|
||||
|
||||
def add_fpath_extension(fpath, fopts, phase):
|
||||
def add_fpath_extension(fpath: str, fopts: dict, phase: str) -> str:
|
||||
if fopts:
|
||||
freqmin, freqmax = fopts['freqmin'], fopts['freqmax']
|
||||
if freqmin:
|
||||
@ -944,12 +947,12 @@ def add_fpath_extension(fpath, fopts, phase):
|
||||
return fpath
|
||||
|
||||
|
||||
def write_correlation_output(export_path, correlations_dict, correlations_dict_stacked):
|
||||
def write_correlation_output(export_path: str, correlations_dict: dict, correlations_dict_stacked: dict) -> None:
|
||||
write_json(correlations_dict, os.path.join(export_path, 'correlations_for_stacking.json'))
|
||||
write_json(correlations_dict_stacked, os.path.join(export_path, 'correlation_results.json'))
|
||||
|
||||
|
||||
def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True):
|
||||
def modify_master_trace4pick(trace_master: Trace, wfdata: Stream, rename_lqt: bool = True) -> Stream:
|
||||
"""
|
||||
Create an artificial Stream with master trace instead of the trace of its corresponding channel.
|
||||
This is done to find metadata for correct station metadata which were modified when stacking (e.g. loc=ST)
|
||||
@ -985,7 +988,8 @@ def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True):
|
||||
return stream_master
|
||||
|
||||
|
||||
def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_type, pf_extension):
|
||||
def export_picks(eventdir: str, correlations_dict: dict, picks: list, taupypicks: list, params: dict,
|
||||
phase_type: str, pf_extension: str) -> None:
|
||||
threshold = params['export_threshold']
|
||||
min_picks_export = params['min_picks_export']
|
||||
# make copy so that modified picks are not exported
|
||||
@ -1057,7 +1061,8 @@ def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_t
|
||||
logging.info('Wrote {} correlated picks to file {}'.format(len(event.picks), fpath))
|
||||
|
||||
|
||||
def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='corrected_taup_times'):
|
||||
def write_taupy_picks(event: Event, eventdir: str, taupypicks: list, time_shift: float,
|
||||
extension: str = 'corrected_taup_times') -> None:
|
||||
# make copies because both objects are being modified
|
||||
event = copy.deepcopy(event)
|
||||
taupypicks = copy.deepcopy(taupypicks)
|
||||
@ -1079,17 +1084,17 @@ def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='correc
|
||||
write_event(event, eventdir, fname)
|
||||
|
||||
|
||||
def write_event(event, eventdir, fname):
|
||||
def write_event(event: Event, eventdir: str, fname: str) -> str:
|
||||
fpath = os.path.join(eventdir, fname)
|
||||
event.write(fpath, format='QUAKEML')
|
||||
return fpath
|
||||
|
||||
|
||||
def get_pickfile_name(event_id, fname_extension):
|
||||
def get_pickfile_name(event_id: str, fname_extension: str) -> str:
|
||||
return 'PyLoT_{}_{}.xml'.format(event_id, fname_extension)
|
||||
|
||||
|
||||
def get_pickfile_name_correlated(event_id, fopts, phase_type):
|
||||
def get_pickfile_name_correlated(event_id: str, fopts: dict, phase_type: str) -> str:
|
||||
fname_extension = add_fpath_extension('correlated', fopts, phase_type)
|
||||
return get_pickfile_name(event_id, fname_extension)
|
||||
|
||||
@ -1110,8 +1115,9 @@ def get_pickfile_name_correlated(event_id, fopts, phase_type):
|
||||
# return trace_this
|
||||
|
||||
|
||||
def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, phase_params, plot=None, fig_dir=None,
|
||||
ncorr=1, wfdata_highf=None, trace_master_highf=None):
|
||||
def prepare_correlation_input(wfdata: Stream, picks: list, channels: list, trace_master: Trace, pick_master: Pick,
|
||||
phase_params: dict, plot: bool = None, fig_dir: str = None,
|
||||
ncorr: int = 1, wfdata_highf: Stream = None, trace_master_highf: Stream = None) -> list:
|
||||
# prepare input for multiprocessing worker for all 'other' picks to correlate with current master-trace
|
||||
input_list = []
|
||||
|
||||
@ -1152,7 +1158,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master
|
||||
|
||||
|
||||
def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Stream, picks: list, params: dict,
|
||||
channels: list, method: str, fig_dir: str):
|
||||
channels: list, method: str, fig_dir: str) -> Optional[tuple]:
|
||||
"""
|
||||
Correlate all stations with the first available station given in station_list, a list containing central,
|
||||
permanent, long operating, low noise stations with descending priority.
|
||||
@ -1187,7 +1193,7 @@ def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Str
|
||||
|
||||
dt_pre, dt_post = params['dt_stacking']
|
||||
trace_master, nstack = apply_stacking(trace_master, stations4stack, wfdata_raw, picks, method=method,
|
||||
check_rms=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir,
|
||||
do_rms_check=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir,
|
||||
dt_pre=dt_pre, dt_post=dt_post)
|
||||
|
||||
return correlations_dict, nwst_id_master, trace_master, nstack
|
||||
@ -1261,7 +1267,7 @@ def iterate_correlation(wfdata_lowf: Stream, wfdata_highf: Stream, channels: lis
|
||||
|
||||
|
||||
def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, picks: list, method: str,
|
||||
check_rms: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False,
|
||||
do_rms_check: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False,
|
||||
fig_dir: str = None) -> tuple:
|
||||
|
||||
def check_trace_length_and_cut(trace: Trace, pick_time: UTCDateTime = None):
|
||||
@ -1334,7 +1340,8 @@ def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, pi
|
||||
return trace_master, count
|
||||
|
||||
|
||||
def check_rms(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.):
|
||||
def check_rms(traces4stack: list, plot: bool = False, fig_dir: str = None, max_it: int = 10,
|
||||
ntimes_std: float = 5.) -> list:
|
||||
rms_list = []
|
||||
trace_names = []
|
||||
|
||||
@ -1396,14 +1403,14 @@ def plot_rms(rms_list: list, trace_names: list, mean: float, std: float, fig_dir
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def calc_rms(X):
|
||||
def calc_rms(array: np.ndarray) -> float:
|
||||
"""
|
||||
Returns root mean square of a given array LON
|
||||
"""
|
||||
return np.sqrt(np.sum(np.power(X, 2)) / len(X))
|
||||
return np.sqrt(np.sum(np.power(array, 2)) / len(array))
|
||||
|
||||
|
||||
def resample_parallel(stream, freq, ncores):
|
||||
def resample_parallel(stream: Stream, freq: float, ncores: int) -> Stream:
|
||||
input_list = [{'trace': trace, 'freq': freq} for trace in stream.traces]
|
||||
|
||||
parallel = Parallel(n_jobs=ncores)
|
||||
@ -1417,7 +1424,8 @@ def resample_parallel(stream, freq, ncores):
|
||||
return stream
|
||||
|
||||
|
||||
def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination, ncores):
|
||||
def rotate_stream(stream: Stream, metadata: Metadata, origin: Origin, stations_dict: dict, channels: list,
|
||||
inclination: float, ncores: int) -> Stream:
|
||||
""" Rotate stream to LQT. To do this all traces have to be cut to equal lenths"""
|
||||
input_list = []
|
||||
cut_stream_to_same_length(stream)
|
||||
@ -1503,7 +1511,8 @@ def unpack_result(result: list) -> dict:
|
||||
return result_dict
|
||||
|
||||
|
||||
def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, ncores=None):
|
||||
def get_taupy_picks(origin: Origin, stations_dict: dict, pylot_parameter: PylotParameter, phase_type: str,
|
||||
ncores: int = None) -> list:
|
||||
input_list = []
|
||||
|
||||
taup_phases = pylot_parameter['taup_phases'].split(',')
|
||||
@ -1531,7 +1540,7 @@ def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type,
|
||||
return taupy_picks
|
||||
|
||||
|
||||
def create_artificial_picks(taupy_result):
|
||||
def create_artificial_picks(taupy_result: dict) -> list:
|
||||
artificial_picks = []
|
||||
for nwst_id, taupy_dict in taupy_result.items():
|
||||
nw, st = nwst_id.split('.')
|
||||
@ -1543,14 +1552,14 @@ def create_artificial_picks(taupy_result):
|
||||
return artificial_picks
|
||||
|
||||
|
||||
def check_taupy_phases(taupy_result):
|
||||
def check_taupy_phases(taupy_result: dict) -> None:
|
||||
test_phase_name = list(taupy_result.values())[0]['phase_name']
|
||||
phase_names_equal = [item['phase_name'] == test_phase_name for item in taupy_result.values()]
|
||||
if not all(phase_names_equal):
|
||||
logging.info('Different first arriving phases detected in TauPy phases for this event.')
|
||||
|
||||
|
||||
def taupy_worker(input_dict):
|
||||
def taupy_worker(input_dict: dict) -> dict:
|
||||
model = input_dict['model']
|
||||
taupy_input = input_dict['taupy_input']
|
||||
source_time = input_dict['source_time']
|
||||
@ -1570,7 +1579,7 @@ def taupy_worker(input_dict):
|
||||
return output_dict
|
||||
|
||||
|
||||
def resample_worker(input_dict):
|
||||
def resample_worker(input_dict: dict) -> Trace:
|
||||
trace = input_dict['trace']
|
||||
freq = input_dict['freq']
|
||||
if freq == trace.stats.sampling_rate:
|
||||
@ -1578,7 +1587,7 @@ def resample_worker(input_dict):
|
||||
return trace.resample(freq, no_filter=freq > trace.stats.sampling_rate)
|
||||
|
||||
|
||||
def rotation_worker(input_dict):
|
||||
def rotation_worker(input_dict: dict) -> Optional[Stream]:
|
||||
stream = input_dict['stream']
|
||||
tstart = max([tr.stats.starttime for tr in stream])
|
||||
tend = min([tr.stats.endtime for tr in stream])
|
||||
@ -1595,7 +1604,7 @@ def rotation_worker(input_dict):
|
||||
return stream
|
||||
|
||||
|
||||
def correlation_worker(input_dict):
|
||||
def correlation_worker(input_dict: dict) -> dict:
|
||||
"""worker function for multiprocessing"""
|
||||
|
||||
# unpack input dictionary
|
||||
@ -1791,7 +1800,7 @@ def plot_stacked_trace_pick(trace: Trace, pick: Pick, pylot_parameter: PylotPara
|
||||
|
||||
|
||||
def plot_stations(stations_dict: dict, stations2compare: dict, coords_this: dict, corr_result: dict, trace: Trace,
|
||||
pick: Pick, window_title: str = None):
|
||||
pick: Pick, window_title: str = None) -> Optional[tuple]:
|
||||
""" Plot routine to check proximity algorithm. """
|
||||
|
||||
title = trace.id
|
||||
@ -1864,8 +1873,7 @@ def get_data(eventdir: str, data_dir: str, headonly: bool = False) -> Stream:
|
||||
return wfdata
|
||||
|
||||
|
||||
def get_closest_stations(coords, stations_dict, n):
|
||||
""" Calculate distances and return n closest stations in stations_dict for station at coords. """
|
||||
def get_station_distances(stations_dict: dict, coords: dict) -> dict:
|
||||
distances = {}
|
||||
for station_id, st_coords in stations_dict.items():
|
||||
dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'],
|
||||
@ -1875,28 +1883,28 @@ def get_closest_stations(coords, stations_dict, n):
|
||||
continue
|
||||
distances[dist] = station_id
|
||||
|
||||
return distances
|
||||
|
||||
|
||||
def get_closest_stations(coords: dict, stations_dict: dict, n: int) -> dict:
|
||||
""" Calculate distances and return the n closest stations in stations_dict for station at coords. """
|
||||
distances = get_station_distances(stations_dict, coords)
|
||||
|
||||
closest_distances = sorted(list(distances.keys()))[:n + 1]
|
||||
closest_stations = {station: dist for dist, station in distances.items() if dist in closest_distances}
|
||||
return closest_stations
|
||||
|
||||
|
||||
def get_random_stations(coords, stations_dict, n):
|
||||
def get_random_stations(coords: dict, stations_dict: dict, n: int) -> dict:
|
||||
""" Calculate distances and return n randomly selected stations in stations_dict for station at coords. """
|
||||
random_keys = random.sample(list(stations_dict.keys()), n)
|
||||
distances = {}
|
||||
for station_id, st_coords in stations_dict.items():
|
||||
dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'],
|
||||
a=6.371e6, f=0)[0]
|
||||
# exclude same coordinate (self or other instrument at same location)
|
||||
if dist == 0:
|
||||
continue
|
||||
distances[dist] = station_id
|
||||
distances = get_station_distances(stations_dict, coords)
|
||||
|
||||
random_stations = {station: dist for dist, station in distances.items() if station in random_keys}
|
||||
return random_stations
|
||||
|
||||
|
||||
def prepare_corr_params(parse_args, logger):
|
||||
def prepare_corr_params(parse_args: argparse, logger: logging.Logger) -> dict:
|
||||
with open(parse_args.params) as infile:
|
||||
parameters_dict = yaml.safe_load(infile)
|
||||
|
||||
@ -1921,7 +1929,7 @@ def prepare_corr_params(parse_args, logger):
|
||||
return corr_params
|
||||
|
||||
|
||||
def init_logger():
|
||||
def init_logger() -> logging.Logger:
|
||||
logger = logging.getLogger()
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
fhandler = logging.FileHandler('pick_corr.out')
|
||||
|
Loading…
Reference in New Issue
Block a user