[new] added some unit tests for correlation picker (WIP)
This commit is contained in:
parent
d5817adc46
commit
8e7bd87711
@ -96,7 +96,7 @@ class CorrelationParameters:
|
||||
self.__parameter[key] = value
|
||||
|
||||
|
||||
class XcorrPickCorrection:
|
||||
class XCorrPickCorrection:
|
||||
def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace,
|
||||
t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5):
|
||||
"""
|
||||
@ -130,6 +130,9 @@ class XcorrPickCorrection:
|
||||
self.trace1 = trace1
|
||||
self.trace2 = trace2
|
||||
|
||||
self.tr1_slice = None
|
||||
self.tr2_slice = None
|
||||
|
||||
self.pick1 = pick1
|
||||
self.pick2 = pick2
|
||||
|
||||
@ -140,13 +143,6 @@ class XcorrPickCorrection:
|
||||
|
||||
self.samp_rate = 0
|
||||
|
||||
# perform some checks on the traces
|
||||
self.check_traces()
|
||||
|
||||
# check data and take correct slice of traces
|
||||
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
|
||||
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
|
||||
|
||||
def check_traces(self) -> None:
|
||||
"""
|
||||
Check if the sampling rates of two traces match, raise an exception if they don't.
|
||||
@ -184,6 +180,7 @@ class XcorrPickCorrection:
|
||||
logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},'
|
||||
f'pick: {pick}')
|
||||
raise Exception(msg)
|
||||
|
||||
# apply signal processing and take correct slice of data
|
||||
return tr.slice(start, end)
|
||||
|
||||
@ -290,6 +287,13 @@ class XcorrPickCorrection:
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
# perform some checks on the traces
|
||||
self.check_traces()
|
||||
|
||||
# check data and take correct slice of traces
|
||||
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
|
||||
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
|
||||
|
||||
# start of cross correlation method
|
||||
shift_len = int(self.cc_maxlag * self.samp_rate)
|
||||
cc = correlate(self.tr1_slice.data, self.tr2_slice.data, shift_len, demean=False)
|
||||
@ -352,17 +356,10 @@ class XcorrPickCorrection:
|
||||
# traces. Actually we do not want to shift the trace to align it but we
|
||||
# want to correct the time of `pick2` so that the traces align without
|
||||
# shifting. This is the negative of the cross correlation shift.
|
||||
# MP MP calculate full width at half maximum
|
||||
# MP MP calculate full width at (first by fraction of, now always using) half maximum
|
||||
fmax = coeff
|
||||
fwfm = 2 * (np.sqrt((b / (2 * a)) ** 2 + (self.frac_max * fmax - c) / a))
|
||||
|
||||
# # calculated error using a and ccc
|
||||
# if a < 0:
|
||||
# asqrtcc = np.sqrt(-1. / a) * (1. - coeff)
|
||||
# else:
|
||||
# # warning already printed above
|
||||
# asqrtcc = np.nan
|
||||
|
||||
# uncertainty is half of two times the fwhm scaled by 1 - maxcc
|
||||
uncert = fwfm * (1. - coeff)
|
||||
if uncert < 0:
|
||||
@ -1416,15 +1413,15 @@ def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination
|
||||
|
||||
return new_stream
|
||||
|
||||
pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
|
||||
logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
|
||||
output_list = pool.map(rotation_worker, input_list, chunksize=10)
|
||||
pool.close()
|
||||
logging.info('Closed multiprocessing pool.')
|
||||
pool.join()
|
||||
del (pool)
|
||||
stream.traces = [tr for tr in output_list if tr is not None]
|
||||
return stream
|
||||
# pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
|
||||
# logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
|
||||
# output_list = pool.map(rotation_worker, input_list, chunksize=10)
|
||||
# pool.close()
|
||||
# logging.info('Closed multiprocessing pool.')
|
||||
# pool.join()
|
||||
# del (pool)
|
||||
# stream.traces = [tr for tr in output_list if tr is not None]
|
||||
# return stream
|
||||
|
||||
|
||||
def correlate_parallel(input_list, ncores):
|
||||
@ -1583,14 +1580,14 @@ def correlation_worker(input_dict):
|
||||
|
||||
try:
|
||||
logging.debug(f'Starting Pick correction for {nwst_id}')
|
||||
xcpc = XcorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
|
||||
xcpc = XCorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
|
||||
t_before=t_before, t_after=t_after, cc_maxlag=cc_maxlag)
|
||||
|
||||
dpick, ccc, uncert, fwm = xcpc.cross_correlation(plot, fig_dir, plot_name='dpick')
|
||||
logging.debug(f'dpick of first correlation: {dpick}')
|
||||
|
||||
if input_dict['ncorr'] > 1: # and not ccc <= 0:
|
||||
xcpc2 = XcorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
|
||||
xcpc2 = XCorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
|
||||
input_dict['trace2_highf'], t_before=1., t_after=40., cc_maxlag=cc_maxlag2)
|
||||
|
||||
dpick2, ccc, uncert, fwm = xcpc2.cross_correlation(plot=plot, fig_dir=fig_dir, plot_name='error',
|
||||
|
0
pylot/tests/__init__.py
Normal file
0
pylot/tests/__init__.py
Normal file
76
pylot/tests/test_pick_correlation_correction.py
Normal file
76
pylot/tests/test_pick_correlation_correction.py
Normal file
@ -0,0 +1,76 @@
|
||||
import pytest
|
||||
from obspy import read, Trace, UTCDateTime
|
||||
|
||||
from pylot.correlation.pick_correlation_correction import XCorrPickCorrection
|
||||
|
||||
|
||||
class TestXCorrPickCorrection():
|
||||
def setup(self):
|
||||
self.make_test_traces()
|
||||
self.make_test_picks()
|
||||
self.t_before = 2.
|
||||
self.t_after = 2.
|
||||
self.cc_maxlag = 0.5
|
||||
|
||||
def make_test_traces(self):
|
||||
# take first trace of test Stream from obspy
|
||||
tr1 = read()[0]
|
||||
# filter trace
|
||||
tr1.filter('bandpass', freqmin=1, freqmax=20)
|
||||
# make a copy and shift the copy by 0.1 s
|
||||
tr2 = tr1.copy()
|
||||
tr2.stats.starttime += 0.1
|
||||
|
||||
self.trace1 = tr1
|
||||
self.trace2 = tr2
|
||||
|
||||
def make_test_picks(self):
|
||||
# create an artificial reference pick on reference trace (trace1) and another one on the 0.1 s shifted trace
|
||||
self.tpick1 = UTCDateTime('2009-08-24T00:20:07.7')
|
||||
# shift the second pick by 0.2 s, the correction should be around 0.1 s now
|
||||
self.tpick2 = self.tpick1 + 0.2
|
||||
|
||||
def test_slice_trace_okay(self):
|
||||
|
||||
self.setup()
|
||||
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
|
||||
t_before=self.t_before, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
|
||||
|
||||
test_trace = self.trace1
|
||||
pick_time = self.tpick2
|
||||
|
||||
sliced_trace = xcpc.slice_trace(test_trace, pick_time)
|
||||
assert ((sliced_trace.stats.starttime == pick_time - self.t_before - self.cc_maxlag / 2)
|
||||
and (sliced_trace.stats.endtime == pick_time + self.t_after + self.cc_maxlag / 2))
|
||||
|
||||
def test_slice_trace_fails(self):
|
||||
self.setup()
|
||||
|
||||
test_trace = self.trace1
|
||||
pick_time = self.tpick1
|
||||
|
||||
with pytest.raises(Exception):
|
||||
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
|
||||
t_before=self.t_before - 20, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
|
||||
xcpc.slice_trace(test_trace, pick_time)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
|
||||
t_before=self.t_before, t_after=self.t_after + 50, cc_maxlag=self.cc_maxlag)
|
||||
xcpc.slice_trace(test_trace, pick_time)
|
||||
|
||||
def test_cross_correlation(self):
|
||||
self.setup()
|
||||
|
||||
# create XCorrPickCorrection object
|
||||
xcpc = XCorrPickCorrection(self.tpick1, self.trace1, self.tpick2, self.trace2, t_before=self.t_before,
|
||||
t_after=self.t_after, cc_maxlag=self.cc_maxlag)
|
||||
|
||||
# execute correlation
|
||||
correction, cc_max, uncert, fwfm = xcpc.cross_correlation(False, '', '')
|
||||
|
||||
# define awaited test result
|
||||
test_result = (-0.09983091718314982, 0.9578431835689154, 0.0015285160561610929, 0.03625786256084631)
|
||||
|
||||
# check results
|
||||
assert (correction, cc_max, uncert, fwfm) == test_result
|
Loading…
Reference in New Issue
Block a user