refactor: rewrite plotWFData to reduce complexity

This commit is contained in:
Sebastian Wehling-Benatelli 2024-07-20 10:40:05 +02:00
parent 090bfb47d1
commit a2add3034a

View File

@ -140,18 +140,6 @@ class LogWidget(QtWidgets.QWidget):
self.stderr.append(60 * '#' + '\n\n')
def getDataType(parent):
type = QInputDialog().getItem(parent, "Select phases type", "Type:",
["manual", "automatic"])
if type[0].startswith('auto'):
type = 'auto'
else:
type = type[0]
return type
def plot_pdf(_axes, x, y, annotation, bbox_props, xlabel=None, ylabel=None,
title=None):
# try method or data
@ -1009,15 +997,6 @@ class WaveformWidgetPG(QtWidgets.QWidget):
time_ax = np.linspace(time_ax[0], time_ax[-1], num=len(data))
return data, time_ax
# def getAxes(self):
# return self.axes
# def getXLims(self):
# return self.getAxes().get_xlim()
# def getYLims(self):
# return self.getAxes().get_ylim()
def setXLims(self, lims):
vb = self.plotWidget.getPlotItem().getViewBox()
vb.setXRange(float(lims[0]), float(lims[1]), padding=0)
@ -1171,8 +1150,6 @@ class PylotCanvas(FigureCanvas):
break
if not ax_check: return
# self.updateCurrentLimits() #maybe put this down to else:
# calculate delta (relative values in axis)
old_x, old_y = self.press_rel
xdiff = gui_event.x - old_x
@ -1380,125 +1357,140 @@ class PylotCanvas(FigureCanvas):
component='*', nth_sample=1, iniPick=None, verbosity=0,
plot_additional=False, additional_channel=None, scaleToChannel=None,
snr=None):
def get_wf_dict(data: Stream = Stream(), linecolor = 'k', offset: float = 0., **plot_kwargs):
return dict(data=data, linecolor=linecolor, offset=offset, plot_kwargs=plot_kwargs)
ax = self.prepare_plot()
self.clearPlotDict()
wfstart, wfend = self.get_wf_range(wfdata)
compclass = self.get_comp_class()
plot_streams = self.get_plot_streams(wfdata, wfdata_compare, component, compclass)
st_main = plot_streams['wfdata']['data']
if mapping:
plot_positions = self.calcPlotPositions(st_main, compclass)
nslc = self.get_sorted_nslc(st_main)
nmax = self.plot_traces(ax, plot_streams, nslc, wfstart, mapping, plot_positions,
scaleToChannel, noiselevel, scaleddata, nth_sample, verbosity)
if plot_additional and additional_channel:
self.plot_additional_trace(ax, wfdata, additional_channel, scaleToChannel,
scaleddata, nth_sample, wfstart)
self.finalize_plot(ax, wfstart, wfend, nmax, zoomx, zoomy, iniPick, title, snr)
def prepare_plot(self):
ax = self.axes[0]
ax.cla()
return ax
self.clearPlotDict()
wfstart, wfend = full_range(wfdata)
nmax = 0
def get_wf_range(self, wfdata):
return full_range(wfdata)
def get_comp_class(self):
settings = QSettings()
compclass = SetChannelComponents.from_qsettings(settings)
return SetChannelComponents.from_qsettings(settings)
def get_plot_streams(self, wfdata, wfdata_compare, component, compclass):
def get_wf_dict(data=Stream(), linecolor='k', offset=0., **plot_kwargs):
return dict(data=data, linecolor=linecolor, offset=offset, plot_kwargs=plot_kwargs)
linecolor = (0., 0., 0., 1.) if not self.style else self.style['linecolor']['rgba_mpl']
plot_streams = {
'wfdata': get_wf_dict(linecolor=linecolor, linewidth=0.7),
'wfdata_comp': get_wf_dict(offset=0.1, linecolor='b', alpha=0.7, linewidth=0.5)
}
plot_streams = dict(wfdata=get_wf_dict(linecolor=linecolor, linewidth=0.7),
wfdata_comp=get_wf_dict(offset=0.1, linecolor='b', alpha=0.7, linewidth=0.5))
if not component == '*':
if component != '*':
alter_comp = compclass.getCompPosition(component)
# alter_comp = str(alter_comp[0])
plot_streams['wfdata']['data'] = wfdata.select(component=component)
plot_streams['wfdata']['data'] += wfdata.select(component=alter_comp)
plot_streams['wfdata']['data'] = wfdata.select(component=component) + wfdata.select(component=alter_comp)
if wfdata_compare:
plot_streams['wfdata_comp']['data'] = wfdata_compare.select(component=component)
plot_streams['wfdata_comp']['data'] += wfdata_compare.select(component=alter_comp)
plot_streams['wfdata_comp']['data'] = wfdata_compare.select(
component=component) + wfdata_compare.select(component=alter_comp)
else:
plot_streams['wfdata']['data'] = wfdata
if wfdata_compare:
plot_streams['wfdata_comp']['data'] = wfdata_compare
st_main = plot_streams['wfdata']['data']
return plot_streams
if mapping:
plot_positions = self.calcPlotPositions(st_main, compclass)
# list containing tuples of network, station, channel and plot position (for sorting)
nslc = []
for plot_pos, trace in enumerate(st_main):
if not trace.stats.channel[-1] in ['Z', 'N', 'E', '1', '2', '3']:
print('Warning: Unrecognized channel {}'.format(trace.stats.channel))
continue
nslc.append(trace.get_id())
nslc.sort()
nslc.reverse()
def get_sorted_nslc(self, st_main):
nslc = [trace.get_id() for trace in st_main if trace.stats.channel[-1] in ['Z', 'N', 'E', '1', '2', '3']]
nslc.sort(reverse=True)
return nslc
def plot_traces(self, ax, plot_streams, nslc, wfstart, mapping, plot_positions, scaleToChannel, noiselevel,
scaleddata, nth_sample, verbosity):
nmax = 0
for n, seed_id in enumerate(nslc):
network, station, location, channel = seed_id.split('.')
for wf_name, wf_dict in plot_streams.items():
st_select = wf_dict.get('data')
if not st_select:
continue
st = st_select.select(id=seed_id)
trace = st[0].copy()
trace = st_select.select(id=seed_id)[0].copy()
if mapping:
n = plot_positions[trace.stats.channel]
if n > nmax:
nmax = n
if verbosity:
msg = 'plotting %s channel of station %s' % (channel, station)
print(msg)
stime = trace.stats.starttime - wfstart
time_ax = prep_time_axis(stime, trace)
if time_ax is not None:
if scaleToChannel:
st_scale = wfdata.select(channel=scaleToChannel)
if st_scale:
tr = st_scale[0]
trace.detrend('constant')
trace.normalize(np.max(np.abs(tr.data)) * 2)
scaleddata = True
if not scaleddata:
trace.detrend('constant')
trace.normalize(np.max(np.abs(trace.data)) * 2)
print(f'plotting {channel} channel of station {station}')
time_ax = prep_time_axis(trace.stats.starttime - wfstart, trace)
self.plot_trace(ax, trace, time_ax, wf_dict, n, scaleToChannel, noiselevel, scaleddata, nth_sample)
self.setPlotDict(n, seed_id)
return nmax
offset = wf_dict.get('offset')
def plot_trace(self, ax, trace, time_ax, wf_dict, n, scaleToChannel, noiselevel, scaleddata, nth_sample):
if time_ax is not None:
if scaleToChannel:
self.scale_trace(trace, scaleToChannel)
scaleddata = True
if not scaleddata:
trace.detrend('constant')
trace.normalize(np.max(np.abs(trace.data)) * 2)
offset = wf_dict.get('offset')
times = [time for index, time in enumerate(time_ax) if not index % nth_sample]
data = [datum + n + offset for index, datum in enumerate(trace.data) if not index % nth_sample]
ax.axhline(n, color="0.5", lw=0.5)
ax.plot(times, data, color=wf_dict.get('linecolor'), **wf_dict.get('plot_kwargs'))
if noiselevel is not None:
for level in [-noiselevel[channel], noiselevel[channel]]:
ax.plot([time_ax[0], time_ax[-1]],
[n + level, n + level],
color=wf_dict.get('linecolor'),
linestyle='dashed')
self.setPlotDict(n, seed_id)
if plot_additional and additional_channel:
compare_stream = wfdata.select(channel=additional_channel)
if compare_stream:
trace = compare_stream[0]
if scaleToChannel:
st_scale = wfdata.select(channel=scaleToChannel)
if st_scale:
tr = st_scale[0]
trace.detrend('constant')
trace.normalize(np.max(np.abs(tr.data)) * 2)
scaleddata = True
if not scaleddata:
trace.detrend('constant')
trace.normalize(np.max(np.abs(trace.data)) * 2)
time_ax = prep_time_axis(stime, trace)
times = [time for index, time in enumerate(time_ax) if not index % nth_sample]
p_data = compare_stream[0].data
# #normalize
# p_max = max(abs(p_data))
# p_data /= p_max
for index in range(3):
ax.plot(times, p_data, color='red', alpha=0.5, linewidth=0.7)
p_data += 1
times = [time for index, time in enumerate(time_ax) if not index % nth_sample]
data = [datum + n + offset for index, datum in enumerate(trace.data) if not index % nth_sample]
ax.axhline(n, color="0.5", lw=0.5)
ax.plot(times, data, color=wf_dict.get('linecolor'), **wf_dict.get('plot_kwargs'))
if noiselevel is not None:
self.plot_noise_level(ax, time_ax, noiselevel, channel, n, wf_dict.get('linecolor'))
def scale_trace(self, trace, scaleToChannel):
st_scale = wfdata.select(channel=scaleToChannel)
if st_scale:
tr = st_scale[0]
trace.detrend('constant')
trace.normalize(np.max(np.abs(tr.data)) * 2)
def plot_noise_level(self, ax, time_ax, noiselevel, channel, n, linecolor):
for level in [-noiselevel[channel], noiselevel[channel]]:
ax.plot([time_ax[0], time_ax[-1]], [n + level, n + level], color=linecolor, linestyle='dashed')
def plot_additional_trace(self, ax, wfdata, additional_channel, scaleToChannel, scaleddata, nth_sample, wfstart):
compare_stream = wfdata.select(channel=additional_channel)
if compare_stream:
trace = compare_stream[0]
if scaleToChannel:
self.scale_trace(trace, scaleToChannel)
scaleddata = True
if not scaleddata:
trace.detrend('constant')
trace.normalize(np.max(np.abs(trace.data)) * 2)
time_ax = prep_time_axis(trace.stats.starttime - wfstart, trace)
self.plot_additional_data(ax, trace, time_ax, nth_sample)
def plot_additional_data(self, ax, trace, time_ax, nth_sample):
times = [time for index, time in enumerate(time_ax) if not index % nth_sample]
p_data = trace.data
for index in range(3):
ax.plot(times, p_data, color='red', alpha=0.5, linewidth=0.7)
p_data += 1
def finalize_plot(self, ax, wfstart, wfend, nmax, zoomx, zoomy, iniPick, title, snr):
if iniPick:
ax.vlines(iniPick, ax.get_ylim()[0], ax.get_ylim()[1],
colors='m', linestyles='dashed',
linewidth=2)
xlabel = 'seconds since {0}'.format(wfstart)
ax.vlines(iniPick, ax.get_ylim()[0], ax.get_ylim()[1], colors='m', linestyles='dashed', linewidth=2)
xlabel = f'seconds since {wfstart}'
ylabel = ''
self.updateWidget(xlabel, ylabel, title)
self.setXLims(ax, [0, wfend - wfstart])
@ -1508,15 +1500,14 @@ class PylotCanvas(FigureCanvas):
if zoomy is not None:
self.setYLims(ax, zoomy)
if snr is not None:
if snr < 2:
warning = 'LOW SNR'
if snr < 1.5:
warning = 'VERY LOW SNR'
ax.text(0.1, 0.9, 'WARNING - {}'.format(warning), ha='center', va='center', transform=ax.transAxes,
color='red')
self.plot_snr_warning(ax, snr)
self.draw()
def plot_snr_warning(self, ax, snr):
if snr < 2:
warning = 'LOW SNR' if snr >= 1.5 else 'VERY LOW SNR'
ax.text(0.1, 0.9, f'WARNING - {warning}', ha='center', va='center', transform=ax.transAxes, color='red')
@staticmethod
def getXLims(ax):
return ax.get_xlim()