refactor: rewrite plotWFData to reduce complexity
This commit is contained in:
		
							parent
							
								
									090bfb47d1
								
							
						
					
					
						commit
						a2add3034a
					
				| @ -140,18 +140,6 @@ class LogWidget(QtWidgets.QWidget): | ||||
|         self.stderr.append(60 * '#' + '\n\n') | ||||
| 
 | ||||
| 
 | ||||
| def getDataType(parent): | ||||
|     type = QInputDialog().getItem(parent, "Select phases type", "Type:", | ||||
|                                   ["manual", "automatic"]) | ||||
| 
 | ||||
|     if type[0].startswith('auto'): | ||||
|         type = 'auto' | ||||
|     else: | ||||
|         type = type[0] | ||||
| 
 | ||||
|     return type | ||||
| 
 | ||||
| 
 | ||||
| def plot_pdf(_axes, x, y, annotation, bbox_props, xlabel=None, ylabel=None, | ||||
|              title=None): | ||||
|     # try method or data | ||||
| @ -1009,15 +997,6 @@ class WaveformWidgetPG(QtWidgets.QWidget): | ||||
|         time_ax = np.linspace(time_ax[0], time_ax[-1], num=len(data)) | ||||
|         return data, time_ax | ||||
| 
 | ||||
|     # def getAxes(self): | ||||
|     #     return self.axes | ||||
| 
 | ||||
|     # def getXLims(self): | ||||
|     #     return self.getAxes().get_xlim() | ||||
| 
 | ||||
|     # def getYLims(self): | ||||
|     #     return self.getAxes().get_ylim() | ||||
| 
 | ||||
|     def setXLims(self, lims): | ||||
|         vb = self.plotWidget.getPlotItem().getViewBox() | ||||
|         vb.setXRange(float(lims[0]), float(lims[1]), padding=0) | ||||
| @ -1171,8 +1150,6 @@ class PylotCanvas(FigureCanvas): | ||||
|                 break | ||||
|         if not ax_check: return | ||||
| 
 | ||||
|         # self.updateCurrentLimits() #maybe put this down to else: | ||||
| 
 | ||||
|         # calculate delta (relative values in axis) | ||||
|         old_x, old_y = self.press_rel | ||||
|         xdiff = gui_event.x - old_x | ||||
| @ -1380,83 +1357,96 @@ class PylotCanvas(FigureCanvas): | ||||
|                    component='*', nth_sample=1, iniPick=None, verbosity=0, | ||||
|                    plot_additional=False, additional_channel=None, scaleToChannel=None, | ||||
|                    snr=None): | ||||
|         def get_wf_dict(data: Stream = Stream(), linecolor = 'k', offset: float = 0., **plot_kwargs): | ||||
|             return dict(data=data, linecolor=linecolor, offset=offset, plot_kwargs=plot_kwargs) | ||||
|         ax = self.prepare_plot() | ||||
|         self.clearPlotDict() | ||||
| 
 | ||||
|         wfstart, wfend = self.get_wf_range(wfdata) | ||||
|         compclass = self.get_comp_class() | ||||
|         plot_streams = self.get_plot_streams(wfdata, wfdata_compare, component, compclass) | ||||
| 
 | ||||
|         st_main = plot_streams['wfdata']['data'] | ||||
|         if mapping: | ||||
|             plot_positions = self.calcPlotPositions(st_main, compclass) | ||||
| 
 | ||||
|         nslc = self.get_sorted_nslc(st_main) | ||||
|         nmax = self.plot_traces(ax, plot_streams, nslc, wfstart, mapping, plot_positions, | ||||
|                                 scaleToChannel, noiselevel, scaleddata, nth_sample, verbosity) | ||||
| 
 | ||||
|         if plot_additional and additional_channel: | ||||
|             self.plot_additional_trace(ax, wfdata, additional_channel, scaleToChannel, | ||||
|                                        scaleddata, nth_sample, wfstart) | ||||
| 
 | ||||
