feat: add type hints and tests for plot utils

This commit is contained in:
Sebastian Wehling-Benatelli 2023-04-23 21:37:20 +02:00 committed by Sebastian Wehling-Benatelli
parent fb32d5e0c5
commit 2bbb84190c
2 changed files with 60 additions and 28 deletions

View File

@ -20,6 +20,10 @@ from pylot.core.io.inputs import PylotParameter, FilterOptions
from pylot.core.util.obspyDMT_interface import check_obspydmt_eventfolder from pylot.core.util.obspyDMT_interface import check_obspydmt_eventfolder
from pylot.styles import style_settings from pylot.styles import style_settings
Rgba: Type[tuple] = Tuple[int, int, int, int]
Mplrgba: Type[tuple] = Tuple[float, float, float, float]
Mplrgbastr: Type[tuple] = Tuple[str, str, str, str]
def _pickle_method(m): def _pickle_method(m):
if m.im_self is None: if m.im_self is None:
@ -121,7 +125,7 @@ def gen_Pool(ncores=0):
print('gen_Pool: Generated multiprocessing Pool with {} cores\n'.format(ncores)) print('gen_Pool: Generated multiprocessing Pool with {} cores\n'.format(ncores))
pool = multiprocessing.Pool(ncores, maxtasksperchild=100) pool = multiprocessing.Pool(ncores)
return pool return pool
@ -382,6 +386,7 @@ def get_bool(value):
else: else:
return False return False
def four_digits(year): def four_digits(year):
""" """
takes a two digit year integer and returns the correct four digit equivalent takes a two digit year integer and returns the correct four digit equivalent
@ -655,32 +660,53 @@ def key_for_set_value(d):
return r return r
def prepTimeAxis(stime, trace, verbosity=0): def prep_time_axis(offset, trace, verbosity=0):
""" """
takes a starttime and a trace object and returns a valid time axis for takes an offset and a trace object and returns a valid time axis for
plotting plotting
:param stime: start time of the actual seismogram as UTCDateTime :param offset: offset of the actual seismogram on plotting axis
:type stime: `~obspy.core.utcdatetime.UTCDateTime` :type offset: float or int
:param trace: seismic trace object :param trace: seismic trace object
:type trace: `~obspy.core.trace.Trace` :type trace: `~obspy.core.trace.Trace`
:param verbosity: if != 0, debug output will be written to console :param verbosity: if != 0, debug output will be written to console
:type verbosity: int :type verbosity: int
:return: valid numpy array with time stamps for plotting :return: valid numpy array with time stamps for plotting
:rtype: `~numpy.ndarray` :rtype: `~numpy.ndarray`
>>> tr = read()[0]
>>> prep_time_axis(0., tr)
array([0.00000000e+00, 1.00033344e-02, 2.00066689e-02, ...,
2.99799933e+01, 2.99899967e+01, 3.00000000e+01])
>>> prep_time_axis(22.5, tr)
array([22.5 , 22.51000333, 22.52000667, ..., 52.47999333,
52.48999667, 52.5 ])
>>> prep_time_axis(tr.stats.starttime, tr)
Traceback (most recent call last):
...
AssertionError: 'offset' is not of type 'float' or 'int'; type: <class 'obspy.core.utcdatetime.UTCDateTime'>
>>> tr.stats.npts -= 1
>>> prep_time_axis(0, tr)
array([0.00000000e+00, 1.00033356e-02, 2.00066711e-02, ...,
2.99699933e+01, 2.99799967e+01, 2.99900000e+01])
>>> tr.stats.npts += 2
>>> prep_time_axis(0, tr)
array([0.00000000e+00, 1.00033333e-02, 2.00066667e-02, ...,
2.99899933e+01, 2.99999967e+01, 3.00100000e+01])
""" """
assert isinstance(offset, (float, int)), "'offset' is not of type 'float' or 'int'; type: {}".format(type(offset))
nsamp = trace.stats.npts nsamp = trace.stats.npts
srate = trace.stats.sampling_rate srate = trace.stats.sampling_rate
tincr = trace.stats.delta tincr = trace.stats.delta
etime = stime + nsamp / srate etime = offset + nsamp / srate
time_ax = np.linspace(stime, etime, nsamp) time_ax = np.linspace(offset, etime, nsamp)
if len(time_ax) < nsamp: if len(time_ax) < nsamp:
if verbosity: if verbosity:
print('elongate time axes by one datum') print('elongate time axes by one datum')
time_ax = np.arange(stime, etime + tincr, tincr) time_ax = np.arange(offset, etime + tincr, tincr)
elif len(time_ax) > nsamp: elif len(time_ax) > nsamp:
if verbosity: if verbosity:
print('shorten time axes by one datum') print('shorten time axes by one datum')
time_ax = np.arange(stime, etime - tincr, tincr) time_ax = np.arange(offset, etime - tincr, tincr)
if len(time_ax) != nsamp: if len(time_ax) != nsamp:
print('Station {0}, {1} samples of data \n ' print('Station {0}, {1} samples of data \n '
'{2} length of time vector \n' '{2} length of time vector \n'
@ -713,7 +739,7 @@ def find_horizontals(data):
return rval return rval
def pick_color(picktype, phase, quality=0): def pick_color(picktype: Literal['manual', 'automatic'], phase: Literal['P', 'S'], quality: int = 0) -> Rgba:
""" """
Create pick color by modifying the base color by the quality. Create pick color by modifying the base color by the quality.
@ -726,7 +752,7 @@ def pick_color(picktype, phase, quality=0):
:param quality: quality of pick. Decides the new intensity of the modifier color :param quality: quality of pick. Decides the new intensity of the modifier color
:type quality: int :type quality: int
:return: tuple containing modified rgba color values :return: tuple containing modified rgba color values
:rtype: (int, int, int, int) :rtype: Rgba
""" """
min_quality = 3 min_quality = 3
bpc = base_phase_colors(picktype, phase) # returns dict like {'modifier': 'g', 'rgba': (0, 0, 255, 255)} bpc = base_phase_colors(picktype, phase) # returns dict like {'modifier': 'g', 'rgba': (0, 0, 255, 255)}
@ -782,17 +808,17 @@ def pick_linestyle_plt(picktype, key):
return linestyles[picktype][key] return linestyles[picktype][key]
def modify_rgba(rgba, modifier, intensity): def modify_rgba(rgba: Rgba, modifier: Literal['r', 'g', 'b'], intensity: float) -> Rgba:
""" """
Modify rgba color by adding the given intensity to the modifier color Modify rgba color by adding the given intensity to the modifier color
:param rgba: tuple containing rgba values :param rgba: tuple containing rgba values
:type rgba: (int, int, int, int) :type rgba: Rgba
:param modifier: which color should be modified, eg. 'r', 'g', 'b' :param modifier: which color should be modified; options: 'r', 'g', 'b'
:type modifier: str :type modifier: Literal['r', 'g', 'b']
:param intensity: intensity to be added to selected color :param intensity: intensity to be added to selected color
:type intensity: float :type intensity: float
:return: tuple containing rgba values :return: tuple containing rgba values
:rtype: (int, int, int, int) :rtype: Rgba
""" """
rgba = list(rgba) rgba = list(rgba)
index = {'r': 0, index = {'r': 0,
@ -826,18 +852,20 @@ def transform_colors_mpl_str(colors, no_alpha=False):
Transforms rgba color values to a matplotlib string of color values with a range of [0, 1] Transforms rgba color values to a matplotlib string of color values with a range of [0, 1]
:param colors: tuple of rgba color values ranging from [0, 255] :param colors: tuple of rgba color values ranging from [0, 255]
:type colors: (float, float, float, float) :type colors: (float, float, float, float)
:param no_alpha: Wether to return a alpha value in the matplotlib color string :param no_alpha: Whether to return an alpha value in the matplotlib color string
:type no_alpha: bool :type no_alpha: bool
:return: String containing r, g, b values and alpha value if no_alpha is False (default) :return: String containing r, g, b values and alpha value if no_alpha is False (default)
:rtype: str :rtype: str
>>> transform_colors_mpl_str((255., 255., 255., 255.), True)
'(1.0, 1.0, 1.0)'
>>> transform_colors_mpl_str((255., 255., 255., 255.))
'(1.0, 1.0, 1.0, 1.0)'
""" """
colors = list(colors)
colors_mpl = tuple([color / 255. for color in colors])
if no_alpha: if no_alpha:
colors_mpl = '({}, {}, {})'.format(*colors_mpl) return '({}, {}, {})'.format(*transform_colors_mpl(colors))
else: else:
colors_mpl = '({}, {}, {}, {})'.format(*colors_mpl) return '({}, {}, {}, {})'.format(*transform_colors_mpl(colors))
return colors_mpl
def transform_colors_mpl(colors): def transform_colors_mpl(colors):
@ -847,6 +875,10 @@ def transform_colors_mpl(colors):
:type colors: (float, float, float, float) :type colors: (float, float, float, float)
:return: tuple of rgba color values ranging from [0, 1] :return: tuple of rgba color values ranging from [0, 1]
:rtype: (float, float, float, float) :rtype: (float, float, float, float)
>>> transform_colors_mpl((127.5, 0., 63.75, 255.))
(0.5, 0.0, 0.25, 1.0)
>>> transform_colors_mpl(())
""" """
colors = list(colors) colors = list(colors)
colors_mpl = tuple([color / 255. for color in colors]) colors_mpl = tuple([color / 255. for color in colors])

View File

@ -49,7 +49,7 @@ from pylot.core.pick.utils import getSNR, earllatepicker, getnoisewin, \
from pylot.core.pick.compare import Comparison from pylot.core.pick.compare import Comparison
from pylot.core.pick.autopick import fmpicker from pylot.core.pick.autopick import fmpicker
from pylot.core.util.defaults import OUTPUTFORMATS, FILTERDEFAULTS from pylot.core.util.defaults import OUTPUTFORMATS, FILTERDEFAULTS
from pylot.core.util.utils import prepTimeAxis, full_range, demeanTrace, isSorted, findComboBoxIndex, clims, \ from pylot.core.util.utils import prep_time_axis, full_range, demeanTrace, isSorted, findComboBoxIndex, clims, \
pick_linestyle_plt, pick_color_plt, \ pick_linestyle_plt, pick_color_plt, \
check4rotated, check4doubled, check_for_gaps_and_merge, check_for_nan, identifyPhase, \ check4rotated, check4doubled, check_for_gaps_and_merge, check_for_nan, identifyPhase, \
loopIdentifyPhase, trim_station_components, transformFilteroptions2String, \ loopIdentifyPhase, trim_station_components, transformFilteroptions2String, \
@ -923,10 +923,10 @@ class WaveformWidgetPG(QtWidgets.QWidget):
msg = 'plotting %s channel of station %s' % (channel, station) msg = 'plotting %s channel of station %s' % (channel, station)
print(msg) print(msg)
stime = trace.stats.starttime - self.wfstart stime = trace.stats.starttime - self.wfstart
time_ax = prepTimeAxis(stime, trace) time_ax = prep_time_axis(stime, trace)
if st_syn: if st_syn:
stime_syn = trace_syn.stats.starttime - self.wfstart stime_syn = trace_syn.stats.starttime - self.wfstart
time_ax_syn = prepTimeAxis(stime_syn, trace_syn) time_ax_syn = prep_time_axis(stime_syn, trace_syn)
if method == 'fast': if method == 'fast':
trace.data, time_ax = self.minMax(trace, time_ax) trace.data, time_ax = self.minMax(trace, time_ax)
@ -1409,7 +1409,7 @@ class PylotCanvas(FigureCanvas):
msg = 'plotting %s channel of station %s' % (channel, station) msg = 'plotting %s channel of station %s' % (channel, station)
print(msg) print(msg)
stime = trace.stats.starttime - wfstart stime = trace.stats.starttime - wfstart
time_ax = prepTimeAxis(stime, trace) time_ax = prep_time_axis(stime, trace)
if time_ax is not None: if time_ax is not None:
if scaleToChannel: if scaleToChannel:
st_scale = wfdata.select(channel=scaleToChannel) st_scale = wfdata.select(channel=scaleToChannel)
@ -1447,7 +1447,7 @@ class PylotCanvas(FigureCanvas):
if not scaleddata: if not scaleddata:
trace.detrend('constant') trace.detrend('constant')
trace.normalize(np.max(np.abs(trace.data)) * 2) trace.normalize(np.max(np.abs(trace.data)) * 2)
time_ax = prepTimeAxis(stime, trace) time_ax = prep_time_axis(stime, trace)
times = [time for index, time in enumerate(time_ax) if not index % nth_sample] times = [time for index, time in enumerate(time_ax) if not index % nth_sample]
p_data = compare_stream[0].data p_data = compare_stream[0].data
# #normalize # #normalize
@ -2548,7 +2548,7 @@ class PickDlg(QDialog):
# prepare plotting of data # prepare plotting of data
for trace in data: for trace in data:
t = prepTimeAxis(trace.stats.starttime - stime, trace) t = prep_time_axis(trace.stats.starttime - stime, trace)
inoise = getnoisewin(t, ini_pick, noise_win, gap_win) inoise = getnoisewin(t, ini_pick, noise_win, gap_win)
trace = demeanTrace(trace, inoise) trace = demeanTrace(trace, inoise)
# upscale trace data in a way that each trace is vertically zoomed to noiselevel*factor # upscale trace data in a way that each trace is vertically zoomed to noiselevel*factor