import numpy as np
import matplotlib.pyplot as plt


def make_fullscreen():  # just for comfort
    manager = plt.get_current_fig_manager()
    if plt.get_backend() == 'TkAgg':  # For Tkinter
        manager.window.state('zoomed')  # Maximize the window
    elif plt.get_backend() == 'wxAgg':  # For wxPython
        manager.frame.Maximize(True)  # Maximize the window
    elif plt.get_backend() == 'QtAgg':  # For Qt
        manager.window.showMaximized()


# Parameters
fs = 10000  # Sampling rate [Hz]
stim_cycle = 0.3  # Stimulus cycle length [sec]
dc_seg = 0.2  # Duration of DC segment [sec]
threshold = -30  # Voltage threshold [mV]


def main_code(data, dataname):
    # Select dataset to process
    Si = data.copy()

    # Compute time vector
    N = Si.shape[0]  # Number of samples
    dt = 1 / fs  # Time step [sec]
    t = np.linspace(1, N, num=N)  # Sample index vector
    t_for_plot = t * dt  # Convert to time in seconds

    # Threshold crossing detection
    SaTH = np.array(Si > threshold, dtype=int)  # Boolean array for threshold crossings
    SaTHdiff = np.diff(SaTH)  # Detect transitions
    SaTHdiff = np.concatenate((SaTHdiff, [0]))  # Maintain array size

    # Plot the signal and threshold crossings
    plt.figure()
    plt.rc('font', size=12)
    make_fullscreen()
    plt.plot(t_for_plot, Si, label='Si')
    plt.axhline(y=threshold, color='green', linestyle='--', label='Threshold')
    plt.plot(t_for_plot, 10 * SaTH - 10, label='SaTH', color='black')
    plt.plot(t_for_plot, 10 * SaTHdiff - 20, label='SaTHdiff', color='red')
    plt.xlabel('Time [s]')
    plt.ylabel('Voltage [mV]')
    plt.title(f'{dataname}: Signal and Threshold Crossings')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Finding spike times
    L2H = np.argwhere(SaTHdiff > 0).flatten()  # Low-to-high transitions
    H2L = np.argwhere(SaTHdiff < 0).flatten()  # High-to-low transitions
    LM = []  # Local maxima indices
    LM_y_values = []  # Local maxima values

    # Identify local maxima between crossings
    for i in range(len(L2H)):
        temp = np.concatenate((np.zeros(L2H[i]), Si[int(L2H[i]):int(H2L[i])], np.zeros(len(Si) - H2L[i])))
        LM_y_values.append(np.max(temp))
        LM.append(np.argmax(temp) + 1)

    # Compute spike rate per segment
    SC = []  # Spike count per segment
    Ravg = []  # Mean firing rate
    Rstd = []  # Standard deviation of firing rate
    num_of_segments = int(len(Si) / (stim_cycle * fs))
    LM = np.array(LM)
    ISI = np.concatenate((np.diff(LM * dt), [0]))  # Interspike interval

    # Calculate firing rate statistics
    segment_offset = 0
    for iSeg in range(num_of_segments):
        start_segment = iSeg * int(stim_cycle * fs)
        end_segment = (iSeg + 1) * int(stim_cycle * fs)
        count = np.sum((LM >= start_segment) & (LM <= min(end_segment, N - 1)))
        SC.append(count)
        segment_spikes_times = np.array([ISI[segment_offset + i] for i in range(count)])
        segment_offset += count
        if count < 2:
            Ravg.append(0)
            Rstd.append(0)
        else:
            Ravg.append(np.median(1 / segment_spikes_times[segment_spikes_times > 0]))
            Rstd.append(np.std(1 / segment_spikes_times[segment_spikes_times > 0]))

    SC = np.array(SC)
    R = SC / dc_seg  # Compute firing rate

    # Plot signal with spike detection
    plt.figure()
    make_fullscreen()
    plt.subplot(2, 1, 1)
    plt.ylim([np.min(Si) - 5, np.max(Si) + 10])
    plt.plot(t_for_plot, Si, label='Si', zorder=0)
    plt.scatter(LM * dt, LM_y_values, facecolors='none', color='black', label='Local Maxima', zorder=1)
    plt.scatter(t_for_plot[L2H], Si[L2H], facecolors='none', color='green', label='Low to High', zorder=1)
    plt.scatter(t_for_plot[H2L], Si[H2L], facecolors='none', color='red', label='High to Low', zorder=1)

    # Annotate firing rates
    for iR, rate in enumerate(R):
        plt.text(iR * stim_cycle + 0.075, np.max(Si) + 3, f'{rate:.1f}')

    plt.xlabel('Time [s]')
    plt.ylabel('Voltage [mV]')
    plt.title(f'{dataname}: Signal with Spike Detection and Firing Rate [Hz]')
    plt.legend(loc='center left')
    plt.grid(True)

    # Plot firing rate statistics
    plt.subplot(2, 1, 2)
    plt.ylim([np.min(Si) - 5, np.max(Si) + 10])
    plt.plot(t_for_plot, Si, label='Si', zorder=0)
    plt.scatter(LM * dt, LM_y_values, facecolors='none', color='black', label='Local Maxima', zorder=1)
    plt.scatter(t_for_plot[L2H], Si[L2H], facecolors='none', color='green', label='Low to High', zorder=1)
    plt.scatter(t_for_plot[H2L], Si[H2L], facecolors='none', color='red', label='High to Low', zorder=1)

    plt.xlabel('Time [s]')
    plt.ylabel('Voltage [mV]')
    plt.title(f'{dataname}: Signal with Median Firing Rate ± Std Dev [Hz]')
    plt.legend(loc='center left')

    # Annotate mean firing rates
    plt.rc('font', size=8)
    for iR, rate in enumerate(R):
        plt.text(iR * stim_cycle + 0.07, np.max(Si) + 4, f'{round(Ravg[iR], 1)}±{round(Rstd[iR], 1)}')

    plt.grid(True)
    plt.tight_layout()
    plt.show()


if __name__ == '__main__':
    # Load datasets
    s1 = np.load('S1.npy', allow_pickle=True).flatten()  # Load S1 data
    s2 = np.load('S2.npy', allow_pickle=True).flatten()  # Load S2 data

    # Process and plot datasets
    main_code(s1, "s1")
    main_code(s2, "s2")
