refactor: rewrite plotWFData to reduce complexity
This commit is contained in:
		
							parent
							
								
									090bfb47d1
								
							
						
					
					
						commit
						a2add3034a
					
				| @ -140,18 +140,6 @@ class LogWidget(QtWidgets.QWidget): | |||||||
|         self.stderr.append(60 * '#' + '\n\n') |         self.stderr.append(60 * '#' + '\n\n') | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def getDataType(parent): |  | ||||||
|     type = QInputDialog().getItem(parent, "Select phases type", "Type:", |  | ||||||
|                                   ["manual", "automatic"]) |  | ||||||
| 
 |  | ||||||
|     if type[0].startswith('auto'): |  | ||||||
|         type = 'auto' |  | ||||||
|     else: |  | ||||||
|         type = type[0] |  | ||||||
| 
 |  | ||||||
|     return type |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def plot_pdf(_axes, x, y, annotation, bbox_props, xlabel=None, ylabel=None, | def plot_pdf(_axes, x, y, annotation, bbox_props, xlabel=None, ylabel=None, | ||||||
|              title=None): |              title=None): | ||||||
|     # try method or data |     # try method or data | ||||||
| @ -1009,15 +997,6 @@ class WaveformWidgetPG(QtWidgets.QWidget): | |||||||
|         time_ax = np.linspace(time_ax[0], time_ax[-1], num=len(data)) |         time_ax = np.linspace(time_ax[0], time_ax[-1], num=len(data)) | ||||||
|         return data, time_ax |         return data, time_ax | ||||||
| 
 | 
 | ||||||
|     # def getAxes(self): |  | ||||||
|     #     return self.axes |  | ||||||
| 
 |  | ||||||
|     # def getXLims(self): |  | ||||||
|     #     return self.getAxes().get_xlim() |  | ||||||
| 
 |  | ||||||
|     # def getYLims(self): |  | ||||||
|     #     return self.getAxes().get_ylim() |  | ||||||
| 
 |  | ||||||
|     def setXLims(self, lims): |     def setXLims(self, lims): | ||||||
|         vb = self.plotWidget.getPlotItem().getViewBox() |         vb = self.plotWidget.getPlotItem().getViewBox() | ||||||
|         vb.setXRange(float(lims[0]), float(lims[1]), padding=0) |         vb.setXRange(float(lims[0]), float(lims[1]), padding=0) | ||||||
| @ -1171,8 +1150,6 @@ class PylotCanvas(FigureCanvas): | |||||||
|                 break |                 break | ||||||
|         if not ax_check: return |         if not ax_check: return | ||||||
| 
 | 
 | ||||||
|         # self.updateCurrentLimits() #maybe put this down to else: |  | ||||||
| 
 |  | ||||||
|         # calculate delta (relative values in axis) |         # calculate delta (relative values in axis) | ||||||
|         old_x, old_y = self.press_rel |         old_x, old_y = self.press_rel | ||||||
|         xdiff = gui_event.x - old_x |         xdiff = gui_event.x - old_x | ||||||
| @ -1380,125 +1357,140 @@ class PylotCanvas(FigureCanvas): | |||||||
|                    component='*', nth_sample=1, iniPick=None, verbosity=0, |                    component='*', nth_sample=1, iniPick=None, verbosity=0, | ||||||
|                    plot_additional=False, additional_channel=None, scaleToChannel=None, |                    plot_additional=False, additional_channel=None, scaleToChannel=None, | ||||||
|                    snr=None): |                    snr=None): | ||||||
|         def get_wf_dict(data: Stream = Stream(), linecolor = 'k', offset: float = 0., **plot_kwargs): |         ax = self.prepare_plot() | ||||||
|             return dict(data=data, linecolor=linecolor, offset=offset, plot_kwargs=plot_kwargs) |         self.clearPlotDict() | ||||||
| 
 | 
 | ||||||
|  |         wfstart, wfend = self.get_wf_range(wfdata) | ||||||
|  |         compclass = self.get_comp_class() | ||||||
|  |         plot_streams = self.get_plot_streams(wfdata, wfdata_compare, component, compclass) | ||||||
| 
 | 
 | ||||||