|         self.finalize_plot(ax, wfstart, wfend, nmax, zoomx, zoomy, iniPick, title, snr) | ||||
| 
 | ||||
|     def prepare_plot(self): | ||||
|         ax = self.axes[0] | ||||
|         ax.cla() | ||||
|         return ax | ||||
| 
 | ||||
|         self.clearPlotDict() | ||||
|         wfstart, wfend = full_range(wfdata) | ||||
|         nmax = 0 | ||||
|     def get_wf_range(self, wfdata): | ||||
|         return full_range(wfdata) | ||||
| 
 | ||||
|     def get_comp_class(self): | ||||
|         settings = QSettings() | ||||
|         compclass = SetChannelComponents.from_qsettings(settings) | ||||
|         return SetChannelComponents.from_qsettings(settings) | ||||
| 
 | ||||
|     def get_plot_streams(self, wfdata, wfdata_compare, component, compclass): | ||||
|         def get_wf_dict(data=Stream(), linecolor='k', offset=0., **plot_kwargs): | ||||
|             return dict(data=data, linecolor=linecolor, offset=offset, plot_kwargs=plot_kwargs) | ||||
| 
 | ||||
|         linecolor = (0., 0., 0., 1.) if not self.style else self.style['linecolor']['rgba_mpl'] | ||||
|         plot_streams = { | ||||
|             'wfdata': get_wf_dict(linecolor=linecolor, linewidth=0.7), | ||||
|             'wfdata_comp': get_wf_dict(offset=0.1, linecolor='b', alpha=0.7, linewidth=0.5) | ||||
|         } | ||||
| 
 | ||||
|         plot_streams = dict(wfdata=get_wf_dict(linecolor=linecolor, linewidth=0.7), | ||||
|                             wfdata_comp=get_wf_dict(offset=0.1, linecolor='b', alpha=0.7, linewidth=0.5)) | ||||
| 
 | ||||
|         if not component == '*': | ||||
|         if component != '*': | ||||
|             alter_comp = compclass.getCompPosition(component) | ||||
|             # alter_comp = str(alter_comp[0]) | ||||
| 
 | ||||
|             plot_streams['wfdata']['data'] = wfdata.select(component=component) | ||||
|             plot_streams['wfdata']['data'] += wfdata.select(component=alter_comp) | ||||
|             plot_streams['wfdata']['data'] = wfdata.select(component=component) + wfdata.select(component=alter_comp) | ||||
|             if wfdata_compare: | ||||
|                 plot_streams['wfdata_comp']['data'] = wfdata_compare.select(component=component) | ||||
|                 plot_streams['wfdata_comp']['data'] += wfdata_compare.select(component=alter_comp) | ||||
|                 plot_streams['wfdata_comp']['data'] = wfdata_compare.select( | ||||
|                     component=component) + wfdata_compare.select(component=alter_comp) | ||||
|         else: | ||||
|             plot_streams['wfdata']['data'] = wfdata | ||||
|             if wfdata_compare: | ||||
|                 plot_streams['wfdata_comp']['data'] = wfdata_compare | ||||
| 
 | ||||
|         st_main = plot_streams['wfdata']['data'] | ||||
|         return plot_streams | ||||
| 
 | ||||
|         if mapping: | ||||
|             plot_positions = self.calcPlotPositions(st_main, compclass) | ||||
| 
 | ||||
|         # list containing tuples of network, station, channel and plot position (for sorting) | ||||
|         nslc = [] | ||||
|         for plot_pos, trace in enumerate(st_main): | ||||
|             if not trace.stats.channel[-1] in ['Z', 'N', 'E', '1', '2', '3']: | ||||
|                 print('Warning: Unrecognized channel {}'.format(trace.stats.channel)) | ||||
|                 continue | ||||
|             nslc.append(trace.get_id()) | ||||
|         nslc.sort() | ||||
|         nslc.reverse() | ||||
|     def get_sorted_nslc(self, st_main): | ||||
|         nslc = [trace.get_id() for trace in st_main if trace.stats.channel[-1] in ['Z', 'N', 'E', '1', '2', '3']] | ||||
|         nslc.sort(reverse=True) | ||||
|         return nslc | ||||
| 
 | ||||
