[update] WIP: Adding type hints, docstrings etc.
This commit is contained in:
parent
ce71c549ca
commit
1b074d14ff
@ -7,6 +7,7 @@ import copy
|
||||
import logging
|
||||
import random
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import yaml
|
||||
@ -14,7 +15,7 @@ import yaml
|
||||
from code_base.fmtomo_tools.fmtomo_teleseismic_utils import *
|
||||
from code_base.util.utils import get_metadata
|
||||
from joblib import Parallel, delayed
|
||||
from obspy import read, Stream
|
||||
from obspy import read, Stream, Trace
|
||||
from obspy.core.event.base import WaveformStreamID, ResourceIdentifier
|
||||
from obspy.core.event.origin import Pick
|
||||
from obspy.signal.cross_correlation import correlate, xcorr_max
|
||||
@ -27,8 +28,10 @@ from pylot.core.util.utils import identifyPhaseID
|
||||
from pylot.core.util.event import Event as PylotEvent
|
||||
|
||||
|
||||
class CorrelationParameters(object):
|
||||
class CorrelationParameters:
|
||||
"""
|
||||
Class to read, store and access correlation parameters from yaml file.
|
||||
|
||||
:param filter_type: filter type for data (e.g. bandpass)
|
||||
:param filter_options: filter options (min/max_freq etc.)
|
||||
:param filter_options_final: filter options for second iteration, final pick (min/max_freq etc.)
|
||||
@ -59,35 +62,36 @@ class CorrelationParameters(object):
|
||||
self.add_parameters(**kwargs)
|
||||
|
||||
# String representation of the object
|
||||
def __repr__(self):
|
||||
return "CorrelationParameters "
|
||||
def __repr__(self) -> str:
|
||||
return 'CorrelationParameters '
|
||||
|
||||
# Boolean test
|
||||
def __nonzero__(self):
|
||||
def __nonzero__(self) -> bool:
|
||||
return bool(self.__parameter)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.__parameter.get(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
self.__parameter[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
def __delitem__(self, key: str):
|
||||
del self.__parameter[key]
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Any:
|
||||
return iter(self.__parameter)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.__parameter.keys())
|
||||
|
||||
def add_parameters(self, **kwargs):
|
||||
def add_parameters(self, **kwargs) -> None:
|
||||
for key, value in kwargs.items():
|
||||
self.__parameter[key] = value
|
||||
|
||||
|
||||
class XcorrPickCorrection:
|
||||
def __init__(self, pick1, trace1, pick2, trace2, t_before, t_after, cc_maxlag, frac_max=0.5):
|
||||
def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace,
|
||||
t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5):
|
||||
"""
|
||||
MP MP : Modified version of obspy xcorr_pick_correction
|
||||
|
||||
@ -95,20 +99,6 @@ class XcorrPickCorrection:
|
||||
correlation of the waveforms in narrow windows around the pick times.
|
||||
For details on the fitting procedure refer to [Deichmann1992]_.
|
||||
|
||||
The parameters depend on the epicentral distance and magnitude range. For
|
||||
small local earthquakes (Ml ~0-2, distance ~3-10 km) with consistent manual
|
||||
picks the following can be tried::
|
||||
|
||||
t_before=0.05, t_after=0.2, cc_maxlag=0.10,
|
||||
|
||||
The appropriate parameter sets can and should be determined/verified
|
||||
visually using the option `plot=True` on a representative set of picks.
|
||||
|
||||
To get the corrected differential pick time calculate: ``((pick2 +
|
||||
pick2_corr) - pick1)``. To get a corrected differential travel time using
|
||||
origin times for both events calculate: ``((pick2 + pick2_corr - ot2) -
|
||||
(pick1 - ot1))``
|
||||
|
||||
:type pick1: :class:`~obspy.core.utcdatetime.UTCDateTime`
|
||||
:param pick1: Time of pick for `trace1`.
|
||||
:type trace1: :class:`~obspy.core.trace.Trace`
|
||||
@ -128,9 +118,6 @@ class XcorrPickCorrection:
|
||||
:type cc_maxlag: float
|
||||
:param cc_maxlag: Maximum lag/shift time tested during cross correlation in
|
||||
seconds.
|
||||
:rtype: (float, float)
|
||||
:returns: Correction time `pick2_corr` for `pick2` pick time as a float and
|
||||
corresponding correlation coefficient.
|
||||
"""
|
||||
|
||||
self.trace1 = trace1
|
||||
@ -153,58 +140,123 @@ class XcorrPickCorrection:
|
||||
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
|
||||
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
|
||||
|
||||
def check_traces(self):
|
||||
def check_traces(self) -> None:
|
||||
"""
|
||||
Check if the sampling rates of two traces match, raise an exception if they don't.
|
||||
Raise an exception if any of the traces is empty. Set the sampling rate attribute.
|
||||
"""
|
||||
if self.trace1.stats.sampling_rate != self.trace2.stats.sampling_rate:
|
||||
msg = "Sampling rates do not match: %s != %s" % (
|
||||
self.trace1.stats.sampling_rate, self.trace2.stats.sampling_rate)
|
||||
msg = f'Sampling rates do not match: {self.trace1.stats.sampling_rate} != {self.trace2.stats.sampling_rate}'
|
||||
raise Exception(msg)
|
||||
for trace in (self.trace1, self.trace2):
|
||||
if len(trace) == 0:
|
||||
raise Exception(f'Trace {trace} is empty')
|
||||
self.samp_rate = self.trace1.stats.sampling_rate
|
||||
|
||||
def slice_trace(self, tr, pick):
|
||||
def slice_trace(self, tr, pick) -> Trace:
|
||||
"""
|
||||
Method to slice a given trace around a specified pick time.
|
||||
|
||||
Parameters:
|
||||
- tr: Trace object representing the seismic data
|
||||
- pick: The pick time around which to slice the trace
|
||||
|
||||
Returns:
|
||||
- Trace sliced around the specified pick time
|
||||
"""
|
||||
start = pick - self.t_before - (self.cc_maxlag / 2.0)
|
||||
end = pick + self.t_after + (self.cc_maxlag / 2.0)
|
||||
# check if necessary time spans are present in data
|
||||
if tr.stats.starttime > start:
|
||||
msg = f"Trace {tr.id} starts too late."
|
||||
msg = f"Trace {tr.id} starts too late. Decrease t_before or cc_maxlag."
|
||||
logging.debug(f'start: {start}, t_before: {self.t_before}, cc_maxlag: {self.cc_maxlag},'
|
||||
f'pick: {pick}')
|
||||
raise Exception(msg)
|
||||
if tr.stats.endtime < end:
|
||||
msg = f"Trace {tr.id} ends too early."
|
||||
msg = f"Trace {tr.id} ends too early. Deacrease t_after or cc_maxlag."
|
||||
logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},'
|
||||
f'pick: {pick}')
|
||||
raise Exception(msg)
|
||||
# apply signal processing and take correct slice of data
|
||||
return tr.slice(start, end)
|
||||
|
||||
def cross_correlation(self, plot, fig_dir, plot_name, min_corr=None):
|
||||
def cross_correlation(self, plot: bool, fig_dir: str, plot_name: str, min_corr: float = None):
|
||||
"""
|
||||
Calculate the cross correlation between two traces (self.trace1 and self.trace2) and return
|
||||
the corrected pick time, correlation coefficient, uncertainty, and full width at half maximum.
|
||||
|
||||
def get_timeaxis(trace):
|
||||
Parameters:
|
||||
plot (bool): A boolean indicating whether to generate a plot or not.
|
||||
fig_dir (str): The directory to save the plot.
|
||||
plot_name (str): The name to use for the plot.
|
||||
min_corr (float, optional): The minimum correlation coefficient allowed.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the corrected pick time, correlation coefficient, uncertainty
|
||||
and full width at half maximum.
|
||||
"""
|
||||
|
||||
def get_timeaxis(trace: Trace) -> np.ndarray:
|
||||
"""
|
||||
Generate a time axis array based on the given trace object.
|
||||
|
||||
Parameters:
|
||||
trace (object): The trace object to generate the time axis from.
|
||||
|
||||
Returns:
|
||||
array: A numpy array representing the time axis.
|
||||
"""
|
||||
return np.linspace(0,trace.stats.endtime -trace.stats.starttime, trace.stats.npts)
|
||||
|
||||
def plot_cc(fig_dir, plot_name):
|
||||
if fig_dir and os.path.isdir(fig_dir):
|
||||
filename = os.path.join(fig_dir, 'corr_{}_{}.svg'.format(self.trace2.id, plot_name))
|
||||
def plot_cc(figure_output_dir: str, plot_filename: str) -> None:
|
||||
"""
|
||||
Generate a plot for the correlation of two traces and save it to a specified file if the directory exists.
|
||||
|
||||
Parameters:
|
||||
- figure_output_dir: str, the directory where the plot will be saved
|
||||
- plot_filename: str, the name of the plot file
|
||||
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
if figure_output_dir and os.path.isdir(figure_output_dir):
|
||||
filename = os.path.join(figure_output_dir, 'corr_{}_{}.svg'.format(self.trace2.id, plot_filename))
|
||||
else:
|
||||
filename = None
|
||||
# with MatplotlibBackend(filename and "AGG" or None, sloppy=True):
|
||||
|
||||
# Create figure object, first subplot axis and timeaxis for both traces
|
||||
fig = plt.figure(figsize=(16, 9))
|
||||
ax1 = fig.add_subplot(211)
|
||||
tmp_t1 = get_timeaxis(self.tr1_slice)
|
||||
tmp_t2 = get_timeaxis(self.tr2_slice)
|
||||
# MP MP normalize slices (not only by positive maximum!
|
||||
|
||||
# MP MP normalize slices (not only by positive maximum!)
|
||||
tr1_slice_norm = self.tr1_slice.copy().normalize()
|
||||
tr2_slice_norm = self.tr2_slice.copy().normalize()
|
||||
|
||||
# Plot normalized traces to first subplot: Trace to correct, reference trace
|
||||
# and trace shifted by correlation maximum
|
||||
ax1.plot(tmp_t1, tr1_slice_norm.data, "b", label="Trace 1 (reference)", lw=0.75)
|
||||
ax1.plot(tmp_t2, tr2_slice_norm.data, "g--", label="Trace 2 (pick shifted)", lw=0.75)
|
||||
ax1.plot(tmp_t2 - dt, tr2_slice_norm.data, "k", label="Trace 2 (correlation shifted)", lw=1.)
|
||||
|
||||
# get relative pick time from trace 1 (reference trace) for plotting which is the same for all three
|
||||
# traces in the plot which are aligned by their pick times for comparison
|
||||
delta_pick_ref = (self.pick1 - self.tr1_slice.stats.starttime)
|
||||
# correct pick time shift in traces for trace1
|
||||
ax1.axvline(delta_pick_ref, color='k', linestyle='dashed', label='Pick', lw=0.5)
|
||||
|
||||
# plot uncertainty around pick
|
||||
ylims = ax1.get_ylim()
|
||||
ax1.fill_between([delta_pick_ref - uncert, delta_pick_ref + uncert], ylims[0], ylims[1], alpha=.25,
|
||||
color='g', label='pick uncertainty)'.format(self.frac_max))
|
||||
|
||||
# add legend, title, labels
|
||||
ax1.legend(loc="lower right", prop={'size': "small"})
|
||||
ax1.set_title("Correlated {} with {}".format(self.tr2_slice.id, self.tr1_slice.id))
|
||||
ax1.set_xlabel("time [s]")
|
||||
ax1.set_ylabel("norm. amplitude")
|
||||
|
||||
# Plot cross correlation to second subplot
|
||||
ax2 = fig.add_subplot(212)
|
||||
ax2.plot(cc_t, cc_convex, ls="", marker=".", color="k", label="xcorr (convex)")
|
||||
ax2.plot(cc_t, cc_concave, ls="", marker=".", color="0.7", label="xcorr (concave)")
|
||||
@ -270,6 +322,7 @@ class XcorrPickCorrection:
|
||||
if num_samples < 5:
|
||||
msg = "Less than 5 samples selected for fit to cross " + "correlation: %s" % num_samples
|
||||
logging.debug(msg)
|
||||
logging.info('Not enough samples for polyfit. Consider increasing sampling frequency.')
|
||||
|
||||
# quadratic fit for small subwindow
|
||||
coeffs, cov_mat = np.polyfit(cc_t[first_sample:last_sample + 1], cc[first_sample:last_sample + 1], deg=2,
|
||||
@ -484,9 +537,6 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
|
||||
# get a dictionary containing coordinates for all sources
|
||||
stations_dict = metadata.get_all_coordinates()
|
||||
|
||||
# read processed (restituted) data (assume P and S are the same...)
|
||||
wfdata_raw = get_data(eventdir, params['P']['data_dir'], headonly=False)
|
||||
|
||||
# create dictionaries for final export of P and S phases together
|
||||
# ps_correlation_dict = {}
|
||||
# ps_taup_picks = {}
|
||||
@ -495,6 +545,9 @@ def correlate_event(eventdir, pylot_parameter, params, channel_config, update):
|
||||
|
||||
# iterate over P and S to create one model each
|
||||
for phase_type in params.keys():
|
||||
# read processed (restituted) data
|
||||
wfdata_raw = get_data(eventdir, params[phase_type]['data_dir'], headonly=False)
|
||||
|
||||
ncores = params[phase_type]['ncores']
|
||||
filter_type = params[phase_type]['filter_type']
|
||||
filter_options = params[phase_type]['filter_options']
|
||||
@ -1392,7 +1445,7 @@ def taupy_parallel(input_list, ncores):
|
||||
logging.info('Taupy_parallel: Generated {} parallel jobs.'.format(ncores))
|
||||
taupy_results = parallel(delayed(taupy_worker)(item) for item in input_list)
|
||||
|
||||
logging.info('Parallel execution finished.')
|
||||
logging.info('Parallel execution finished. Unpacking results...')
|
||||
|
||||
return unpack_result(taupy_results)
|
||||
|
||||
@ -1402,7 +1455,8 @@ def unpack_result(result):
|
||||
nerr = 0
|
||||
for item in result_dict.values():
|
||||
if item['err']:
|
||||
logging.debug(item['err'])
|
||||
logging.debug(f'Found error for {item["nwst_id"]}: {item["err"]}.')
|
||||
#logging.debug(f'Detailed traceback: {item["exc"]}')
|
||||
nerr += 1
|
||||
logging.info('Unpack results: Found {} errors after multiprocessing'.format(nerr))
|
||||
return result_dict
|
||||
@ -1462,11 +1516,16 @@ def taupy_worker(input_dict):
|
||||
|
||||
try:
|
||||
arrivals = model.get_travel_times_geo(**taupy_input)
|
||||
if len(arrivals) == 0:
|
||||
raise Exception(f'No arrivals found for phase {taupy_input["phase_list"]}. Source time: {source_time} -'
|
||||
f' Input: {taupy_input}')
|
||||
first_arrival = arrivals[0]
|
||||
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=first_arrival.name,
|
||||
phase_time=source_time + first_arrival.time, phase_dist=first_arrival.distance, err=None, )
|
||||
phase_time=source_time + first_arrival.time, phase_dist=first_arrival.distance, err=None,
|
||||
exc=None,)
|
||||
except Exception as e:
|
||||
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e))
|
||||
exc = traceback.format_exc()
|
||||
output_dict = dict(nwst_id=input_dict['nwst_id'], phase_name=None, phase_time=None, err=str(e), exc=exc,)
|
||||
return output_dict
|
||||
|
||||
|
||||
@ -1475,8 +1534,7 @@ def resample_worker(input_dict):
|
||||
freq = input_dict['freq']
|
||||
if freq == trace.stats.sampling_rate:
|
||||
return trace
|
||||
else:
|
||||
return trace.resample(freq, no_filter=False)
|
||||
return trace.resample(freq, no_filter=freq > trace.stats.sampling_rate)
|
||||
|
||||
|
||||
def rotation_worker(input_dict):
|
||||
|
Loading…
Reference in New Issue
Block a user