[new] added some unit tests for correlation picker (WIP)
This commit is contained in:
parent
d5817adc46
commit
8e7bd87711
@ -96,7 +96,7 @@ class CorrelationParameters:
|
|||||||
self.__parameter[key] = value
|
self.__parameter[key] = value
|
||||||
|
|
||||||
|
|
||||||
class XcorrPickCorrection:
|
class XCorrPickCorrection:
|
||||||
def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace,
|
def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace,
|
||||||
t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5):
|
t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5):
|
||||||
"""
|
"""
|
||||||
@ -130,6 +130,9 @@ class XcorrPickCorrection:
|
|||||||
self.trace1 = trace1
|
self.trace1 = trace1
|
||||||
self.trace2 = trace2
|
self.trace2 = trace2
|
||||||
|
|
||||||
|
self.tr1_slice = None
|
||||||
|
self.tr2_slice = None
|
||||||
|
|
||||||
self.pick1 = pick1
|
self.pick1 = pick1
|
||||||
self.pick2 = pick2
|
self.pick2 = pick2
|
||||||
|
|
||||||
@ -140,13 +143,6 @@ class XcorrPickCorrection:
|
|||||||
|
|
||||||
self.samp_rate = 0
|
self.samp_rate = 0
|
||||||
|
|
||||||
# perform some checks on the traces
|
|
||||||
self.check_traces()
|
|
||||||
|
|
||||||
# check data and take correct slice of traces
|
|
||||||
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
|
|
||||||
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
|
|
||||||
|
|
||||||
def check_traces(self) -> None:
|
def check_traces(self) -> None:
|
||||||
"""
|
"""
|
||||||
Check if the sampling rates of two traces match, raise an exception if they don't.
|
Check if the sampling rates of two traces match, raise an exception if they don't.
|
||||||
@ -184,6 +180,7 @@ class XcorrPickCorrection:
|
|||||||
logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},'
|
logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},'
|
||||||
f'pick: {pick}')
|
f'pick: {pick}')
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
# apply signal processing and take correct slice of data
|
# apply signal processing and take correct slice of data
|
||||||
return tr.slice(start, end)
|
return tr.slice(start, end)
|
||||||
|
|
||||||
@ -290,6 +287,13 @@ class XcorrPickCorrection:
|
|||||||
|
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
# perform some checks on the traces
|
||||||
|
self.check_traces()
|
||||||
|
|
||||||
|
# check data and take correct slice of traces
|
||||||
|
self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
|
||||||
|
self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
|
||||||
|
|
||||||
# start of cross correlation method
|
# start of cross correlation method
|
||||||
shift_len = int(self.cc_maxlag * self.samp_rate)
|
shift_len = int(self.cc_maxlag * self.samp_rate)
|
||||||
cc = correlate(self.tr1_slice.data, self.tr2_slice.data, shift_len, demean=False)
|
cc = correlate(self.tr1_slice.data, self.tr2_slice.data, shift_len, demean=False)
|
||||||
@ -352,17 +356,10 @@ class XcorrPickCorrection:
|
|||||||
# traces. Actually we do not want to shift the trace to align it but we
|
# traces. Actually we do not want to shift the trace to align it but we
|
||||||
# want to correct the time of `pick2` so that the traces align without
|
# want to correct the time of `pick2` so that the traces align without
|
||||||
# shifting. This is the negative of the cross correlation shift.
|
# shifting. This is the negative of the cross correlation shift.
|
||||||
# MP MP calculate full width at half maximum
|
# MP MP calculate full width at (first by fraction of, now always using) half maximum
|
||||||
fmax = coeff
|
fmax = coeff
|
||||||
fwfm = 2 * (np.sqrt((b / (2 * a)) ** 2 + (self.frac_max * fmax - c) / a))
|
fwfm = 2 * (np.sqrt((b / (2 * a)) ** 2 + (self.frac_max * fmax - c) / a))
|
||||||
|
|
||||||
# # calculated error using a and ccc
|
|
||||||
# if a < 0:
|
|
||||||
# asqrtcc = np.sqrt(-1. / a) * (1. - coeff)
|
|
||||||
# else:
|
|
||||||
# # warning already printed above
|
|
||||||
# asqrtcc = np.nan
|
|
||||||
|
|
||||||
# uncertainty is half of two times the fwhm scaled by 1 - maxcc
|
# uncertainty is half of two times the fwhm scaled by 1 - maxcc
|
||||||
uncert = fwfm * (1. - coeff)
|
uncert = fwfm * (1. - coeff)
|
||||||
if uncert < 0:
|
if uncert < 0:
|
||||||
@ -1416,15 +1413,15 @@ def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination
|
|||||||
|
|
||||||
return new_stream
|
return new_stream
|
||||||
|
|
||||||
pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
|
# pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
|
||||||
logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
|
# logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
|
||||||
output_list = pool.map(rotation_worker, input_list, chunksize=10)
|
# output_list = pool.map(rotation_worker, input_list, chunksize=10)
|
||||||
pool.close()
|
# pool.close()
|
||||||
logging.info('Closed multiprocessing pool.')
|
# logging.info('Closed multiprocessing pool.')
|
||||||
pool.join()
|
# pool.join()
|
||||||
del (pool)
|
# del (pool)
|
||||||
stream.traces = [tr for tr in output_list if tr is not None]
|
# stream.traces = [tr for tr in output_list if tr is not None]
|
||||||
return stream
|
# return stream
|
||||||
|
|
||||||
|
|
||||||
def correlate_parallel(input_list, ncores):
|
def correlate_parallel(input_list, ncores):
|
||||||
@ -1583,14 +1580,14 @@ def correlation_worker(input_dict):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logging.debug(f'Starting Pick correction for {nwst_id}')
|
logging.debug(f'Starting Pick correction for {nwst_id}')
|
||||||
xcpc = XcorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
|
xcpc = XCorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
|
||||||
t_before=t_before, t_after=t_after, cc_maxlag=cc_maxlag)
|
t_before=t_before, t_after=t_after, cc_maxlag=cc_maxlag)
|
||||||
|
|
||||||
dpick, ccc, uncert, fwm = xcpc.cross_correlation(plot, fig_dir, plot_name='dpick')
|
dpick, ccc, uncert, fwm = xcpc.cross_correlation(plot, fig_dir, plot_name='dpick')
|
||||||
logging.debug(f'dpick of first correlation: {dpick}')
|
logging.debug(f'dpick of first correlation: {dpick}')
|
||||||
|
|
||||||
if input_dict['ncorr'] > 1: # and not ccc <= 0:
|
if input_dict['ncorr'] > 1: # and not ccc <= 0:
|
||||||
xcpc2 = XcorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
|
xcpc2 = XCorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
|
||||||
input_dict['trace2_highf'], t_before=1., t_after=40., cc_maxlag=cc_maxlag2)
|
input_dict['trace2_highf'], t_before=1., t_after=40., cc_maxlag=cc_maxlag2)
|
||||||
|
|
||||||
dpick2, ccc, uncert, fwm = xcpc2.cross_correlation(plot=plot, fig_dir=fig_dir, plot_name='error',
|
dpick2, ccc, uncert, fwm = xcpc2.cross_correlation(plot=plot, fig_dir=fig_dir, plot_name='error',
|
||||||
|
0
pylot/tests/__init__.py
Normal file
0
pylot/tests/__init__.py
Normal file
76
pylot/tests/test_pick_correlation_correction.py
Normal file
76
pylot/tests/test_pick_correlation_correction.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import pytest
|
||||||
|
from obspy import read, Trace, UTCDateTime
|
||||||
|
|
||||||
|
from pylot.correlation.pick_correlation_correction import XCorrPickCorrection
|
||||||
|
|
||||||
|
|
||||||
|
class TestXCorrPickCorrection():
|
||||||
|
def setup(self):
|
||||||
|
self.make_test_traces()
|
||||||
|
self.make_test_picks()
|
||||||
|
self.t_before = 2.
|
||||||
|
self.t_after = 2.
|
||||||
|
self.cc_maxlag = 0.5
|
||||||
|
|
||||||
|
def make_test_traces(self):
|
||||||
|
# take first trace of test Stream from obspy
|
||||||
|
tr1 = read()[0]
|
||||||
|
# filter trace
|
||||||
|
tr1.filter('bandpass', freqmin=1, freqmax=20)
|
||||||
|
# make a copy and shift the copy by 0.1 s
|
||||||
|
tr2 = tr1.copy()
|
||||||
|
tr2.stats.starttime += 0.1
|
||||||
|
|
||||||
|
self.trace1 = tr1
|
||||||
|
self.trace2 = tr2
|
||||||
|
|
||||||
|
def make_test_picks(self):
|
||||||
|
# create an artificial reference pick on reference trace (trace1) and another one on the 0.1 s shifted trace
|
||||||
|
self.tpick1 = UTCDateTime('2009-08-24T00:20:07.7')
|
||||||
|
# shift the second pick by 0.2 s, the correction should be around 0.1 s now
|
||||||
|
self.tpick2 = self.tpick1 + 0.2
|
||||||
|
|
||||||
|
def test_slice_trace_okay(self):
|
||||||
|
|
||||||
|
self.setup()
|
||||||
|
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
|
||||||
|
t_before=self.t_before, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
|
||||||
|
|
||||||
|
test_trace = self.trace1
|
||||||
|
pick_time = self.tpick2
|
||||||
|
|
||||||
|
sliced_trace = xcpc.slice_trace(test_trace, pick_time)
|
||||||
|
assert ((sliced_trace.stats.starttime == pick_time - self.t_before - self.cc_maxlag / 2)
|
||||||
|
and (sliced_trace.stats.endtime == pick_time + self.t_after + self.cc_maxlag / 2))
|
||||||
|
|
||||||
|
def test_slice_trace_fails(self):
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
test_trace = self.trace1
|
||||||
|
pick_time = self.tpick1
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
|
||||||
|
t_before=self.t_before - 20, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
|
||||||
|
xcpc.slice_trace(test_trace, pick_time)
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
|
||||||
|
t_before=self.t_before, t_after=self.t_after + 50, cc_maxlag=self.cc_maxlag)
|
||||||
|
xcpc.slice_trace(test_trace, pick_time)
|
||||||
|
|
||||||
|
def test_cross_correlation(self):
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
# create XCorrPickCorrection object
|
||||||
|
xcpc = XCorrPickCorrection(self.tpick1, self.trace1, self.tpick2, self.trace2, t_before=self.t_before,
|
||||||
|
t_after=self.t_after, cc_maxlag=self.cc_maxlag)
|
||||||
|
|
||||||
|
# execute correlation
|
||||||
|
correction, cc_max, uncert, fwfm = xcpc.cross_correlation(False, '', '')
|
||||||
|
|
||||||
|
# define awaited test result
|
||||||
|
test_result = (-0.09983091718314982, 0.9578431835689154, 0.0015285160561610929, 0.03625786256084631)
|
||||||
|
|
||||||
|
# check results
|
||||||
|
assert (correction, cc_max, uncert, fwfm) == test_result
|
Loading…
Reference in New Issue
Block a user