[new] added some unit tests for correlation picker (WIP)
This commit is contained in:
		
							parent
							
								
									d5817adc46
								
							
						
					
					
						commit
						8e7bd87711
					
				@ -96,7 +96,7 @@ class CorrelationParameters:
 | 
			
		||||
            self.__parameter[key] = value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XcorrPickCorrection:
 | 
			
		||||
class XCorrPickCorrection:
 | 
			
		||||
    def __init__(self, pick1: UTCDateTime, trace1: Trace, pick2: UTCDateTime, trace2: Trace,
 | 
			
		||||
                 t_before: float, t_after: float, cc_maxlag: float, frac_max: float = 0.5):
 | 
			
		||||
        """
 | 
			
		||||
@ -130,6 +130,9 @@ class XcorrPickCorrection:
 | 
			
		||||
        self.trace1 = trace1
 | 
			
		||||
        self.trace2 = trace2
 | 
			
		||||
 | 
			
		||||
        self.tr1_slice = None
 | 
			
		||||
        self.tr2_slice = None
 | 
			
		||||
 | 
			
		||||
        self.pick1 = pick1
 | 
			
		||||
        self.pick2 = pick2
 | 
			
		||||
 | 
			
		||||
@ -140,13 +143,6 @@ class XcorrPickCorrection:
 | 
			
		||||
 | 
			
		||||
        self.samp_rate = 0
 | 
			
		||||
 | 
			
		||||
        # perform some checks on the traces
 | 
			
		||||
        self.check_traces()
 | 
			
		||||
 | 
			
		||||
        # check data and take correct slice of traces
 | 
			
		||||
        self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
 | 
			
		||||
        self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
 | 
			
		||||
 | 
			
		||||
    def check_traces(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Check if the sampling rates of two traces match, raise an exception if they don't.
 | 
			
		||||
@ -184,6 +180,7 @@ class XcorrPickCorrection:
 | 
			
		||||
            logging.debug(f'end: {end}, t_after: {self.t_after}, cc_maxlag: {self.cc_maxlag},'
 | 
			
		||||
                          f'pick: {pick}')
 | 
			
		||||
            raise Exception(msg)
 | 
			
		||||
 | 
			
		||||
        # apply signal processing and take correct slice of data
 | 
			
		||||
        return tr.slice(start, end)
 | 
			
		||||
 | 
			
		||||
@ -290,6 +287,13 @@ class XcorrPickCorrection:
 | 
			
		||||
 | 
			
		||||
            plt.close(fig)
 | 
			
		||||
 | 
			
		||||
        # perform some checks on the traces
 | 
			
		||||
        self.check_traces()
 | 
			
		||||
 | 
			
		||||
        # check data and take correct slice of traces
 | 
			
		||||
        self.tr1_slice = self.slice_trace(self.trace1, self.pick1)
 | 
			
		||||
        self.tr2_slice = self.slice_trace(self.trace2, self.pick2)
 | 
			
		||||
 | 
			
		||||
        # start of cross correlation method
 | 
			
		||||
        shift_len = int(self.cc_maxlag * self.samp_rate)
 | 
			
		||||
        cc = correlate(self.tr1_slice.data, self.tr2_slice.data, shift_len, demean=False)
 | 
			
		||||
@ -352,17 +356,10 @@ class XcorrPickCorrection:
 | 
			
		||||
        # traces. Actually we do not want to shift the trace to align it but we
 | 
			
		||||
        # want to correct the time of `pick2` so that the traces align without
 | 
			
		||||
        # shifting. This is the negative of the cross correlation shift.
 | 
			
		||||
        # MP MP calculate full width at half maximum
 | 
			
		||||
        # MP MP calculate full width at (first by fraction of, now always using) half maximum
 | 
			
		||||
        fmax = coeff
 | 
			
		||||
        fwfm = 2 * (np.sqrt((b / (2 * a)) ** 2 + (self.frac_max * fmax - c) / a))
 | 
			
		||||
 | 
			
		||||
        # # calculated error using a and ccc
 | 
			
		||||
        # if a < 0:
 | 
			
		||||
        #     asqrtcc = np.sqrt(-1. / a) * (1. - coeff)
 | 
			
		||||
        # else:
 | 
			
		||||
        #     # warning already printed above
 | 
			
		||||
        #     asqrtcc = np.nan
 | 
			
		||||
 | 
			
		||||
        # uncertainty is half of two times the fwhm scaled by 1 - maxcc
 | 
			
		||||
        uncert = fwfm * (1. - coeff)
 | 
			
		||||
        if uncert < 0:
 | 
			
		||||
@ -1416,15 +1413,15 @@ def rotate_stream(stream, metadata, origin, stations_dict, channels, inclination
 | 
			
		||||
 | 
			
		||||
    return new_stream
 | 
			
		||||
 | 
			
		||||
    pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
 | 
			
		||||
    logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
 | 
			
		||||
    output_list = pool.map(rotation_worker, input_list, chunksize=10)
 | 
			
		||||
    pool.close()
 | 
			
		||||
    logging.info('Closed multiprocessing pool.')
 | 
			
		||||
    pool.join()
 | 
			
		||||
    del (pool)
 | 
			
		||||
    stream.traces = [tr for tr in output_list if tr is not None]
 | 
			
		||||
    return stream
 | 
			
		||||
    # pool = multiprocessing.Pool(ncores, maxtasksperchild=100)
 | 
			
		||||
    # logging.info('Resample_parallel: Generated multiprocessing pool with {} cores.'.format(ncores))
 | 
			
		||||
    # output_list = pool.map(rotation_worker, input_list, chunksize=10)
 | 
			
		||||
    # pool.close()
 | 
			
		||||
    # logging.info('Closed multiprocessing pool.')
 | 
			
		||||
    # pool.join()
 | 
			
		||||
    # del (pool)
 | 
			
		||||
    # stream.traces = [tr for tr in output_list if tr is not None]
 | 
			
		||||
    # return stream
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def correlate_parallel(input_list, ncores):
 | 
			
		||||
@ -1583,14 +1580,14 @@ def correlation_worker(input_dict):
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        logging.debug(f'Starting Pick correction for {nwst_id}')
 | 
			
		||||
        xcpc = XcorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
 | 
			
		||||
        xcpc = XCorrPickCorrection(pick_this.time, input_dict['trace1'], other_pick.time, input_dict['trace2'],
 | 
			
		||||
                                   t_before=t_before, t_after=t_after, cc_maxlag=cc_maxlag)
 | 
			
		||||
 | 
			
		||||
        dpick, ccc, uncert, fwm = xcpc.cross_correlation(plot, fig_dir, plot_name='dpick')
 | 
			
		||||
        logging.debug(f'dpick of first correlation: {dpick}')
 | 
			
		||||
 | 
			
		||||
        if input_dict['ncorr'] > 1:  # and not ccc <= 0:
 | 
			
		||||
            xcpc2 = XcorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
 | 
			
		||||
            xcpc2 = XCorrPickCorrection(pick_this.time, input_dict['trace1_highf'], other_pick.time + dpick,
 | 
			
		||||
                                        input_dict['trace2_highf'], t_before=1., t_after=40., cc_maxlag=cc_maxlag2)
 | 
			
		||||
 | 
			
		||||
            dpick2, ccc, uncert, fwm = xcpc2.cross_correlation(plot=plot, fig_dir=fig_dir, plot_name='error',
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										0
									
								
								pylot/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								pylot/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										76
									
								
								pylot/tests/test_pick_correlation_correction.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								pylot/tests/test_pick_correlation_correction.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,76 @@
 | 
			
		||||
import pytest
 | 
			
		||||
from obspy import read, Trace, UTCDateTime
 | 
			
		||||
 | 
			
		||||
from pylot.correlation.pick_correlation_correction import XCorrPickCorrection
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestXCorrPickCorrection():
 | 
			
		||||
    def setup(self):
 | 
			
		||||
        self.make_test_traces()
 | 
			
		||||
        self.make_test_picks()
 | 
			
		||||
        self.t_before = 2.
 | 
			
		||||
        self.t_after = 2.
 | 
			
		||||
        self.cc_maxlag = 0.5
 | 
			
		||||
 | 
			
		||||
    def make_test_traces(self):
 | 
			
		||||
        # take first trace of test Stream from obspy
 | 
			
		||||
        tr1 = read()[0]
 | 
			
		||||
        # filter trace
 | 
			
		||||
        tr1.filter('bandpass', freqmin=1, freqmax=20)
 | 
			
		||||
        # make a copy and shift the copy by 0.1 s
 | 
			
		||||
        tr2 = tr1.copy()
 | 
			
		||||
        tr2.stats.starttime += 0.1
 | 
			
		||||
 | 
			
		||||
        self.trace1 = tr1
 | 
			
		||||
        self.trace2 = tr2
 | 
			
		||||
 | 
			
		||||
    def make_test_picks(self):
 | 
			
		||||
        # create an artificial reference pick on reference trace (trace1) and another one on the 0.1 s shifted trace
 | 
			
		||||
        self.tpick1 = UTCDateTime('2009-08-24T00:20:07.7')
 | 
			
		||||
        # shift the second pick by 0.2 s, the correction should be around 0.1 s now
 | 
			
		||||
        self.tpick2 = self.tpick1 + 0.2
 | 
			
		||||
 | 
			
		||||
    def test_slice_trace_okay(self):
 | 
			
		||||
 | 
			
		||||
        self.setup()
 | 
			
		||||
        xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
 | 
			
		||||
                                   t_before=self.t_before, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
 | 
			
		||||
 | 
			
		||||
        test_trace = self.trace1
 | 
			
		||||
        pick_time = self.tpick2
 | 
			
		||||
 | 
			
		||||
        sliced_trace = xcpc.slice_trace(test_trace, pick_time)
 | 
			
		||||
        assert ((sliced_trace.stats.starttime == pick_time - self.t_before - self.cc_maxlag / 2)
 | 
			
		||||
                and (sliced_trace.stats.endtime == pick_time + self.t_after + self.cc_maxlag / 2))
 | 
			
		||||
 | 
			
		||||
    def test_slice_trace_fails(self):
 | 
			
		||||
        self.setup()
 | 
			
		||||
 | 
			
		||||
        test_trace = self.trace1
 | 
			
		||||
        pick_time = self.tpick1
 | 
			
		||||
 | 
			
		||||
        with pytest.raises(Exception):
 | 
			
		||||
            xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
 | 
			
		||||
                                       t_before=self.t_before - 20, t_after=self.t_after, cc_maxlag=self.cc_maxlag)
 | 
			
		||||
            xcpc.slice_trace(test_trace, pick_time)
 | 
			
		||||
 | 
			
		||||
        with pytest.raises(Exception):
 | 
			
		||||
            xcpc = XCorrPickCorrection(UTCDateTime(), Trace(), UTCDateTime(), Trace(),
 | 
			
		||||
                                       t_before=self.t_before, t_after=self.t_after + 50, cc_maxlag=self.cc_maxlag)
 | 
			
		||||
            xcpc.slice_trace(test_trace, pick_time)
 | 
			
		||||
 | 
			
		||||
    def test_cross_correlation(self):
 | 
			
		||||
        self.setup()
 | 
			
		||||
 | 
			
		||||
        # create XCorrPickCorrection object
 | 
			
		||||
        xcpc = XCorrPickCorrection(self.tpick1, self.trace1, self.tpick2, self.trace2, t_before=self.t_before,
 | 
			
		||||
                                   t_after=self.t_after, cc_maxlag=self.cc_maxlag)
 | 
			
		||||
 | 
			
		||||
        # execute correlation
 | 
			
		||||
        correction, cc_max, uncert, fwfm = xcpc.cross_correlation(False, '', '')
 | 
			
		||||
 | 
			
		||||
        # define awaited test result
 | 
			
		||||
        test_result = (-0.09983091718314982, 0.9578431835689154, 0.0015285160561610929, 0.03625786256084631)
 | 
			
		||||
 | 
			
		||||
        # check results
 | 
			
		||||
        assert (correction, cc_max, uncert, fwfm) == test_result
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user