#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

import matplotlib
import numpy as np

from obspy import Stream


COLORS_DICT = {'FAIL': (195, 29, 14, 255),
               'NO DATA': (255, 255, 125, 255),
               'WARN': (250, 192, 63, 255),
               'OK': (185, 245, 145, 255),
               'undefined': (240, 240, 240, 255),
               'disc': (126, 127, 131, 255), }


def get_bg_color(check_key, status, dt_thresh=None, hex=False):
    message = status.message
    if check_key == 'last active':
        bg_color = get_time_delay_color(message, dt_thresh)
    elif check_key == 'temp':
        bg_color = get_temp_color(message)
    elif check_key == 'mass':
        bg_color = get_mass_color(message)
    else:
        if status.is_warn:
            bg_color = get_warn_color(status.count)
        elif status.is_error:
            if status.connection_error:
                bg_color = get_color('disc')
            else:
                bg_color = get_color('FAIL')
        else:
            bg_color = get_color(message)
    if not bg_color:
        bg_color = get_color('undefined')

    if hex:
        bg_color = '#{:02x}{:02x}{:02x}'.format(*bg_color[:3])
    return bg_color


def get_color(key):
    # some old GUI default colors
    # colors_dict = {'FAIL': (255, 85, 50, 255),
    #                'NO DATA': (255, 255, 125, 255),
    #                'WARN': (255, 255, 80, 255),
    #                'OK': (173, 255, 133, 255),
    #                'undefined': (230, 230, 230, 255),
    #                'disc': (255, 160, 40, 255),}
    if not key in COLORS_DICT.keys():
        key = 'undefined'
    return COLORS_DICT.get(key)


def get_color_mpl(key):
    color_tup = get_color(key)
    return np.array([color/255. for color in color_tup])


def get_time_delay_color(dt, dt_thresh):
    """ Set color of time delay after thresholds specified in self.dt_thresh """
    if isinstance(dt, type(dt_thresh[0])):
        if dt < dt_thresh[0]:
            return get_color('OK')
        elif dt_thresh[0] <= dt < dt_thresh[1]:
            return get_color('WARN')
    return get_color('FAIL')


def get_warn_color(count, n_colors=20):
    if count >= n_colors:
        count = -1
    gradient = np.linspace((240, 245, 110, 255), (250, 192, 63, 255), n_colors, dtype=int)
    return tuple(gradient[count])


def get_mass_color(message):
    # can change this to something else if wanted. This way it always returns get_color (without warn count)
    if isinstance(message, (float, int)):
        return get_color('OK')
    return get_color(message)


def get_temp_color(temp, vmin=-10, vmax=60, cmap='coolwarm'):
    """ Get an rgba temperature value back from specified cmap, linearly interpolated between vmin and vmax. """
    if type(temp) in [str]:
        if temp in COLORS_DICT.keys():
            return get_color(temp)
        return get_color('undefined')
    cmap = matplotlib.cm.get_cmap(cmap)
    val = (temp - vmin) / (vmax - vmin)
    rgba = [int(255 * c) for c in cmap(val)]
    return rgba


def get_font_color(bg_color, hex=False):
    if hex:
        bg_color = matplotlib.colors.to_rgb(bg_color)
    bg_color_hsv = matplotlib.colors.rgb_to_hsv(bg_color)
    bg_color_hsl = hsv_to_hsl(bg_color_hsv)
    font_color = (255, 255, 255, 255) if bg_color_hsl[2] < 0.6 else (0, 0, 0, 255)
    if hex:
        font_color = '#{:02x}{:02x}{:02x}'.format(*font_color[:3])
    return font_color


def hsv_to_hsl(hsv):
    hue, saturation, value = hsv
    lightness = value * (1 - saturation / 2)
    saturation = 0 if lightness in (0, 1) else (value - lightness) / min(lightness, 1 - lightness)
    return hue, saturation, lightness


def modify_stream_for_plot(input_stream, parameters):
    """ copy (if necessary) and modify stream for plotting """

    # make a copy
    st = Stream()

    channels_dict = parameters.get('CHANNELS')

    # iterate over all channels and put them to new stream in order
    for index, ch_tup in enumerate(channels_dict.items()):
        # unpack tuple from items
        channel, channel_dict = ch_tup

        # get correct channel from stream
        st_sel = input_stream.select(channel=channel)
        # in case there are != 1 there is ambiguity
        if not len(st_sel) == 1:
            continue

        # make a copy to not modify original stream!
        tr = st_sel[0].copy()

        # multiply with conversion factor for unit
        unit_factor = channel_dict.get('unit')
        if unit_factor:
            tr.data = tr.data * float(unit_factor)

        # apply transformations if provided
        transform = channel_dict.get('transform')
        if transform:
            tr.data = transform_trace(tr.data, transform)

        # modify trace id to maintain plotting order
        name = channel_dict.get('name')
        tr.id = f'{index + 1}: {name} - {tr.id}'

        st.append(tr)

    return st


