[new] added some unit tests for correlation picker (WIP)

This commit is contained in:
Marcel Paffrath 2024-08-07 17:11:27 +02:00
parent d5817adc46
commit 8e7bd87711
3 changed files with 100 additions and 27 deletions

View File

@ -96,7 +96,7 @@ class CorrelationParameters:
self.__parameter[key] = value
class XcorrPickCorrection:
class XCorrPickCorrection:
def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace,
t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5):
"""
@ -130,6 +130,9 @@ class XcorrPickCorrection:
self.trace1 = trace1
self.trace2 = trace2
self.tr1_slice = None
self.tr2_slice = None
self.pick1 = pick1
self.pick2 = pick2
@ -140,13 +143,6 @@ class XcorrPickCorrection:
self.samp_rate = 0
# perform some checks on the traces
self.check_traces()
# check data and take correct slice of traces
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
def check_traces(self) -> None:
"""
Check if the sampling rates of two traces match, raise an exception if they don't.
@ -184,6 +180,7 @@ class XcorrPickCorrection:
logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},'
f'pick: {pick}')
raise Exception(msg)
# apply signal processing and take correct slice of data
return tr.slice(start, end)
@ -290,6 +287,13 @@ class XcorrPickCorrection:
plt.close(fig)
# perform some checks on the traces
self.check_traces()
# check data and take correct slice of traces
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
# start of cross correlation method
shift_len = int(self.cc_maxlag * self.samp_rate)
cc = correlate(self.tr1_slice.data, self.tr2_slice.data, shift_len, demean=False)
@ -352,17 +356,10 @@ class XcorrPickCorrection:
# traces. Actually we do not want to shift the trace to align it but we
# want to correct the time of `pick2` so that the traces align without
# shifting. This is the negative of the cross correlation shift.
# MP MP calculate full width at half maximum
# MP MP calculate full width at (first by fraction of, now always using) half maximum
fmax = coeff
fwfm = 2 * (np.sqrt((b / (2 * a)) ** 2 + (self.frac_max * fmax - c) / a))
# # calculated error using a and ccc
# if a < 0:
# asqrtcc = np.sqrt(-1. / a) * (1. - coeff)
# else:
# # warning already printed above
# asqrtcc = np.nan
# uncertainty is half of two times the fwhm scaled by 1 - maxcc
uncert = fwfm * (1. - coeff)
if uncert < 0:
@ -1416,15 +1413,15 @@ def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination
return new_stream
pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
output_list = pool.map(rotation_worker, input_list, chunksize=10)
pool.close()
logging.info('Closed multiprocessing pool.')
pool.join()
del (pool)
stream.traces = [tr for tr in output_list if tr is not None]
return stream
# pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
# logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
# output_list = pool.map(rotation_worker, input_list, chunksize=10)
# pool.close()
# logging.info('Closed multiprocessing pool.')
# pool.join()
# del (pool)
# stream.traces = [tr for tr in output_list if tr is not None]
# return stream
def correlate_parallel(input_list, ncores):
@ -1583,14 +1580,14 @@ def correlation_worker(input_dict):
try:
logging.debug(f'Starting Pick correction for {nwst_id}')
xcpc = XcorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
xcpc = XCorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
t_before=t_before, t_after=t_after, cc_maxlag=cc_maxlag)
dpick, ccc, uncert, fwm = xcpc.cross_correlation(plot, fig_dir, plot_name='dpick')
logging.debug(f'dpick of first correlation: {dpick}')
if input_dict['ncorr'] > 1: # and not ccc <= 0:
xcpc2 = XcorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
xcpc2 = XCorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
input_dict['trace2_highf'], t_before=1., t_after=40., cc_maxlag=cc_maxlag2)
dpick2, ccc, uncert, fwm = xcpc2.cross_correlation(plot=plot, fig_dir=fig_dir, plot_name='error',

0
pylot/tests/__init__.py Normal file
View File

View File

@ -0,0 +1,76 @@
import pytest
from obspy import read, Trace, UTCDateTime
from pylot.correlation.pick_correlation_correction import XCorrPickCorrection
class TestXCorrPickCorrection():
def setup(self):
self.make_test_traces()
self.make_test_picks()
self.t_before = 2.
self.t_after = 2.
self.cc_maxlag = 0.5
def make_test_traces(self):
# take first trace of test Stream from obspy
tr1 = read()[0]
# filter trace
tr1.filter('bandpass', freqmin=1, freqmax=20)
# make a copy and shift the copy by 0.1 s
tr2 = tr1.copy()
tr2.stats.starttime += 0.1
self.trace1 = tr1
self.trace2 = tr2
def make_test_picks(self):
# create an artificial reference pick on reference trace (trace1) and another one on the 0.1 s shifted trace
self.tpick1 = UTCDateTime('2009-08-24T00:20:07.7')
# shift the second pick by 0.2 s, the correction should be around 0.1 s now
self.tpick2 = self.tpick1 + 0.2
def test_slice_trace_okay(self):
self.setup()
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
t_before=self.t_before, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
test_trace = self.trace1
pick_time = self.tpick2
sliced_trace = xcpc.slice_trace(test_trace, pick_time)
assert ((sliced_trace.stats.starttime == pick_time - self.t_before - self.cc_maxlag / 2)
and (sliced_trace.stats.endtime == pick_time + self.t_after + self.cc_maxlag / 2))
def test_slice_trace_fails(self):
self.setup()
test_trace = self.trace1
pick_time = self.tpick1
with pytest.raises(Exception):
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
t_before=self.t_before - 20, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
xcpc.slice_trace(test_trace, pick_time)
with pytest.raises(Exception):
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
t_before=self.t_before, t_after=self.t_after + 50, cc_maxlag=self.cc_maxlag)
xcpc.slice_trace(test_trace, pick_time)
def test_cross_correlation(self):
self.setup()
# create XCorrPickCorrection object
xcpc = XCorrPickCorrection(self.tpick1, self.trace1, self.tpick2, self.trace2, t_before=self.t_before,
t_after=self.t_after, cc_maxlag=self.cc_maxlag)
# execute correlation
correction, cc_max, uncert, fwfm = xcpc.cross_correlation(False, '', '')
# define awaited test result
test_result = (-0.09983091718314982, 0.9578431835689154, 0.0015285160561610929, 0.03625786256084631)
# check results
assert (correction, cc_max, uncert, fwfm) == test_result