|     def plot_traces(self, ax, plot_streams, nslc, wfstart, mapping, plot_positions, scaleToChannel, noiselevel, | ||||
|                     scaleddata, nth_sample, verbosity): | ||||
|         nmax = 0 | ||||
|         for n, seed_id in enumerate(nslc): | ||||
|             network, station, location, channel = seed_id.split('.') | ||||
|             for wf_name, wf_dict in plot_streams.items(): | ||||
|                 st_select = wf_dict.get('data') | ||||
|                 if not st_select: | ||||
|                     continue | ||||
|                 st = st_select.select(id=seed_id) | ||||
|                 trace = st[0].copy() | ||||
|                 trace = st_select.select(id=seed_id)[0].copy() | ||||
|                 if mapping: | ||||
|                     n = plot_positions[trace.stats.channel] | ||||
|                 if n > nmax: | ||||
|                     nmax = n | ||||
|                 if verbosity: | ||||
|                     msg = 'plotting %s channel of station %s' % (channel, station) | ||||
|                     print(msg) | ||||
|                 stime = trace.stats.starttime - wfstart | ||||
|                 time_ax = prep_time_axis(stime, trace) | ||||
|                     print(f'plotting {channel} channel of station {station}') | ||||
|                 time_ax = prep_time_axis(trace.stats.starttime - wfstart, trace) | ||||
|                 self.plot_trace(ax, trace, time_ax, wf_dict, n, scaleToChannel, noiselevel, scaleddata, nth_sample) | ||||
|                 self.setPlotDict(n, seed_id) | ||||
|         return nmax | ||||
| 
 | ||||
|     def plot_trace(self, ax, trace, time_ax, wf_dict, n, scaleToChannel, noiselevel, scaleddata, nth_sample): | ||||
|         if time_ax is not None: | ||||
|             if scaleToChannel: | ||||
|                         st_scale = wfdata.select(channel=scaleToChannel) | ||||
|                         if st_scale: | ||||
|                             tr = st_scale[0] | ||||
|                             trace.detrend('constant') | ||||
|                             trace.normalize(np.max(np.abs(tr.data)) * 2) | ||||
|                 self.scale_trace(trace, scaleToChannel) | ||||
|                 scaleddata = True | ||||
|             if not scaleddata: | ||||
|                 trace.detrend('constant') | ||||
|                 trace.normalize(np.max(np.abs(trace.data)) * 2) | ||||
| 
 | ||||
|             offset = wf_dict.get('offset') | ||||
| 
 | ||||
|             times = [time for index, time in enumerate(time_ax) if not index % nth_sample] | ||||
| @ -1464,41 +1454,43 @@ class PylotCanvas(FigureCanvas): | ||||
|             ax.axhline(n, color="0.5", lw=0.5) | ||||
|             ax.plot(times, data, color=wf_dict.get('linecolor'), **wf_dict.get('plot_kwargs')) | ||||
|             if noiselevel is not None: | ||||
|                         for level in [-noiselevel[channel], noiselevel[channel]]: | ||||
|                             ax.plot([time_ax[0], time_ax[-1]], | ||||
|                                     [n + level, n + level], | ||||
|                                     color=wf_dict.get('linecolor'), | ||||
|                                     linestyle='dashed') | ||||
|                     self.setPlotDict(n, seed_id) | ||||
|         if plot_additional and additional_channel: | ||||
|             compare_stream = wfdata.select(channel=additional_channel) | ||||
|             if compare_stream: | ||||
|                 trace = compare_stream[0] | ||||
|                 if scaleToChannel: | ||||
|                 self.plot_noise_level(ax, time_ax, noiselevel, channel, n, wf_dict.get('linecolor')) | ||||
| 
 | ||||