def transform_trace(data, transf):
    """
    Transform trace with arithmetic operations in order, specified in transf
    @param data: numpy array
    @param transf: list of lists with arithmetic operations (e.g. [['*', '20'], ] -> multiply data by 20
    """
    # This looks a little bit hardcoded, however it is safer than using e.g. "eval"
    for operator_str, val in transf:
        if operator_str == '+':
            data = data + val
        elif operator_str == '-':
            data = data - val
        elif operator_str == '*':
            data = data * val
        elif operator_str == '/':
            data = data / val
        else:
            raise IOError(f'Unknown arithmethic operator string: {operator_str}')

    return data


def set_axis_ylabels(fig, parameters):
    """
    Adds channel names to y-axis if defined in parameters.
    """
    names = [channel.get('name') for channel in parameters.get('CHANNELS').values()]
    if not names: # or not len(st.traces):
        return
    if not len(names) == len(fig.axes):
        logging.info('Mismatch in axis and label lengths. Not adding plot labels')
        return
    for channel_name, ax in zip(names, fig.axes):
        if channel_name:
            ax.set_ylabel(channel_name)


def set_axis_color(fig, color='0.8', shade_color='0.95'):
    """
    Set all axes (frame) of figure to specific color. Shade every second axis.
    """
    for i, ax in enumerate(fig.axes):
        for key in ['bottom', 'top', 'right', 'left']:
            ax.spines[key].set_color(color)
        if i % 2:
            ax.set_facecolor(shade_color)


def set_axis_yticks(fig, parameters):
    """
    Adds channel names to y-axis if defined in parameters.
    """
    ticks = [channel.get('ticks') for channel in parameters.get('CHANNELS').values()]
    if not ticks:
        return
    if not len(ticks) == len(fig.axes):
        logging.info('Mismatch in axis tick and label lengths. Not changing plot ticks.')
        return
    for ytick_tripple, ax in zip(ticks, fig.axes):
        if not ytick_tripple:
            continue
        ymin, ymax, step = ytick_tripple

        yticks = list(np.arange(ymin, ymax + step, step))
        ax.set_yticks(yticks)
        ax.set_ylim(ymin - 0.33 * step, ymax + 0.33 * step)


def plot_axis_thresholds(fig, parameters):
    """
    Adds channel thresholds (warn, fail) to y-axis if defined in parameters.
    """
    logging.info('Plotting trace thresholds')

    keys_colors = {'warn': dict(color=0.8 * get_color_mpl('WARN'), linestyle=(0, (5, 10)), alpha=0.5, linewidth=0.7),
                   'fail': dict(color=0.8 * get_color_mpl('FAIL'), linestyle='solid', alpha=0.5, linewidth=0.7)}

    for key, kwargs in keys_colors.items():
        channel_threshold_list = [channel.get(key) for channel in parameters.get('CHANNELS').values()]
        if not channel_threshold_list:
            continue
        plot_threshold_lines(fig, channel_threshold_list, parameters, **kwargs)


def plot_threshold_lines(fig, channel_threshold_list, parameters, **kwargs):
    for channel_thresholds, ax in zip(channel_threshold_list, fig.axes):
        if channel_thresholds in ['pb_SOH2', 'pb_SOH3']:
            annotate_voltage_states(ax, parameters, channel_thresholds)
            channel_thresholds = get_warn_states_pbox(channel_thresholds, parameters)

        if not channel_thresholds:
            continue

        if not isinstance(channel_thresholds, (list, tuple)):
            channel_thresholds = [channel_thresholds]

        for warn_thresh in channel_thresholds:
            if isinstance(warn_thresh, str):
                warn_thresh = parameters.get('THRESHOLDS').get(warn_thresh)
            if isinstance(warn_thresh, (float, int)):
                ax.axhline(warn_thresh, **kwargs)


def get_warn_states_pbox(soh_key: str, parameters: dict) -> list:
    pb_dict = parameters.get('POWBOX').get(soh_key)
    if not pb_dict:
        return []
    return [key for key in pb_dict.keys() if key > 1]


def annotate_voltage_states(ax, parameters, pb_key, color='0.75'):
    for voltage, voltage_dict in parameters.get('POWBOX').get(pb_key).items():
        if float(voltage) < 1:
            continue
        out_string = ''
        for key, val in voltage_dict.items():
            if val != 'OK':
                if out_string:
                    out_string += ' | '
                out_string += f'{key}: {val}'

        ax.annotate(out_string, (ax.get_xlim()[-1], voltage), color=color, fontsize='xx-small',
                    horizontalalignment='right')