survBot/utils.py
2022-12-21 16:03:10 +01:00

217 lines
6.9 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import matplotlib
import numpy as np
from obspy import Stream
def get_bg_color(check_key, status, dt_thresh=None, hex=False):
message = status.message
if check_key == 'last active':
bg_color = get_time_delay_color(message, dt_thresh)
elif check_key == 'temp':
bg_color = get_temp_color(message)
elif check_key == 'mass':
bg_color = get_mass_color(message)
else:
if status.is_warn:
bg_color = get_warn_color(status.count)
elif status.is_error:
bg_color = get_color('FAIL')
else:
bg_color = get_color(message)
if not bg_color:
bg_color = get_color('undefined')
if hex:
bg_color = '#{:02x}{:02x}{:02x}'.format(*bg_color[:3])
return bg_color
def get_color(key):
# some GUI default colors
colors_dict = {'FAIL': (255, 50, 0, 255),
'NO DATA': (255, 255, 125, 255),
'WARN': (255, 255, 80, 255),
'OK': (125, 255, 125, 255),
'undefined': (230, 230, 230, 255)}
return colors_dict.get(key)
def get_color_mpl(key):
color_tup = get_color(key)
return np.array([color/255. for color in color_tup])
def get_time_delay_color(dt, dt_thresh):
""" Set color of time delay after thresholds specified in self.dt_thresh """
if dt < dt_thresh[0]:
return get_color('OK')
elif dt_thresh[0] <= dt < dt_thresh[1]:
return get_color('WARN')
return get_color('FAIL')
def get_warn_color(count):
color = (min([255, 200 + count ** 2]), 255, 80, 255)
return color
def get_mass_color(message):
# can change this to something else if wanted. This way it always returns get_color (without warn count)
if isinstance(message, (float, int)):
return get_color('OK')
return get_color(message)
def get_temp_color(temp, vmin=-10, vmax=60, cmap='coolwarm'):
""" Get an rgba temperature value back from specified cmap, linearly interpolated between vmin and vmax. """
if type(temp) in [str]:
return get_color('undefined')
cmap = matplotlib.cm.get_cmap(cmap)
val = (temp - vmin) / (vmax - vmin)
rgba = [int(255 * c) for c in cmap(val)]
return rgba
def modify_stream_for_plot(input_stream, parameters):
""" copy (if necessary) and modify stream for plotting """
# make a copy
st = Stream()
channels_dict = parameters.get('CHANNELS')
# iterate over all channels and put them to new stream in order
for index, ch_tup in enumerate(channels_dict.items()):
# unpack tuple from items
channel, channel_dict = ch_tup
# get correct channel from stream
st_sel = input_stream.select(channel=channel)
# in case there are != 1 there is ambiguity
if not len(st_sel) == 1:
continue
# make a copy to not modify original stream!
tr = st_sel[0].copy()
# multiply with conversion factor for unit
unit_factor = channel_dict.get('unit')
if unit_factor:
tr.data = tr.data * float(unit_factor)
# apply transformations if provided
transform = channel_dict.get('transform')
if transform:
tr.data = transform_trace(tr.data, transform)
# modify trace id to maintain plotting order
name = channel_dict.get('name')
tr.id = f'{index + 1}: {name} - {tr.id}'
st.append(tr)
return st
def transform_trace(data, transf):
"""
Transform trace with arithmetic operations in order, specified in transf
@param data: numpy array
@param transf: list of lists with arithmetic operations (e.g. [['*', '20'], ] -> multiply data by 20
"""
# This looks a little bit hardcoded, however it is safer than using e.g. "eval"
for operator_str, val in transf:
if operator_str == '+':
data = data + val
elif operator_str == '-':
data = data - val
elif operator_str == '*':
data = data * val
elif operator_str == '/':
data = data / val
else:
raise IOError(f'Unknown arithmethic operator string: {operator_str}')
return data
def set_axis_ylabels(fig, parameters, verbosity=0):
"""
Adds channel names to y-axis if defined in parameters.
"""
names = [channel.get('name') for channel in parameters.get('CHANNELS').values()]
if not names: # or not len(st.traces):
return
if not len(names) == len(fig.axes):
if verbosity:
print('Mismatch in axis and label lengths. Not adding plot labels')
return
for channel_name, ax in zip(names, fig.axes):
if channel_name:
ax.set_ylabel(channel_name)
def set_axis_color(fig, color='grey'):
"""
Set all axes of figure to specific color
"""
for ax in fig.axes:
for key in ['bottom', 'top', 'right', 'left']:
ax.spines[key].set_color(color)
def set_axis_yticks(fig, parameters, verbosity=0):
"""
Adds channel names to y-axis if defined in parameters.
"""
ticks = [channel.get('ticks') for channel in parameters.get('CHANNELS').values()]
if not ticks:
return
if not len(ticks) == len(fig.axes):
if verbosity:
print('Mismatch in axis tick and label lengths. Not changing plot ticks.')
return
for ytick_tripple, ax in zip(ticks, fig.axes):
if not ytick_tripple:
continue
ymin, ymax, step = ytick_tripple
yticks = list(np.arange(ymin, ymax + step, step))
ax.set_yticks(yticks)
ax.set_ylim(ymin - 0.33 * step, ymax + 0.33 * step)
def plot_axis_thresholds(fig, parameters, verbosity=0):
"""
Adds channel thresholds (warn, fail) to y-axis if defined in parameters.
"""
if verbosity > 0:
print('Plotting trace thresholds')
keys_colors = {'warn': dict(color=0.8 * get_color_mpl('WARN'), linestyle=(0, (5, 10)), alpha=0.5, linewidth=0.7),
'fail': dict(color=0.8 * get_color_mpl('FAIL'), linestyle='solid', alpha=0.5, linewidth=0.7)}
for key, kwargs in keys_colors.items():
channel_threshold_list = [channel.get(key) for channel in parameters.get('CHANNELS').values()]
if not channel_threshold_list:
continue
plot_threshold_lines(fig, channel_threshold_list, parameters, **kwargs)
def plot_threshold_lines(fig, channel_threshold_list, parameters, **kwargs):
for channel_thresholds, ax in zip(channel_threshold_list, fig.axes):
if not channel_thresholds:
continue
if not isinstance(channel_thresholds, (list, tuple)):
channel_thresholds = [channel_thresholds]
for warn_thresh in channel_thresholds:
if isinstance(warn_thresh, str):
warn_thresh = parameters.get('THRESHOLDS').get(warn_thresh)
if type(warn_thresh in (float, int)):
ax.axhline(warn_thresh, **kwargs)