[update] WIP: Adding type hints, docstrings etc.

This commit is contained in:
Marcel Paffrath 2024-08-06 16:03:50 +02:00
parent ce71c549ca
commit 1b074d14ff

View File

@ -7,6 +7,7 @@ import copy
import logging
import random
import traceback
from typing import Any
import matplotlib.pyplot as plt
import yaml
@ -14,7 +15,7 @@ import yaml
from code_base.fmtomo_tools.fmtomo_teleseismic_utils import *
from code_base.util.utils import get_metadata
from joblib import Parallel, delayed
from obspy import read, Stream
from obspy import read, Stream, Trace
from obspy.core.event.base import WaveformStreamID, ResourceIdentifier
from obspy.core.event.origin import Pick
from obspy.signal.cross_correlation import correlate, xcorr_max
@ -27,8 +28,10 @@ from pylot.core.util.utils import identifyPhaseID
from pylot.core.util.event import Event as PylotEvent
class CorrelationParameters(object):
class CorrelationParameters:
"""
Class to read, store and access correlation parameters from yaml file.
:param filter_type: filter type for data (e.g. bandpass)
:param filter_options: filter options (min/max_freq etc.)
:param filter_options_final: filter options for second iteration, final pick (min/max_freq etc.)
@ -59,35 +62,36 @@ class CorrelationParameters(object):
self.add_parameters(**kwargs)
# String representation of the object
def __repr__(self):
return "CorrelationParameters "
def __repr__(self) -> str:
return 'CorrelationParameters '
# Boolean test
def __nonzero__(self):
def __nonzero__(self) -> bool:
return bool(self.__parameter)
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.__parameter.get(key)
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
self.__parameter[key] = value
def __delitem__(self, key):
def __delitem__(self, key: str):
del self.__parameter[key]
def __iter__(self):
def __iter__(self) -> Any:
return iter(self.__parameter)
def __len__(self):
def __len__(self) -> int:
return len(self.__parameter.keys())
def add_parameters(self, **kwargs):
def add_parameters(self, **kwargs) -> None:
for key, value in kwargs.items():
self.__parameter[key] = value
class XcorrPickCorrection:
def __init__(self, pick1, trace1, pick2, trace2, t_before, t_after, cc_maxlag, frac_max=0.5):
def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace,
t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5):
"""
MP MP : Modified version of obspy xcorr_pick_correction
@ -95,20 +99,6 @@ class XcorrPickCorrection:
correlation of the waveforms in narrow windows around the pick times.
For details on the fitting procedure refer to [Deichmann1992]_.
The parameters depend on the epicentral distance and magnitude range. For
small local earthquakes (Ml ~0-2, distance ~3-10 km) with consistent manual
picks the following can be tried::
t_before=0.05, t_after=0.2, cc_maxlag=0.10,
The appropriate parameter sets can and should be determined/verified
visually using the option `plot=True` on a representative set of picks.
To get the corrected differential pick time calculate: ``((pick2 +
pick2_corr) - pick1)``. To get a corrected differential travel time using
origin times for both events calculate: ``((pick2 + pick2_corr - ot2) -
(pick1 - ot1))``
:type pick1: :class:`~obspy.core.utcdatetime.UTCDateTime`
:param pick1: Time of pick for `trace1`.
:type trace1: :class:`~obspy.core.trace.Trace`
@ -128,9 +118,6 @@ class XcorrPickCorrection:
:type cc_maxlag: float
:param cc_maxlag: Maximum lag/shift time tested during cross correlation in
seconds.
:rtype: (float, float)
:returns: Correction time `pick2_corr` for `pick2` pick time as a float and
corresponding correlation coefficient.
"""
self.trace1 = trace1
@ -153,58 +140,123 @@ class XcorrPickCorrection:
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
def check_traces(self):
def check_traces(self) -> None:
"""
Check if the sampling rates of two traces match, raise an exception if they don't.
Raise an exception if any of the traces is empty. Set the sampling rate attribute.
"""
if self.trace1.stats.sampling_rate != self.trace2.stats.sampling_rate:
msg = "Sampling rates do not match: %s != %s" % (
self.trace1.stats.sampling_rate, self.trace2.stats.sampling_rate)
msg = f'Sampling rates do not match: {self.trace1.stats.sampling_rate} != {self.trace2.stats.sampling_rate}'
raise Exception(msg)
for trace in (self.trace1, self.trace2):
if len(trace) == 0:
raise Exception(f'Trace {trace} is empty')
self.samp_rate = self.trace1.stats.sampling_rate
def slice_trace(self, tr, pick):
def slice_trace(self, tr, pick) -> Trace:
"""
Method to slice a given trace around a specified pick time.
Parameters:
- tr: Trace object representing the seismic data
- pick: The pick time around which to slice the trace
Returns:
- Trace sliced around the specified pick time
"""
start = pick - self.t_before - (self.cc_maxlag / 2.0)
end = pick + self.t_after + (self.cc_maxlag / 2.0)
# check if necessary time spans are present in data
if tr.stats.starttime > start:
msg = f"Trace {tr.id} starts too late."
msg = f"Trace {tr.id} starts too late. Decrease t_before or cc_maxlag."
logging.debug(f'start: {start}, t_before: {self.t_before}, cc_maxlag: {self.cc_maxlag},'
f'pick: {pick}')
raise Exception(msg)
if tr.stats.endtime < end:
msg = f"Trace {tr.id} ends too early."
msg = f"Trace {tr.id} ends too early. Deacrease t_after or cc_maxlag."
logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},'
f'pick: {pick}')
raise Exception(msg)
# apply signal processing and take correct slice of data
return tr.slice(start, end)
def cross_correlation(self, plot, fig_dir, plot_name, min_corr=None):
def cross_correlation(self, plot: bool, fig_dir: str, plot_name: str, min_corr: float = None):
"""
Calculate the cross correlation between two traces (self.trace1 and self.trace2) and return
the corrected pick time, correlation coefficient, uncertainty, and full width at half maximum.
def get_timeaxis(trace):
Parameters:
plot (bool): A boolean indicating whether to generate a plot or not.
fig_dir (str): The directory to save the plot.
plot_name (str): The name to use for the plot.
min_corr (float, optional): The minimum correlation coefficient allowed.
Returns:
tuple: A tuple containing the corrected pick time, correlation coefficient, uncertainty
and full width at half maximum.
"""
def get_timeaxis(trace: Trace) -> np.ndarray:
"""
Generate a time axis array based on the given trace object.
Parameters:
trace (object): The trace object to generate the time axis from.
Returns:
array: A numpy array representing the time axis.
"""
return np.linspace(0,trace.stats.endtime -trace.stats.starttime, trace.stats.npts)
def plot_cc(fig_dir, plot_name):
if fig_dir and os.path.isdir(fig_dir):
filename = os.path.join(fig_dir, 'corr_{}_{}.svg'.format(self.trace2.id, plot_name))
def plot_cc(figure_output_dir: str, plot_filename: str) -> None:
"""
Generate a plot for the correlation of two traces and save it to a specified file if the directory exists.
Parameters:
- figure_output_dir: str, the directory where the plot will be saved
- plot_filename: str, the name of the plot file
Returns:
- None
"""
if figure_output_dir and os.path.isdir(figure_output_dir):
filename = os.path.join(figure_output_dir, 'corr_{}_{}.svg'.format(self.trace2.id, plot_filename))
else:
filename = None
# with MatplotlibBackend(filename and "AGG" or None, sloppy=True):
# Create figure object, first subplot axis and timeaxis for both traces
fig = plt.figure(figsize=(16, 9))
ax1 = fig.add_subplot(211)
tmp_t1 = get_timeaxis(self.tr1_slice)
tmp_t2 = get_timeaxis(self.tr2_slice)
# MP MP normalize slices (not only by positive maximum!
# MP MP normalize slices (not only by positive maximum!)
tr1_slice_norm = self.tr1_slice.copy().normalize()
tr2_slice_norm = self.tr2_slice.copy().normalize()
# Plot normalized traces to first subplot: Trace to correct, reference trace
# and trace shifted by correlation maximum
ax1.plot(tmp_t1, tr1_slice_norm.data, "b", label="Trace 1 (reference)", lw=0.75)
ax1.plot(tmp_t2, tr2_slice_norm.data, "g--", label="Trace 2 (pick shifted)", lw=0.75)
ax1.plot(tmp_t2 - dt, tr2_slice_norm.data, "k", label="Trace 2 (correlation shifted)", lw=1.)
# get relative pick time from trace 1 (reference trace) for plotting which is the same for all three
# traces in the plot which are aligned by their pick times for comparison
delta_pick_ref = (self.pick1 - self.tr1_slice.stats.starttime)
# correct pick time shift in traces for trace1
ax1.axvline(delta_pick_ref, color='k', linestyle='dashed', label='Pick', lw=0.5)
# plot uncertainty around pick
ylims = ax1.get_ylim()
ax1.fill_between([delta_pick_ref - uncert, delta_pick_ref + uncert], ylims[0], ylims[1], alpha=.25,
color='g', label='pick uncertainty)'.format(self.frac_max))
# add legend, title, labels
ax1.legend(loc="lower right", prop={'size': "small"})
ax1.set_title("Correlated {} with {}".format(self.tr2_slice.id, self.tr1_slice.id))
ax1.set_xlabel("time [s]")
ax1.set_ylabel("norm. amplitude")
# Plot cross correlation to second subplot
ax2 = fig.add_subplot(212)
ax2.plot(cc_t, cc_convex, ls="", marker=".", color="k", label="xcorr (convex)")
ax2.plot(cc_t, cc_concave, ls="", marker=".", color="0.7", label="xcorr (concave)")
@ -270,6 +322,7 @@ class XcorrPickCorrection:
if num_samples < 5:
msg = "Less than 5 samples selected for fit to cross " + "correlation: %s" % num_samples
logging.debug(msg)
logging.info('Not enough samples for polyfit. Consider increasing sampling frequency.')
# quadratic fit for small subwindow
coeffs, cov_mat = np.polyfit(cc_t[first_sample:last_sample + 1], cc[first_sample:last_sample + 1], deg=2,
@ -484,9 +537,6 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
# get a dictionary containing coordinates for all sources
stations_dict = metadata.get_all_coordinates()
# read processed (restituted) data (assume P and S are the same...)
wfdata_raw = get_data(eventdir, params['P']['data_dir'], headonly=False)
# create dictionaries for final export of P and S phases together
# ps_correlation_dict = {}
# ps_taup_picks = {}
@ -495,6 +545,9 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
# iterate over P and S to create one model each
for phase_type in params.keys():
# read processed (restituted) data
wfdata_raw = get_data(eventdir, params[phase_type]['data_dir'], headonly=False)
ncores = params[phase_type]['ncores']
filter_type = params[phase_type]['filter_type']
filter_options = params[phase_type]['filter_options']
@ -1392,7 +1445,7 @@ def taupy_parallel(input_list, ncores):
logging.info('Taupy_parallel: Generated {} parallel jobs.'.format(ncores))
taupy_results = parallel(delayed(taupy_worker)(item) for item in input_list)
logging.info('Parallel execution finished.')
logging.info('Parallel execution finished. Unpacking results...')
return unpack_result(taupy_results)
@ -1402,7 +1455,8 @@ def unpack_result(result):
nerr = 0
for item in result_dict.values():
if item['err']:
logging.debug(item['err'])
logging.debug(f'Found error for {item["nwst_id"]}: {item["err"]}.')
#logging.debug(f'Detailed traceback: {item["exc"]}')
nerr += 1
logging.info('Unpack results: Found {} errors after multiprocessing'.format(nerr))
return result_dict
@ -1462,11 +1516,16 @@ def taupy_worker(input_dict):
try:
arrivals = model.get_travel_times_geo(**taupy_input)
if len(arrivals) == 0:
raise Exception(f'No arrivals found for phase {taupy_input["phase_list"]}. Source time: {source_time} -'
f' Input: {taupy_input}')
first_arrival = arrivals[0]
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=first_arrival.name,
phase_time=source_time + first_arrival.time, phase_dist=first_arrival.distance, err=None, )
phase_time=source_time + first_arrival.time, phase_dist=first_arrival.distance, err=None,
exc=None,)
except Exception as e:
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e))
exc = traceback.format_exc()
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e), exc=exc,)
return output_dict
@ -1475,8 +1534,7 @@ def resample_worker(input_dict):
freq = input_dict['freq']
if freq == trace.stats.sampling_rate:
return trace
else:
return trace.resample(freq, no_filter=False)
return trace.resample(freq, no_filter=freq > trace.stats.sampling_rate)
def rotation_worker(input_dict):