feat: add type hints and tests for plot utils

This commit is contained in:
Sebastian Wehling-Benatelli 2023-04-23 21:37:20 +02:00
parent 4861d33e9a
commit c468bfbe84
2 changed files with 59 additions and 27 deletions

View File

@ -7,6 +7,7 @@ import platform
import re
import subprocess
import warnings
from typing import Literal, Tuple, Type
import numpy as np
from obspy import UTCDateTime, read
@ -18,6 +19,10 @@ from pylot.core.io.inputs import PylotParameter, FilterOptions
from pylot.core.util.obspyDMT_interface import check_obspydmt_eventfolder
from pylot.styles import style_settings
Rgba: Type[tuple] = Tuple[int, int, int, int]
Mplrgba: Type[tuple] = Tuple[float, float, float, float]
Mplrgbastr: Type[tuple] = Tuple[str, str, str, str]
def _pickle_method(m):
if m.im_self is None:
@ -654,32 +659,53 @@ def key_for_set_value(d):
return r
def prepTimeAxis(stime, trace, verbosity=0):
def prep_time_axis(offset, trace, verbosity=0):
"""
takes a starttime and a trace object and returns a valid time axis for
takes an offset and a trace object and returns a valid time axis for
plotting
:param stime: start time of the actual seismogram as UTCDateTime
:type stime: `~obspy.core.utcdatetime.UTCDateTime`
:param offset: offset of the actual seismogram on plotting axis
:type offset: float or int
:param trace: seismic trace object
:type trace: `~obspy.core.trace.Trace`
:param verbosity: if != 0, debug output will be written to console
:type verbosity: int
:return: valid numpy array with time stamps for plotting
:rtype: `~numpy.ndarray`
>>> tr = read()[0]
>>> prep_time_axis(0., tr)
array([0.00000000e+00, 1.00033344e-02, 2.00066689e-02, ...,
2.99799933e+01, 2.99899967e+01, 3.00000000e+01])
>>> prep_time_axis(22.5, tr)
array([22.5 , 22.51000333, 22.52000667, ..., 52.47999333,
52.48999667, 52.5 ])
>>> prep_time_axis(tr.stats.starttime, tr)
Traceback (most recent call last):
...
AssertionError: 'offset' is not of type 'float' or 'int'; type: <class 'obspy.core.utcdatetime.UTCDateTime'>
>>> tr.stats.npts -= 1
>>> prep_time_axis(0, tr)
array([0.00000000e+00, 1.00033356e-02, 2.00066711e-02, ...,
2.99699933e+01, 2.99799967e+01, 2.99900000e+01])
>>> tr.stats.npts += 2
>>> prep_time_axis(0, tr)
array([0.00000000e+00, 1.00033333e-02, 2.00066667e-02, ...,
2.99899933e+01, 2.99999967e+01, 3.00100000e+01])
"""
assert isinstance(offset, (float, int)), "'offset' is not of type 'float' or 'int'; type: {}".format(type(offset))
nsamp = trace.stats.npts
srate = trace.stats.sampling_rate
tincr = trace.stats.delta
etime = stime + nsamp / srate
time_ax = np.linspace(stime, etime, nsamp)
etime = offset + nsamp / srate
time_ax = np.linspace(offset, etime, nsamp)
if len(time_ax) < nsamp:
if verbosity:
print('elongate time axes by one datum')
time_ax = np.arange(stime, etime + tincr, tincr)
time_ax = np.arange(offset, etime + tincr, tincr)
elif len(time_ax) > nsamp:
if verbosity:
print('shorten time axes by one datum')
time_ax = np.arange(stime, etime - tincr, tincr)
time_ax = np.arange(offset, etime - tincr, tincr)
if len(time_ax) != nsamp:
print('Station {0}, {1} samples of data \n '
'{2} length of time vector \n'
@ -712,7 +738,7 @@ def find_horizontals(data):
return rval
def pick_color(picktype, phase, quality=0):
def pick_color(picktype: Literal['manual', 'automatic'], phase: Literal['P', 'S'], quality: int = 0) -> Rgba:
"""
Create pick color by modifying the base color by the quality.
@ -725,7 +751,7 @@ def pick_color(picktype, phase, quality=0):
:param quality: quality of pick. Decides the new intensity of the modifier color
:type quality: int
:return: tuple containing modified rgba color values
:rtype: (int, int, int, int)
:rtype: Rgba
"""
min_quality = 3
bpc = base_phase_colors(picktype, phase) # returns dict like {'modifier': 'g', 'rgba': (0, 0, 255, 255)}
@ -781,17 +807,17 @@ def pick_linestyle_plt(picktype, key):
return linestyles[picktype][key]
def modify_rgba(rgba, modifier, intensity):
def modify_rgba(rgba: Rgba, modifier: Literal['r', 'g', 'b'], intensity: float) -> Rgba:
"""
Modify rgba color by adding the given intensity to the modifier color
:param rgba: tuple containing rgba values
:type rgba: (int, int, int, int)
:param modifier: which color should be modified, eg. 'r', 'g', 'b'
:type modifier: str
:type rgba: Rgba
:param modifier: which color should be modified; options: 'r', 'g', 'b'
:type modifier: Literal['r', 'g', 'b']
:param intensity: intensity to be added to selected color
:type intensity: float
:return: tuple containing rgba values
:rtype: (int, int, int, int)
:rtype: Rgba
"""
rgba = list(rgba)
index = {'r': 0,
@ -825,18 +851,20 @@ def transform_colors_mpl_str(colors, no_alpha=False):
Transforms rgba color values to a matplotlib string of color values with a range of [0, 1]
:param colors: tuple of rgba color values ranging from [0, 255]
:type colors: (float, float, float, float)
:param no_alpha: Wether to return a alpha value in the matplotlib color string
:param no_alpha: Whether to return an alpha value in the matplotlib color string
:type no_alpha: bool
:return: String containing r, g, b values and alpha value if no_alpha is False (default)
:rtype: str
>>> transform_colors_mpl_str((255., 255., 255., 255.), True)
'(1.0, 1.0, 1.0)'
>>> transform_colors_mpl_str((255., 255., 255., 255.))
'(1.0, 1.0, 1.0, 1.0)'
"""
colors = list(colors)
colors_mpl = tuple([color / 255. for color in colors])
if no_alpha:
colors_mpl = '({}, {}, {})'.format(*colors_mpl)
return '({}, {}, {})'.format(*transform_colors_mpl(colors))
else:
colors_mpl = '({}, {}, {}, {})'.format(*colors_mpl)
return colors_mpl
return '({}, {}, {}, {})'.format(*transform_colors_mpl(colors))
def transform_colors_mpl(colors):
@ -846,6 +874,10 @@ def transform_colors_mpl(colors):
:type colors: (float, float, float, float)
:return: tuple of rgba color values ranging from [0, 1]
:rtype: (float, float, float, float)
>>> transform_colors_mpl((127.5, 0., 63.75, 255.))
(0.5, 0.0, 0.25, 1.0)
>>> transform_colors_mpl(())
"""
colors = list(colors)
colors_mpl = tuple([color / 255. for color in colors])

View File

@ -49,7 +49,7 @@ from pylot.core.pick.utils import getSNR, earllatepicker, getnoisewin, \
from pylot.core.pick.compare import Comparison
from pylot.core.pick.autopick import fmpicker
from pylot.core.util.defaults import OUTPUTFORMATS, FILTERDEFAULTS
from pylot.core.util.utils import prepTimeAxis, full_range, demeanTrace, isSorted, findComboBoxIndex, clims, \
from pylot.core.util.utils import prep_time_axis, full_range, demeanTrace, isSorted, findComboBoxIndex, clims, \
pick_linestyle_plt, pick_color_plt, \
check4rotated, check4doubled, merge_stream, identifyPhase, \
loopIdentifyPhase, trim_station_components, transformFilteroptions2String, \
@ -923,10 +923,10 @@ class WaveformWidgetPG(QtWidgets.QWidget):
msg = 'plotting %s channel of station %s' % (channel, station)
print(msg)
stime = trace.stats.starttime - self.wfstart
time_ax = prepTimeAxis(stime, trace)
time_ax = prep_time_axis(stime, trace)
if st_syn:
stime_syn = trace_syn.stats.starttime - self.wfstart
time_ax_syn = prepTimeAxis(stime_syn, trace_syn)
time_ax_syn = prep_time_axis(stime_syn, trace_syn)
if method == 'fast':
trace.data, time_ax = self.minMax(trace, time_ax)
@ -1409,7 +1409,7 @@ class PylotCanvas(FigureCanvas):
msg = 'plotting %s channel of station %s' % (channel, station)
print(msg)
stime = trace.stats.starttime - wfstart
time_ax = prepTimeAxis(stime, trace)
time_ax = prep_time_axis(stime, trace)
if time_ax is not None:
if scaleToChannel:
st_scale = wfdata.select(channel=scaleToChannel)
@ -1447,7 +1447,7 @@ class PylotCanvas(FigureCanvas):
if not scaleddata:
trace.detrend('constant')
trace.normalize(np.max(np.abs(trace.data)) * 2)
time_ax = prepTimeAxis(stime, trace)
time_ax = prep_time_axis(stime, trace)
times = [time for index, time in enumerate(time_ax) if not index % nth_sample]
p_data = compare_stream[0].data
# #normalize
@ -2548,7 +2548,7 @@ class PickDlg(QDialog):
# prepare plotting of data
for trace in data:
t = prepTimeAxis(trace.stats.starttime - stime, trace)
t = prep_time_axis(trace.stats.starttime - stime, trace)
inoise = getnoisewin(t, ini_pick, noise_win, gap_win)
trace = demeanTrace(trace, inoise)
# upscale trace data in a way that each trace is vertically zoomed to noiselevel*factor