|  |         st_main = plot_streams['wfdata']['data'] | ||||||
|  |         if mapping: | ||||||
|  |             plot_positions = self.calcPlotPositions(st_main, compclass) | ||||||
|  | 
 | ||||||
|  |         nslc = self.get_sorted_nslc(st_main) | ||||||
|  |         nmax = self.plot_traces(ax, plot_streams, nslc, wfstart, mapping, plot_positions, | ||||||
|  |                                 scaleToChannel, noiselevel, scaleddata, nth_sample, verbosity) | ||||||
|  | 
 | ||||||
|  |         if plot_additional and additional_channel: | ||||||
|  |             self.plot_additional_trace(ax, wfdata, additional_channel, scaleToChannel, | ||||||
|  |                                        scaleddata, nth_sample, wfstart) | ||||||
|  | 
 | ||||||
|  |         self.finalize_plot(ax, wfstart, wfend, nmax, zoomx, zoomy, iniPick, title, snr) | ||||||
|  | 
 | ||||||
|  |     def prepare_plot(self): | ||||||
|         ax = self.axes[0] |         ax = self.axes[0] | ||||||
|         ax.cla() |         ax.cla() | ||||||
|  |         return ax | ||||||
| 
 | 
 | ||||||
|         self.clearPlotDict() |     def get_wf_range(self, wfdata): | ||||||
|         wfstart, wfend = full_range(wfdata) |         return full_range(wfdata) | ||||||
|         nmax = 0 |  | ||||||
| 
 | 
 | ||||||
|  |     def get_comp_class(self): | ||||||
|         settings = QSettings() |         settings = QSettings() | ||||||
|         compclass = SetChannelComponents.from_qsettings(settings) |         return SetChannelComponents.from_qsettings(settings) | ||||||
|  | 
 | ||||||
|  |     def get_plot_streams(self, wfdata, wfdata_compare, component, compclass): | ||||||
|  |         def get_wf_dict(data=Stream(), linecolor='k', offset=0., **plot_kwargs): | ||||||
|  |             return dict(data=data, linecolor=linecolor, offset=offset, plot_kwargs=plot_kwargs) | ||||||
| 
 | 
 | ||||||
