[refactor] finished annotations (type hints)
This commit is contained in:
		
							parent
							
								
									759e7bb848
								
							
						
					
					
						commit
						176e93d833
					
				| @ -11,7 +11,7 @@ import traceback | ||||
| import glob | ||||
| import json | ||||
| from datetime import datetime | ||||
| from typing import Any | ||||
| from typing import Any, Optional | ||||
| 
 | ||||
| import numpy as np | ||||
| import matplotlib.pyplot as plt | ||||
| @ -19,6 +19,7 @@ import yaml | ||||
| 
 | ||||
| from joblib import Parallel, delayed | ||||
| from obspy import read, Stream, UTCDateTime, Trace | ||||
| from obspy.core.event import Event | ||||
| from obspy.taup import TauPyModel | ||||
| from obspy.core.event.base import WaveformStreamID, ResourceIdentifier | ||||
| from obspy.core.event.origin import Pick, Origin | ||||
| @ -28,6 +29,7 @@ from obspy.signal.cross_correlation import correlate, xcorr_max | ||||
| from pylot.core.io.inputs import PylotParameter | ||||
| from pylot.core.io.phases import picks_from_picksdict | ||||
| from pylot.core.pick.autopick import autopickstation | ||||
| from pylot.core.util.dataprocessing import Metadata | ||||
| from pylot.core.util.utils import check4rotated | ||||
| from pylot.core.util.utils import identifyPhaseID | ||||
| from pylot.core.util.event import Event as PylotEvent | ||||
| @ -771,7 +773,7 @@ def correlate_event(eventdir: str, pylot_parameter: PylotParameter, params: dict | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| def remove_outliers(picks, corrected_taupy_picks, threshold): | ||||
| def remove_outliers(picks: list, corrected_taupy_picks: list, threshold: float) -> list: | ||||
|     """ delete a pick if difference to corrected taupy pick is larger than threshold""" | ||||
| 
 | ||||
|     n_picks = len(picks) | ||||
| @ -800,7 +802,7 @@ def remove_outliers(picks, corrected_taupy_picks, threshold): | ||||
|     return picks | ||||
| 
 | ||||
| 
 | ||||
| def remove_invalid_picks(picks): | ||||
| def remove_invalid_picks(picks: list) -> list: | ||||
|     """ Remove picks without uncertainty (invalid PyLoT picks)""" | ||||
|     count = 0 | ||||
|     deleted_picks_ids = [] | ||||
| @ -815,15 +817,15 @@ def remove_invalid_picks(picks): | ||||
|     return picks | ||||
| 
 | ||||
| 
 | ||||
| def get_picks_median(picks): | ||||
| def get_picks_median(picks: list) -> UTCDateTime: | ||||
|     return UTCDateTime(int(np.median([pick.time.timestamp for pick in picks if pick.time]))) | ||||
| 
 | ||||
| 
 | ||||
| def get_picks_mean(picks): | ||||
| def get_picks_mean(picks: list) -> UTCDateTime: | ||||
|     return UTCDateTime(np.mean([pick.time.timestamp for pick in picks if pick.time])) | ||||
| 
 | ||||
| 
 | ||||
| def get_corrected_taupy_picks(picks, taupypicks, all_available=False): | ||||
| def get_corrected_taupy_picks(picks: list, taupypicks: list, all_available: bool = False) -> tuple: | ||||
|     """ get mean/median from picks taupy picks, correct latter for the difference """ | ||||
| 
 | ||||
|     def nwst_id_from_wfid(wfid): | ||||
| @ -859,7 +861,7 @@ def get_corrected_taupy_picks(picks, taupypicks, all_available=False): | ||||
|     return taupypicks_new, median_diff | ||||
| 
 | ||||
| 
 | ||||
| def load_stacked_trace(eventdir, min_corr_stack): | ||||
| def load_stacked_trace(eventdir: str, min_corr_stack: float) -> Optional[tuple]: | ||||
|     # load stacked stream (miniseed) | ||||
|     str_stack_fpaths = glob.glob(os.path.join(eventdir, 'correlation', '*_stacked.mseed')) | ||||
|     if not len(str_stack_fpaths) == 1: | ||||
| @ -888,7 +890,8 @@ def load_stacked_trace(eventdir, min_corr_stack): | ||||
|     return corr_dict, nwst_id, stacked_trace, nstack | ||||
| 
 | ||||
| 
 | ||||
| def repick_master_trace(wfdata, trace_master, pylot_parameter, event, event_id, metadata, phase_type, corr_out_dir): | ||||
| def repick_master_trace(wfdata: Stream, trace_master: Trace, pylot_parameter: PylotParameter, event: Event, | ||||
|                         event_id: str, metadata: Metadata, phase_type: str, corr_out_dir: str) -> Optional[Pick]: | ||||
|     rename_lqt = phase_type == 'S' | ||||
|     # create an 'artificial' stream object which can be used as input for PyLoT | ||||
|     stream_master = modify_master_trace4pick(trace_master.copy(), wfdata, rename_lqt=rename_lqt) | ||||
| @ -925,14 +928,14 @@ def create_correlation_output_dir(eventdir: str, fopts: dict, phase_type: str) - | ||||
|     return export_path | ||||
| 
 | ||||
| 
 | ||||
| def create_correlation_figure_dir(correlation_out_dir): | ||||
| def create_correlation_figure_dir(correlation_out_dir: str) -> str: | ||||
|     export_path = os.path.join(correlation_out_dir, 'figures') | ||||
|     if not os.path.isdir(export_path): | ||||
|         os.mkdir(export_path) | ||||
|     return export_path | ||||
| 
 | ||||
| 
 | ||||
| def add_fpath_extension(fpath, fopts, phase): | ||||
| def add_fpath_extension(fpath: str, fopts: dict, phase: str) -> str: | ||||
|     if fopts: | ||||
|         freqmin, freqmax = fopts['freqmin'], fopts['freqmax'] | ||||
|         if freqmin: | ||||
| @ -944,12 +947,12 @@ def add_fpath_extension(fpath, fopts, phase): | ||||
|     return fpath | ||||
| 
 | ||||
| 
 | ||||
| def write_correlation_output(export_path, correlations_dict, correlations_dict_stacked): | ||||
| def write_correlation_output(export_path: str, correlations_dict: dict, correlations_dict_stacked: dict) -> None: | ||||
|     write_json(correlations_dict, os.path.join(export_path, 'correlations_for_stacking.json')) | ||||
|     write_json(correlations_dict_stacked, os.path.join(export_path, 'correlation_results.json')) | ||||
| 
 | ||||
| 
 | ||||
| def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True): | ||||
| def modify_master_trace4pick(trace_master: Trace, wfdata: Stream, rename_lqt: bool = True) -> Stream: | ||||
|     """ | ||||
|     Create an artificial Stream with master trace instead of the trace of its corresponding channel. | ||||
|     This is done to find metadata for correct station metadata which were modified when stacking (e.g. loc=ST) | ||||
| @ -985,7 +988,8 @@ def modify_master_trace4pick(trace_master, wfdata, rename_lqt=True): | ||||
|     return stream_master | ||||
| 
 | ||||
| 
 | ||||
| def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_type, pf_extension): | ||||
| def export_picks(eventdir: str, correlations_dict: dict, picks: list, taupypicks: list, params: dict, | ||||
|                  phase_type: str, pf_extension: str) -> None: | ||||
|     threshold = params['export_threshold'] | ||||
|     min_picks_export = params['min_picks_export'] | ||||
|     # make copy so that modified picks are not exported | ||||
| @ -1057,7 +1061,8 @@ def export_picks(eventdir, correlations_dict, picks, taupypicks, params, phase_t | ||||
|     logging.info('Wrote {} correlated picks to file {}'.format(len(event.picks), fpath)) | ||||
| 
 | ||||
| 
 | ||||
| def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='corrected_taup_times'): | ||||
| def write_taupy_picks(event: Event, eventdir: str, taupypicks: list, time_shift: float, | ||||
|                       extension: str = 'corrected_taup_times') -> None: | ||||
|     # make copies because both objects are being modified | ||||
|     event = copy.deepcopy(event) | ||||
|     taupypicks = copy.deepcopy(taupypicks) | ||||
| @ -1079,17 +1084,17 @@ def write_taupy_picks(event, eventdir, taupypicks, time_shift, extension='correc | ||||
|     write_event(event, eventdir, fname) | ||||
| 
 | ||||
| 
 | ||||
| def write_event(event, eventdir, fname): | ||||
| def write_event(event: Event, eventdir: str, fname: str) -> str: | ||||
|     fpath = os.path.join(eventdir, fname) | ||||
|     event.write(fpath, format='QUAKEML') | ||||
|     return fpath | ||||
| 
 | ||||
| 
 | ||||
| def get_pickfile_name(event_id, fname_extension): | ||||
| def get_pickfile_name(event_id: str, fname_extension: str) -> str: | ||||
|     return 'PyLoT_{}_{}.xml'.format(event_id, fname_extension) | ||||
| 
 | ||||
| 
 | ||||
| def get_pickfile_name_correlated(event_id, fopts, phase_type): | ||||
| def get_pickfile_name_correlated(event_id: str, fopts: dict, phase_type: str) -> str: | ||||
|     fname_extension = add_fpath_extension('correlated', fopts, phase_type) | ||||
|     return get_pickfile_name(event_id, fname_extension) | ||||
| 
 | ||||
| @ -1110,8 +1115,9 @@ def get_pickfile_name_correlated(event_id, fopts, phase_type): | ||||
| #     return trace_this | ||||
| 
 | ||||
| 
 | ||||
| def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master, phase_params, plot=None, fig_dir=None, | ||||
|                               ncorr=1, wfdata_highf=None, trace_master_highf=None): | ||||
| def prepare_correlation_input(wfdata: Stream, picks: list, channels: list, trace_master: Trace, pick_master: Pick, | ||||
|                               phase_params: dict, plot: bool = None, fig_dir: str = None, | ||||
|                               ncorr: int = 1, wfdata_highf: Stream = None, trace_master_highf: Stream = None) -> list: | ||||
|     # prepare input for multiprocessing worker for all 'other' picks to correlate with current master-trace | ||||
|     input_list = [] | ||||
| 
 | ||||
| @ -1152,7 +1158,7 @@ def prepare_correlation_input(wfdata, picks, channels, trace_master, pick_master | ||||
| 
 | ||||
| 
 | ||||
| def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Stream, picks: list, params: dict, | ||||
|                       channels: list, method: str, fig_dir: str): | ||||
|                       channels: list, method: str, fig_dir: str) -> Optional[tuple]: | ||||
|     """ | ||||
|     Correlate all stations with the first available station given in station_list, a list containing central, | ||||
|     permanent, long operating, low noise stations with descending priority. | ||||
| @ -1187,7 +1193,7 @@ def stack_mastertrace(wfdata_lowf: Stream, wfdata_highf: Stream, wfdata_raw: Str | ||||
| 
 | ||||
|     dt_pre, dt_post = params['dt_stacking'] | ||||
|     trace_master, nstack = apply_stacking(trace_master, stations4stack, wfdata_raw, picks, method=method, | ||||
|                                           check_rms=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir, | ||||
|                                           do_rms_check=params['check_RMS'], plot=params['plot'], fig_dir=fig_dir, | ||||
|                                           dt_pre=dt_pre, dt_post=dt_post) | ||||
| 
 | ||||
|     return correlations_dict, nwst_id_master, trace_master, nstack | ||||
| @ -1261,7 +1267,7 @@ def iterate_correlation(wfdata_lowf: Stream, wfdata_highf: Stream, channels: lis | ||||
| 
 | ||||
| 
 | ||||
| def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, picks: list, method: str, | ||||
|                    check_rms: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False, | ||||
|                    do_rms_check: bool, dt_pre: float = 250., dt_post: float = 250., plot: bool = False, | ||||
|                    fig_dir: str = None) -> tuple: | ||||
| 
 | ||||
|     def check_trace_length_and_cut(trace: Trace, pick_time: UTCDateTime = None): | ||||
| @ -1334,7 +1340,8 @@ def apply_stacking(trace_master: Trace, stations4stack: dict, wfdata: Stream, pi | ||||
|     return trace_master, count | ||||
| 
 | ||||
| 
 | ||||
| def check_rms(traces4stack, plot=False, fig_dir=None, max_it=10, ntimes_std=5.): | ||||
| def check_rms(traces4stack: list, plot: bool = False, fig_dir: str = None, max_it: int = 10, | ||||
|               ntimes_std: float = 5.) -> list: | ||||
|     rms_list = [] | ||||
|     trace_names = [] | ||||
| 
 | ||||
| @ -1396,14 +1403,14 @@ def plot_rms(rms_list: list, trace_names: list, mean: float, std: float, fig_dir | ||||
|     plt.close(fig) | ||||
| 
 | ||||
| 
 | ||||
| def calc_rms(X): | ||||
| def calc_rms(array: np.ndarray) -> float: | ||||
|     """ | ||||
|     Returns root mean square of a given array LON | ||||
|     """ | ||||
|     return np.sqrt(np.sum(np.power(X, 2)) / len(X)) | ||||
|     return np.sqrt(np.sum(np.power(array, 2)) / len(array)) | ||||
| 
 | ||||
| 
 | ||||
| def resample_parallel(stream, freq, ncores): | ||||
| def resample_parallel(stream: Stream, freq: float, ncores: int) -> Stream: | ||||
|     input_list = [{'trace': trace, 'freq': freq} for trace in stream.traces] | ||||
| 
 | ||||
|     parallel = Parallel(n_jobs=ncores) | ||||
| @ -1417,7 +1424,8 @@ def resample_parallel(stream, freq, ncores): | ||||
|     return stream | ||||
| 
 | ||||
| 
 | ||||
| def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination, ncores): | ||||
| def rotate_stream(stream: Stream, metadata: Metadata, origin: Origin, stations_dict: dict, channels: list, | ||||
|                   inclination: float, ncores: int) -> Stream: | ||||
|     """ Rotate stream to LQT. To do this all traces have to be cut to equal lenths""" | ||||
|     input_list = [] | ||||
|     cut_stream_to_same_length(stream) | ||||
| @ -1503,7 +1511,8 @@ def unpack_result(result: list) -> dict: | ||||
|     return result_dict | ||||
| 
 | ||||
| 
 | ||||
| def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, ncores=None): | ||||
| def get_taupy_picks(origin: Origin, stations_dict: dict, pylot_parameter: PylotParameter, phase_type: str, | ||||
|                     ncores: int = None) -> list: | ||||
|     input_list = [] | ||||
| 
 | ||||
|     taup_phases = pylot_parameter['taup_phases'].split(',') | ||||
| @ -1531,7 +1540,7 @@ def get_taupy_picks(origin: Origin, stations_dict, pylot_parameter, phase_type, | ||||
|     return taupy_picks | ||||
| 
 | ||||
| 
 | ||||
| def create_artificial_picks(taupy_result): | ||||
| def create_artificial_picks(taupy_result: dict) -> list: | ||||
|     artificial_picks = [] | ||||
|     for nwst_id, taupy_dict in taupy_result.items(): | ||||
|         nw, st = nwst_id.split('.') | ||||
| @ -1543,14 +1552,14 @@ def create_artificial_picks(taupy_result): | ||||
|     return artificial_picks | ||||
| 
 | ||||
| 
 | ||||
| def check_taupy_phases(taupy_result): | ||||
| def check_taupy_phases(taupy_result: dict) -> None: | ||||
|     test_phase_name = list(taupy_result.values())[0]['phase_name'] | ||||
|     phase_names_equal = [item['phase_name'] == test_phase_name for item in taupy_result.values()] | ||||
|     if not all(phase_names_equal): | ||||
|         logging.info('Different first arriving phases detected in TauPy phases for this event.') | ||||
| 
 | ||||
| 
 | ||||
| def taupy_worker(input_dict): | ||||
| def taupy_worker(input_dict: dict) -> dict: | ||||
|     model = input_dict['model'] | ||||
|     taupy_input = input_dict['taupy_input'] | ||||
|     source_time = input_dict['source_time'] | ||||
| @ -1570,7 +1579,7 @@ def taupy_worker(input_dict): | ||||
|     return output_dict | ||||
| 
 | ||||
| 
 | ||||
| def resample_worker(input_dict): | ||||
| def resample_worker(input_dict: dict) -> Trace: | ||||
|     trace = input_dict['trace'] | ||||
|     freq = input_dict['freq'] | ||||
|     if freq == trace.stats.sampling_rate: | ||||
| @ -1578,7 +1587,7 @@ def resample_worker(input_dict): | ||||
|     return trace.resample(freq, no_filter=freq > trace.stats.sampling_rate) | ||||
| 
 | ||||
| 
 | ||||
| def rotation_worker(input_dict): | ||||
| def rotation_worker(input_dict: dict) -> Optional[Stream]: | ||||
|     stream = input_dict['stream'] | ||||
|     tstart = max([tr.stats.starttime for tr in stream]) | ||||
|     tend = min([tr.stats.endtime for tr in stream]) | ||||
| @ -1595,7 +1604,7 @@ def rotation_worker(input_dict): | ||||
|     return stream | ||||
| 
 | ||||
| 
 | ||||
| def correlation_worker(input_dict): | ||||
| def correlation_worker(input_dict: dict) -> dict: | ||||
|     """worker function for multiprocessing""" | ||||
| 
 | ||||
|     # unpack input dictionary | ||||
| @ -1791,7 +1800,7 @@ def plot_stacked_trace_pick(trace: Trace, pick: Pick, pylot_parameter: PylotPara | ||||
| 
 | ||||
| 
 | ||||
| def plot_stations(stations_dict: dict, stations2compare: dict, coords_this: dict, corr_result: dict, trace: Trace, | ||||
|                   pick: Pick, window_title: str = None): | ||||
|                   pick: Pick, window_title: str = None) -> Optional[tuple]: | ||||
|     """ Plot routine to check proximity algorithm. """ | ||||
| 
 | ||||
|     title = trace.id | ||||
| @ -1864,8 +1873,7 @@ def get_data(eventdir: str, data_dir: str, headonly: bool = False) -> Stream: | ||||
|     return wfdata | ||||
| 
 | ||||
| 
 | ||||
| def get_closest_stations(coords, stations_dict, n): | ||||
|     """ Calculate distances and return n closest stations in stations_dict for station at coords. """ | ||||
| def get_station_distances(stations_dict: dict, coords: dict) -> dict: | ||||
|     distances = {} | ||||
|     for station_id, st_coords in stations_dict.items(): | ||||
|         dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'], | ||||
| @ -1875,28 +1883,28 @@ def get_closest_stations(coords, stations_dict, n): | ||||
|             continue | ||||
|         distances[dist] = station_id | ||||
| 
 | ||||
|     return distances | ||||
| 
 | ||||
| 
 | ||||
| def get_closest_stations(coords: dict, stations_dict: dict, n: int) -> dict: | ||||
|     """ Calculate distances and return the n closest stations in stations_dict for station at coords. """ | ||||
|     distances = get_station_distances(stations_dict, coords) | ||||
| 
 | ||||
|     closest_distances = sorted(list(distances.keys()))[:n + 1] | ||||
|     closest_stations = {station: dist for dist, station in distances.items() if dist in closest_distances} | ||||
|     return closest_stations | ||||
| 
 | ||||
| 
 | ||||
| def get_random_stations(coords, stations_dict, n): | ||||
| def get_random_stations(coords: dict, stations_dict: dict, n: int) -> dict: | ||||
|     """ Calculate distances and return n randomly selected stations in stations_dict for station at coords. """ | ||||
|     random_keys = random.sample(list(stations_dict.keys()), n) | ||||
|     distances = {} | ||||
|     for station_id, st_coords in stations_dict.items(): | ||||
|         dist = gps2dist_azimuth(coords['latitude'], coords['longitude'], st_coords['latitude'], st_coords['longitude'], | ||||
|                                 a=6.371e6, f=0)[0] | ||||
|         # exclude same coordinate (self or other instrument at same location) | ||||
|         if dist == 0: | ||||
|             continue | ||||
|         distances[dist] = station_id | ||||
|     distances = get_station_distances(stations_dict, coords) | ||||
| 
 | ||||
|     random_stations = {station: dist for dist, station in distances.items() if station in random_keys} | ||||
|     return random_stations | ||||
| 
 | ||||
| 
 | ||||
| def prepare_corr_params(parse_args, logger): | ||||
| def prepare_corr_params(parse_args: argparse, logger: logging.Logger) -> dict: | ||||
|     with open(parse_args.params) as infile: | ||||
|         parameters_dict = yaml.safe_load(infile) | ||||
| 
 | ||||
| @ -1921,7 +1929,7 @@ def prepare_corr_params(parse_args, logger): | ||||
|     return corr_params | ||||
| 
 | ||||
| 
 | ||||
| def init_logger(): | ||||
| def init_logger() -> logging.Logger: | ||||
|     logger = logging.getLogger() | ||||
|     handler = logging.StreamHandler(sys.stdout) | ||||
|     fhandler = logging.FileHandler('pick_corr.out') | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user