|     def scale_trace(self, trace, scaleToChannel): | ||||
|         st_scale = wfdata.select(channel=scaleToChannel) | ||||
|         if st_scale: | ||||
|             tr = st_scale[0] | ||||
|             trace.detrend('constant') | ||||
|             trace.normalize(np.max(np.abs(tr.data)) * 2) | ||||
| 
 | ||||
|     def plot_noise_level(self, ax, time_ax, noiselevel, channel, n, linecolor): | ||||
|         for level in [-noiselevel[channel], noiselevel[channel]]: | ||||
|             ax.plot([time_ax[0], time_ax[-1]], [n + level, n + level], color=linecolor, linestyle='dashed') | ||||
| 
 | ||||
|     def plot_additional_trace(self, ax, wfdata, additional_channel, scaleToChannel, scaleddata, nth_sample, wfstart): | ||||
|         compare_stream = wfdata.select(channel=additional_channel) | ||||
|         if compare_stream: | ||||
|             trace = compare_stream[0] | ||||
|             if scaleToChannel: | ||||
|                 self.scale_trace(trace, scaleToChannel) | ||||
|                 scaleddata = True | ||||
|             if not scaleddata: | ||||
|                 trace.detrend('constant') | ||||
|                 trace.normalize(np.max(np.abs(trace.data)) * 2) | ||||
|                 time_ax = prep_time_axis(stime, trace) | ||||
|             time_ax = prep_time_axis(trace.stats.starttime - wfstart, trace) | ||||
|             self.plot_additional_data(ax, trace, time_ax, nth_sample) | ||||
| 
 | ||||
|     def plot_additional_data(self, ax, trace, time_ax, nth_sample): | ||||
|         times = [time for index, time in enumerate(time_ax) if not index % nth_sample] | ||||
|                 p_data = compare_stream[0].data | ||||
|                 # #normalize | ||||
|                 # p_max = max(abs(p_data)) | ||||
|                 # p_data /= p_max | ||||
|         p_data = trace.data | ||||
|         for index in range(3): | ||||
|             ax.plot(times, p_data, color='red', alpha=0.5, linewidth=0.7) | ||||
|             p_data += 1 | ||||
| 
 | ||||
|     def finalize_plot(self, ax, wfstart, wfend, nmax, zoomx, zoomy, iniPick, title, snr): | ||||
|         if iniPick: | ||||
|             ax.vlines(iniPick, ax.get_ylim()[0], ax.get_ylim()[1], | ||||
|                       colors='m', linestyles='dashed', | ||||
|                       linewidth=2) | ||||
|         xlabel = 'seconds since {0}'.format(wfstart) | ||||
|             ax.vlines(iniPick, ax.get_ylim()[0], ax.get_ylim()[1], colors='m', linestyles='dashed', linewidth=2) | ||||
|         xlabel = f'seconds since {wfstart}' | ||||
|         ylabel = '' | ||||
|         self.updateWidget(xlabel, ylabel, title) | ||||
|         self.setXLims(ax, [0, wfend - wfstart]) | ||||
| @ -1508,15 +1500,14 @@ class PylotCanvas(FigureCanvas): | ||||
|         if zoomy is not None: | ||||
|             self.setYLims(ax, zoomy) | ||||
|         if snr is not None: | ||||
|             if snr < 2: | ||||
|                 warning = 'LOW SNR' | ||||
|                 if snr < 1.5: | ||||
|                     warning = 'VERY LOW SNR' | ||||
|                 ax.text(0.1, 0.9, 'WARNING - {}'.format(warning), ha='center', va='center', transform=ax.transAxes, | ||||
|                         color='red') | ||||
| 
 | ||||
|             self.plot_snr_warning(ax, snr) | ||||
|         self.draw() | ||||
| 
 | ||||
|     def plot_snr_warning(self, ax, snr): | ||||
|         if snr < 2: | ||||
|             warning = 'LOW SNR' if snr >= 1.5 else 'VERY LOW SNR' | ||||
|             ax.text(0.1, 0.9, f'WARNING - {warning}', ha='center', va='center', transform=ax.transAxes, color='red') | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def getXLims(ax): | ||||
|         return ax.get_xlim() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user