From 8e7bd8771111974a993108ba09d6cdaee10a4f1f Mon Sep 17 00:00:00 2001 From: Marcel Office Desktop Date: Wed, 7 Aug 2024 17:11:27 +0200 Subject: [PATCH] [new] added some unit tests for correlation picker (WIP) --- .../pick_correlation_correction.py | 51 ++++++------- pylot/tests/__init__.py | 0 .../tests/test_pick_correlation_correction.py | 76 +++++++++++++++++++ 3 files changed, 100 insertions(+), 27 deletions(-) create mode 100644 pylot/tests/__init__.py create mode 100644 pylot/tests/test_pick_correlation_correction.py diff --git a/pylot/correlation/pick_correlation_correction.py b/pylot/correlation/pick_correlation_correction.py index 737c3745..847cfe1c 100644 --- a/pylot/correlation/pick_correlation_correction.py +++ b/pylot/correlation/pick_correlation_correction.py @@ -96,7 +96,7 @@ class CorrelationParameters: self.__parameter[key] = value -class XcorrPickCorrection: +class XCorrPickCorrection: def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace, t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5): """ @@ -130,6 +130,9 @@ class XcorrPickCorrection: self.trace1 = trace1 self.trace2 = trace2 + self.tr1_slice = None + self.tr2_slice = None + self.pick1 = pick1 self.pick2 = pick2 @@ -140,13 +143,6 @@ class XcorrPickCorrection: self.samp_rate = 0 - # perform some checks on the traces - self.check_traces() - - # check data and take correct slice of traces - self.tr1_slice = self.slice_trace(self.trace1, self.pick1) - self.tr2_slice = self.slice_trace(self.trace2, self.pick2) - def check_traces(self) -> None: """ Check if the sampling rates of two traces match, raise an exception if they don't. @@ -184,6 +180,7 @@ class XcorrPickCorrection: logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},' f'pick: {pick}') raise Exception(msg) + # apply signal processing and take correct slice of data return tr.slice(start, end) @@ -290,6 +287,13 @@ class XcorrPickCorrection: plt.close(fig) + # perform some checks on the traces + self.check_traces() + + # check data and take correct slice of traces + self.tr1_slice = self.slice_trace(self.trace1, self.pick1) + self.tr2_slice = self.slice_trace(self.trace2, self.pick2) + # start of cross correlation method shift_len = int(self.cc_maxlag * self.samp_rate) cc = correlate(self.tr1_slice.data, self.tr2_slice.data, shift_len, demean=False) @@ -352,17 +356,10 @@ class XcorrPickCorrection: # traces. Actually we do not want to shift the trace to align it but we # want to correct the time of `pick2` so that the traces align without # shifting. This is the negative of the cross correlation shift. - # MP MP calculate full width at half maximum + # MP MP calculate full width at (first by fraction of, now always using) half maximum fmax = coeff fwfm = 2 * (np.sqrt((b / (2 * a)) ** 2 + (self.frac_max * fmax - c) / a)) - # # calculated error using a and ccc - # if a < 0: - # asqrtcc = np.sqrt(-1. / a) * (1. - coeff) - # else: - # # warning already printed above - # asqrtcc = np.nan - # uncertainty is half of two times the fwhm scaled by 1 - maxcc uncert = fwfm * (1. - coeff) if uncert < 0: @@ -1416,15 +1413,15 @@ def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination return new_stream - pool = multiprocessing.Pool(ncores, maxtasksperchild=100) - logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores)) - output_list = pool.map(rotation_worker, input_list, chunksize=10) - pool.close() - logging.info('Closed multiprocessing pool.') - pool.join() - del (pool) - stream.traces = [tr for tr in output_list if tr is not None] - return stream + # pool = multiprocessing.Pool(ncores, maxtasksperchild=100) + # logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores)) + # output_list = pool.map(rotation_worker, input_list, chunksize=10) + # pool.close() + # logging.info('Closed multiprocessing pool.') + # pool.join() + # del (pool) + # stream.traces = [tr for tr in output_list if tr is not None] + # return stream def correlate_parallel(input_list, ncores): @@ -1583,14 +1580,14 @@ def correlation_worker(input_dict): try: logging.debug(f'Starting Pick correction for {nwst_id}') - xcpc = XcorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'], + xcpc = XCorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'], t_before=t_before, t_after=t_after, cc_maxlag=cc_maxlag) dpick, ccc, uncert, fwm = xcpc.cross_correlation(plot, fig_dir, plot_name='dpick') logging.debug(f'dpick of first correlation: {dpick}') if input_dict['ncorr'] > 1: # and not ccc <= 0: - xcpc2 = XcorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick, + xcpc2 = XCorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick, input_dict['trace2_highf'], t_before=1., t_after=40., cc_maxlag=cc_maxlag2) dpick2, ccc, uncert, fwm = xcpc2.cross_correlation(plot=plot, fig_dir=fig_dir, plot_name='error', diff --git a/pylot/tests/__init__.py b/pylot/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pylot/tests/test_pick_correlation_correction.py b/pylot/tests/test_pick_correlation_correction.py new file mode 100644 index 00000000..f7f98bc5 --- /dev/null +++ b/pylot/tests/test_pick_correlation_correction.py @@ -0,0 +1,76 @@ +import pytest +from obspy import read, Trace, UTCDateTime + +from pylot.correlation.pick_correlation_correction import XCorrPickCorrection + + +class TestXCorrPickCorrection(): + def setup(self): + self.make_test_traces() + self.make_test_picks() + self.t_before = 2. + self.t_after = 2. + self.cc_maxlag = 0.5 + + def make_test_traces(self): + # take first trace of test Stream from obspy + tr1 = read()[0] + # filter trace + tr1.filter('bandpass', freqmin=1, freqmax=20) + # make a copy and shift the copy by 0.1 s + tr2 = tr1.copy() + tr2.stats.starttime += 0.1 + + self.trace1 = tr1 + self.trace2 = tr2 + + def make_test_picks(self): + # create an artificial reference pick on reference trace (trace1) and another one on the 0.1 s shifted trace + self.tpick1 = UTCDateTime('2009-08-24T00:20:07.7') + # shift the second pick by 0.2 s, the correction should be around 0.1 s now + self.tpick2 = self.tpick1 + 0.2 + + def test_slice_trace_okay(self): + + self.setup() + xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(), + t_before=self.t_before, t_after=self.t_after, cc_maxlag=self.cc_maxlag) + + test_trace = self.trace1 + pick_time = self.tpick2 + + sliced_trace = xcpc.slice_trace(test_trace, pick_time) + assert ((sliced_trace.stats.starttime == pick_time - self.t_before - self.cc_maxlag / 2) + and (sliced_trace.stats.endtime == pick_time + self.t_after + self.cc_maxlag / 2)) + + def test_slice_trace_fails(self): + self.setup() + + test_trace = self.trace1 + pick_time = self.tpick1 + + with pytest.raises(Exception): + xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(), + t_before=self.t_before - 20, t_after=self.t_after, cc_maxlag=self.cc_maxlag) + xcpc.slice_trace(test_trace, pick_time) + + with pytest.raises(Exception): + xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(), + t_before=self.t_before, t_after=self.t_after + 50, cc_maxlag=self.cc_maxlag) + xcpc.slice_trace(test_trace, pick_time) + + def test_cross_correlation(self): + self.setup() + + # create XCorrPickCorrection object + xcpc = XCorrPickCorrection(self.tpick1, self.trace1, self.tpick2, self.trace2, t_before=self.t_before, + t_after=self.t_after, cc_maxlag=self.cc_maxlag) + + # execute correlation + correction, cc_max, uncert, fwfm = xcpc.cross_correlation(False, '', '') + + # define awaited test result + test_result = (-0.09983091718314982, 0.9578431835689154, 0.0015285160561610929, 0.03625786256084631) + + # check results + assert (correction, cc_max, uncert, fwfm) == test_result \ No newline at end of file