|         linecolor = (0., 0., 0., 1.) if not self.style else self.style['linecolor']['rgba_mpl'] |         linecolor = (0., 0., 0., 1.) if not self.style else self.style['linecolor']['rgba_mpl'] | ||||||
|  |         plot_streams = { | ||||||
|  |             'wfdata': get_wf_dict(linecolor=linecolor, linewidth=0.7), | ||||||
|  |             'wfdata_comp': get_wf_dict(offset=0.1, linecolor='b', alpha=0.7, linewidth=0.5) | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         plot_streams = dict(wfdata=get_wf_dict(linecolor=linecolor, linewidth=0.7), |         if component != '*': | ||||||
|                             wfdata_comp=get_wf_dict(offset=0.1, linecolor='b', alpha=0.7, linewidth=0.5)) |  | ||||||
| 
 |  | ||||||
|         if not component == '*': |  | ||||||
|             alter_comp = compclass.getCompPosition(component) |             alter_comp = compclass.getCompPosition(component) | ||||||
|             # alter_comp = str(alter_comp[0]) |             plot_streams['wfdata']['data'] = wfdata.select(component=component) + wfdata.select(component=alter_comp) | ||||||
| 
 |  | ||||||
|             plot_streams['wfdata']['data'] = wfdata.select(component=component) |  | ||||||
|             plot_streams['wfdata']['data'] += wfdata.select(component=alter_comp) |  | ||||||
|             if wfdata_compare: |             if wfdata_compare: | ||||||
|                 plot_streams['wfdata_comp']['data'] = wfdata_compare.select(component=component) |                 plot_streams['wfdata_comp']['data'] = wfdata_compare.select( | ||||||
|                 plot_streams['wfdata_comp']['data'] += wfdata_compare.select(component=alter_comp) |                     component=component) + wfdata_compare.select(component=alter_comp) | ||||||
|         else: |         else: | ||||||
|             plot_streams['wfdata']['data'] = wfdata |             plot_streams['wfdata']['data'] = wfdata | ||||||
|             if wfdata_compare: |             if wfdata_compare: | ||||||
|                 plot_streams['wfdata_comp']['data'] = wfdata_compare |                 plot_streams['wfdata_comp']['data'] = wfdata_compare | ||||||
| 
 | 
 | ||||||
|         st_main = plot_streams['wfdata']['data'] |         return plot_streams | ||||||
| 
 | 
 | ||||||
|         if mapping: |     def get_sorted_nslc(self, st_main): | ||||||
|             plot_positions = self.calcPlotPositions(st_main, compclass) |         nslc = [trace.get_id() for trace in st_main if trace.stats.channel[-1] in ['Z', 'N', 'E', '1', '2', '3']] | ||||||
| 
 |         nslc.sort(reverse=True) | ||||||
|         # list containing tuples of network, station, channel and plot position (for sorting) |         return nslc | ||||||
|         nslc = [] |  | ||||||
|         for plot_pos, trace in enumerate(st_main): |  | ||||||
|             if not trace.stats.channel[-1] in ['Z', 'N', 'E', '1', '2', '3']: |  | ||||||
|                 print('Warning: Unrecognized channel {}'.format(trace.stats.channel)) |  | ||||||
|                 continue |  | ||||||
|             nslc.append(trace.get_id()) |  | ||||||
|         nslc.sort() |  | ||||||
|         nslc.reverse() |  | ||||||
| 
 | 
 | ||||||
|  |     def plot_traces(self, ax, plot_streams, nslc, wfstart, mapping, plot_positions, scaleToChannel, noiselevel, | ||||||
|  |                     scaleddata, nth_sample, verbosity): | ||||||
|  |         nmax = 0 | ||||||
|         for n, seed_id in enumerate(nslc): |         for n, seed_id in enumerate(nslc): | ||||||
|             network, station, location, channel = seed_id.split('.') |             network, station, location, channel = seed_id.split('.') | ||||||
|             for wf_name, wf_dict in plot_streams.items(): |             for wf_name, wf_dict in plot_streams.items(): | ||||||
|                 st_select = wf_dict.get('data') |                 st_select = wf_dict.get('data') | ||||||
|                 if not st_select: |                 if not st_select: | ||||||
|                     continue |                     continue | ||||||
|                 st = st_select.select(id=seed_id) |                 trace = st_select.select(id=seed_id)[0].copy() | ||||||
|                 trace = st[0].copy() |  | ||||||
|                 if mapping: |                 if mapping: | ||||||
|                     n = plot_positions[trace.stats.channel] |                     n = plot_positions[trace.stats.channel] | ||||||
|                 if n > nmax: |                 if n > nmax: | ||||||
|                     nmax = n |                     nmax = n | ||||||
|                 if verbosity: |                 if verbosity: | ||||||
|                     msg = 'plotting %s channel of station %s' % (channel, station) |                     print(f'plotting {channel} channel of station {station}') | ||||||
|                     print(msg) |                 time_ax = prep_time_axis(trace.stats.starttime - wfstart, trace) | ||||||
|                 stime = trace.stats.starttime - wfstart |                 self.plot_trace(ax, trace, time_ax, wf_dict, n, scaleToChannel, noiselevel, scaleddata, nth_sample) | ||||||
|                 time_ax = prep_time_axis(stime, trace) |                 self.setPlotDict(n, seed_id) | ||||||
|                 if time_ax is not None: |         return nmax | ||||||
|                     if scaleToChannel: |  | ||||||
|                         st_scale = wfdata.select(channel=scaleToChannel) |  | ||||||
|                         if st_scale: |  | ||||||
|                             tr = st_scale[0] |  | ||||||
|                             trace.detrend('constant') |  | ||||||
|                             trace.normalize(np.max(np.abs(tr.data)) * 2) |  | ||||||
|                             scaleddata = True |  | ||||||
|                     if not scaleddata: |  | ||||||
|                         trace.detrend('constant') |  | ||||||
|                         trace.normalize(np.max(np.abs(trace.data)) * 2) |  | ||||||
| 
 | 
 | ||||||
|                     offset = wf_dict.get('offset') |     def plot_trace(self, ax, trace, time_ax, wf_dict, n, scaleToChannel, noiselevel, scaleddata, nth_sample): | ||||||
|  |         if time_ax is not None: | ||||||
|  |             if scaleToChannel: | ||||||
|  |                 self.scale_trace(trace, scaleToChannel) | ||||||
|  |                 scaleddata = True | ||||||
|  |             if not scaleddata: | ||||||
|  |                 trace.detrend('constant') | ||||||
|  |                 trace.normalize(np.max(np.abs(trace.data)) * 2) | ||||||
|  |             offset = wf_dict.get('offset') | ||||||
| 
 | 
 | ||||||
|                     times = [time for index, time in enumerate(time_ax) if not index % nth_sample] |             times = [time for index, time in enumerate(time_ax) if not index % nth_sample] | ||||||
|                     data = [datum + n + offset for index, datum in enumerate(trace.data) if not index % nth_sample] |             data = [datum + n + offset for index, datum in enumerate(trace.data) if not index % nth_sample] | ||||||
|                     ax.axhline(n, color="0.5", lw=0.5) |             ax.axhline(n, color="0.5", lw=0.5) | ||||||
|                     ax.plot(times, data, color=wf_dict.get('linecolor'), **wf_dict.get('plot_kwargs')) |             ax.plot(times, data, color=wf_dict.get('linecolor'), **wf_dict.get('plot_kwargs')) | ||||||
|                     if noiselevel is not None: |             if noiselevel is not None: | ||||||
|                         for level in [-noiselevel[channel], noiselevel[channel]]: |                 self.plot_noise_level(ax, time_ax, noiselevel, channel, n, wf_dict.get('linecolor')) | ||||||
|                             ax.plot([time_ax[0], time_ax[-1]], |  | ||||||
|                                     [n + level, n + level], |  | ||||||
|                                     color=wf_dict.get('linecolor'), |  | ||||||
|                                     linestyle='dashed') |  | ||||||
|                     self.setPlotDict(n, seed_id) |  | ||||||
|         if plot_additional and additional_channel: |  | ||||||
|             compare_stream = wfdata.select(channel=additional_channel) |  | ||||||
|             if compare_stream: |  | ||||||
|                 trace = compare_stream[0] |  | ||||||
|                 if scaleToChannel: |  | ||||||
|                     st_scale = wfdata.select(channel=scaleToChannel) |  | ||||||
|                     if st_scale: |  | ||||||
|                         tr = st_scale[0] |  | ||||||
|                         trace.detrend('constant') |  | ||||||
|                         trace.normalize(np.max(np.abs(tr.data)) * 2) |  | ||||||
|                         scaleddata = True |  | ||||||
|                 if not scaleddata: |  | ||||||
|                     trace.detrend('constant') |  | ||||||
|                     trace.normalize(np.max(np.abs(trace.data)) * 2) |  | ||||||
|                 time_ax = prep_time_axis(stime, trace) |  | ||||||
|                 times = [time for index, time in enumerate(time_ax) if not index % nth_sample] |  | ||||||
|                 p_data = compare_stream[0].data |  | ||||||
|                 # #normalize |  | ||||||
|                 # p_max = max(abs(p_data)) |  | ||||||
|                 # p_data /= p_max |  | ||||||
|                 for index in range(3): |  | ||||||
|                     ax.plot(times, p_data, color='red', alpha=0.5, linewidth=0.7) |  | ||||||
|                     p_data += 1 |  | ||||||
| 
 | 
 | ||||||
|  |     def scale_trace(self, trace, scaleToChannel): | ||||||
|  |         st_scale = wfdata.select(channel=scaleToChannel) | ||||||
|  |         if st_scale: | ||||||
|  |             tr = st_scale[0] | ||||||
|  |             trace.detrend('constant') | ||||||
|  |             trace.normalize(np.max(np.abs(tr.data)) * 2) | ||||||
|  | 
 | ||||||
|  |     def plot_noise_level(self, ax, time_ax, noiselevel, channel, n, linecolor): | ||||||
|  |         for level in [-noiselevel[channel], noiselevel[channel]]: | ||||||
|  |             ax.plot([time_ax[0], time_ax[-1]], [n + level, n + level], color=linecolor, linestyle='dashed') | ||||||
|  | 
 | ||||||
|  |     def plot_additional_trace(self, ax, wfdata, additional_channel, scaleToChannel, scaleddata, nth_sample, wfstart): | ||||||
|  |         compare_stream = wfdata.select(channel=additional_channel) | ||||||
|  |         if compare_stream: | ||||||
|  |             trace = compare_stream[0] | ||||||
|  |             if scaleToChannel: | ||||||
|  |                 self.scale_trace(trace, scaleToChannel) | ||||||
|  |                 scaleddata = True | ||||||
|  |             if not scaleddata: | ||||||
|  |                 trace.detrend('constant') | ||||||
|  |                 trace.normalize(np.max(np.abs(trace.data)) * 2) | ||||||
|  |             time_ax = prep_time_axis(trace.stats.starttime - wfstart, trace) | ||||||
|  |             self.plot_additional_data(ax, trace, time_ax, nth_sample) | ||||||
|  | 
 | ||||||
|  |     def plot_additional_data(self, ax, trace, time_ax, nth_sample): | ||||||
|  |         times = [time for index, time in enumerate(time_ax) if not index % nth_sample] | ||||||
|  |         p_data = trace.data | ||||||
|  |         for index in range(3): | ||||||
|  |             ax.plot(times, p_data, color='red', alpha=0.5, linewidth=0.7) | ||||||
|  |             p_data += 1 | ||||||
|  | 
 | ||||||
|  |     def finalize_plot(self, ax, wfstart, wfend, nmax, zoomx, zoomy, iniPick, title, snr): | ||||||
|         if iniPick: |         if iniPick: | ||||||
|             ax.vlines(iniPick, ax.get_ylim()[0], ax.get_ylim()[1], |             ax.vlines(iniPick, ax.get_ylim()[0], ax.get_ylim()[1], colors='m', linestyles='dashed', linewidth=2) | ||||||
|                       colors='m', linestyles='dashed', |         xlabel = f'seconds since {wfstart}' | ||||||
|                       linewidth=2) |  | ||||||
|         xlabel = 'seconds since {0}'.format(wfstart) |  | ||||||
|         ylabel = '' |         ylabel = '' | ||||||
|         self.updateWidget(xlabel, ylabel, title) |         self.updateWidget(xlabel, ylabel, title) | ||||||
|         self.setXLims(ax, [0, wfend - wfstart]) |         self.setXLims(ax, [0, wfend - wfstart]) | ||||||
| @ -1508,15 +1500,14 @@ class PylotCanvas(FigureCanvas): | |||||||
|         if zoomy is not None: |         if zoomy is not None: | ||||||
|             self.setYLims(ax, zoomy) |             self.setYLims(ax, zoomy) | ||||||
|         if snr is not None: |         if snr is not None: | ||||||
|             if snr < 2: |             self.plot_snr_warning(ax, snr) | ||||||
|                 warning = 'LOW SNR' |  | ||||||
|                 if snr < 1.5: |  | ||||||
|                     warning = 'VERY LOW SNR' |  | ||||||
|                 ax.text(0.1, 0.9, 'WARNING - {}'.format(warning), ha='center', va='center', transform=ax.transAxes, |  | ||||||
|                         color='red') |  | ||||||
| 
 |  | ||||||
|         self.draw() |         self.draw() | ||||||
| 
 | 
 | ||||||
|  |     def plot_snr_warning(self, ax, snr): | ||||||
|  |         if snr < 2: | ||||||
|  |             warning = 'LOW SNR' if snr >= 1.5 else 'VERY LOW SNR' | ||||||
|  |             ax.text(0.1, 0.9, f'WARNING - {warning}', ha='center', va='center', transform=ax.transAxes, color='red') | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def getXLims(ax): |     def getXLims(ax): | ||||||
|         return ax.get_xlim() |         return ax.get_xlim() | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user