import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation

# uncomment the following line for the gif saving function of the last graph to work:
# from PIL import Image

excitation = 10
num_of_neurons = 200
total_time = 20  # [sec]
DELTA = 0.1  # delta for euler approximation
size = int(total_time / DELTA)
time = np.linspace(start=1, stop=size, num=size, dtype=int)  # [μsec]
pulse_time = 1  # duration of the pulse [sec]
pulse_strength = 20  # pulse current value [nA]
start_I_time = 0  # pulse starts after x [msec]
weights = np.zeros((num_of_neurons, num_of_neurons))
neurons_dicts = {}
euler = {}
tao = 1
rings = ["main", "left", "right"]
shift_size = 1
shifts = {"main": 0, "left": shift_size, "right": -shift_size}
I_left, I_right = np.zeros((num_of_neurons, num_of_neurons)), np.zeros((num_of_neurons, num_of_neurons))
f = lambda x: x if x >= 0 else 0


def get_single_weight(i, j, shift=0):
    diff = min(abs(i - j + shift), num_of_neurons - abs(i - j + shift))
    return np.exp(-pow(diff / excitation, 2)) - 0.1


def get_all_weights(shift=0):
    result = []
    for i in range(1, num_of_neurons + 1):
        tmp_weights = []
        for j in range(1, num_of_neurons + 1):
            tmp_weights.append(get_single_weight(j, i, shift))
        result.append(tmp_weights)

    return np.asarray(result)


def get_triple_weights():
    final_result = {}
    for ring_ in rings:
        tmp_result = []
        for i in range(1, num_of_neurons + 1):
            tmp_weights = []
            for j in range(1, num_of_neurons + 1):
                tmp_weights.append(get_single_weight(j, i, shifts[ring_]))
            tmp_result.append(tmp_weights)
        final_result[ring_] = tmp_result.copy()
    return final_result


def build_dicts(ring="single"):
    global neurons_dicts, euler, rings, time, size, shifts

    if ring == "single":
        for neuron in range(num_of_neurons):
            neurons_dicts[neuron] = {"r": np.zeros(size), "I_inj": np.zeros(size)}
            neurons_dicts[neuron]["r"][0] = np.random.rand()
            euler[neuron] = euler_approximation_single_ring(neuron)
    elif ring == "triple":
        size = int(total_time / DELTA)
        time = np.linspace(start=1, stop=size, num=size, dtype=int)  # [μsec]
        shifts = {"main": 0, "left": shift_size, "right": -shift_size}

        for ring_type in rings:
            neurons_dicts[ring_type] = {}
            euler[ring_type] = {}
            for neuron in range(num_of_neurons):
                neurons_dicts[ring_type][neuron] = {"r": np.zeros(size), "I_inj": np.zeros(size)}
                neurons_dicts[ring_type][neuron]["r"][0] = np.random.rand()
                euler[ring_type][neuron] = euler_approximation_triple_ring(neuron, ring_type)


def inject_current(neurons_to_inject, strength=pulse_strength, duration=pulse_time, when=start_I_time):
    global neurons_dicts
    if len(neurons_to_inject) > 0:
        for i in neurons_to_inject:
            neurons_dicts[i]["I_inj"][int(when / DELTA): int((when + duration) / DELTA)] = strength  # Injecting Current Pulse for "duration" mseconds


def inject_current_to_triple_ring(neurons_to_inject, ring_to_inject="main", strength=pulse_strength, duration=pulse_time, when=start_I_time):
    global neurons_dicts
    if len(neurons_to_inject) == 0:
        neurons_to_inject = [i for i in range(num_of_neurons)]
        # for ring_ in rings:
    for i in neurons_to_inject:
        neurons_dicts[ring_to_inject][i]["I_inj"][int(when / DELTA): int((when + duration) / DELTA)] = strength  # Injecting Current Pulse for "duration" mseconds


def euler_approximation_single_ring(neuron_index, delta=DELTA):
    derivative = lambda t: (-neurons_dicts[neuron_index]["r"][t] + neurons_dicts[neuron_index]["I_inj"][t] + f(sum([(neurons_dicts[index]["r"][t] * weights[neuron_index][index]) for index in range(num_of_neurons)]))) / tao

    # Create a lambda function for Euler's method
    def euler_func(prev_y, prev_t):
        return prev_y + delta * derivative(prev_t)

    return euler_func


def euler_approximation_triple_ring(neuron_index, ring_type, delta=DELTA):
    if ring_type == "main":
        derivative = lambda t: (-neurons_dicts[ring_type][neuron_index]["r"][t] + neurons_dicts[ring_type][neuron_index]["I_inj"][t] + f(
            sum([(neurons_dicts["main"][index]["r"][t] * weights["main"][neuron_index][index]) for index in range(num_of_neurons)]) +
            sum([(neurons_dicts["left"][index]["r"][t] * weights["main"][neuron_index][index]) for index in range(num_of_neurons)]) +
            sum([(neurons_dicts["right"][index]["r"][t] * weights["main"][neuron_index][index]) for index in range(num_of_neurons)]))) / tao
    elif ring_type == "left":
        derivative = lambda t: (-neurons_dicts[ring_type][neuron_index]["r"][t] + neurons_dicts[ring_type][neuron_index]["I_inj"][t] + I_left[neuron_index][t] + f(
            sum([(neurons_dicts["main"][index]["r"][t] * weights["main"][neuron_index][index]) for index in range(num_of_neurons)]) +
            sum([(neurons_dicts["left"][index]["r"][t] * weights["left"][neuron_index][index]) for index in range(num_of_neurons)]) +
            sum([(neurons_dicts["main"][index]["r"][t] * weights["right"][neuron_index][index]) for index in range(num_of_neurons)]))) / tao
    elif ring_type == "right":
        derivative = lambda t: (-neurons_dicts[ring_type][neuron_index]["r"][t] + neurons_dicts[ring_type][neuron_index]["I_inj"][t] + I_right[neuron_index][t] + f(
            sum([(neurons_dicts["main"][index]["r"][t] * weights["main"][neuron_index][index]) for index in range(num_of_neurons)]) +
            sum([(neurons_dicts["main"][index]["r"][t] * weights["left"][neuron_index][index]) for index in range(num_of_neurons)]) +
            sum([(neurons_dicts["right"][index]["r"][t] * weights["right"][neuron_index][index]) for index in range(num_of_neurons)]))) / tao

    # Create a lambda function for Euler's method
    def euler_func(prev_y, prev_t):
        return prev_y + delta * derivative(prev_t)

    return euler_func


def show_weights_to_diff(relative_neuron=100, shift=0):
    color = "blue"
    if shift != 0:
        color = "green" if shift < 0 else "red"
    weights_ = []
    plt.figure(figsize=(14, 6))

    for i in range(1, num_of_neurons + 1):
        w = get_single_weight(i, relative_neuron, shift=shift)
        weights_.append(w)
    addon = ""
    if shift:
        addon = f"with shift of {shift} (to the right)" if shift < 0 else f"with shift of {shift} (to the left)"
    plt.title(f"Weights of the relative neuron (#{relative_neuron}) as a function of distance to the rest of the neurons " + addon)
    plt.grid()
    plt.xlabel('Neurons Indexes')
    plt.ylabel('Synaptic Weight')
    neurons = [i for i in range(1, num_of_neurons + 1)]
    plt.axvline(x=relative_neuron, linestyle='--', color='black')
    plt.axhline(y=0, linestyle='--', color='black')
    plt.plot(neurons, weights_, color=color)
    plt.show()


def show_animation_of_rate_vs_time(where_to_inject=-1, shift=0):
    global weights
    fig, ax = plt.subplots(figsize=(14, 6))
    weights = get_all_weights(shift)
    addon = ""
    color = "blue"
    if shift != 0:
        color = "green" if shift < 0 else "red"
    if where_to_inject != -1:
        inject_current([where_to_inject])
        addon = f",with current injection of {pulse_strength}[nA] for {pulse_time}[sec] at neuron #{where_to_inject}"

    to_plot = [neurons_dicts[i]["r"][0] for i in neurons_dicts.keys()]
    max_val = max(to_plot)
    title = "All of the neurons activity as a function of neurons index " + addon
    plt.title(title)
    plt.axvline(x=where_to_inject, linestyle='--', color='black')
    plt.xlabel('Neurons Indexes')
    plt.ylabel('Firing Rate')
    ax.grid()
    line = ax.plot(np.asarray(to_plot) / max_val, color=color)[0]

    def update(frame):
        if frame > 0:
            # euler approximation of the neuron firing rate
            data = [euler[neuron](neurons_dicts[neuron]["r"][frame - 1], frame - 1) for neuron in neurons_dicts.keys()]
            max_ = max(data)
            data_to_plot = []
            for neuron in neurons_dicts.keys():
                # normalize
                neurons_dicts[neuron]["r"][frame] = data[neuron] / max_
                data_to_plot.append(neurons_dicts[neuron]["r"][frame])

            line.set_ydata(data_to_plot)
            plt.title(title + f" (time: {frame}[deci-sec])")
        return line

    ani = animation.FuncAnimation(fig=fig, func=update, frames=int(total_time / DELTA), interval=1)
    # writer = animation.PillowWriter(fps=30, metadata=dict(artist='Anton'), bitrate=1800)
    # ani.save(f"ring_animation_shift={shift}.gif", writer=writer)

    plt.show()


def save_animation_in_chunks(ani, total_frames, chunk_size=1000, fps=30):  # function that saves animation as gif file in chunks (the number of frames was too large to fit in one save)
    # Split the total frames into chunks
    chunks = [range(i, min(i + chunk_size, total_frames)) for i in range(0, total_frames, chunk_size)]

    # Save each chunk as a separate GIF
    for i, chunk in enumerate(chunks):
        chunk_ani = animation.FuncAnimation(ani._fig, ani._func, frames=chunk, interval=ani._interval)
        chunk_ani.save(f'triple_ring_chunk_{i}.gif', writer=animation.PillowWriter(fps=fps))

    # Combine all chunks into a single GIF
    chunk_images = [Image.open(f'triple_ring_chunk_{i}.gif') for i in range(len(chunks))]
    chunk_images[0].save('triple_ring.gif', save_all=True, append_images=chunk_images[1:], loop=0, duration=1000 / fps)


def show_complete_triple_ring_model():
    global weights
    weights = get_triple_weights()
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(17, 10))
    addon = ""
    main_ring, left_ring, right_ring = None, None, None
    axs_lines = []
    for ring in rings:
        color = "blue"
        if shifts[ring] != 0:
            color = "green" if shifts[ring] < 0 else "red"

        y_to_plot = np.asarray([neurons_dicts[ring][j]["r"][0] for j in range(num_of_neurons)])
        max_val = np.max(y_to_plot)
        x_to_plot = np.asarray([j for j in range(1, num_of_neurons + 1)])
        title = "All of the rings are connected to each other now and act according to: "

        if ring == "main":
            ax1.set_title(ring.capitalize() + " ring neurons activity")
            axs_lines.append(ax1.axvline(x=100, linestyle='--', color='black'))
            ax1.grid()
            main_ring, = ax1.plot(x_to_plot, y_to_plot / max_val, color=color)
        elif ring == "left":
            ax2.set_title(ring.capitalize() + " shift ring neurons activity")
            axs_lines.append(ax2.axvline(x=100, linestyle='--', color='black'))
            ax2.grid()
            left_ring, = ax2.plot(x_to_plot, y_to_plot / max_val, color=color)
        elif ring == "right":
            ax3.set_title(ring.capitalize() + " shift ring neurons activity")
            axs_lines.append(ax3.axvline(x=100, linestyle='--', color='black'))
            ax3.grid()
            right_ring, = ax3.plot(x_to_plot, y_to_plot / max_val, color=color)
    lines = [main_ring, left_ring, right_ring]

    def update(frame):
        nonlocal addon
        print(frame)
        if frame > 0:
            max_indx = 0
            for ring in rings:
                # euler approximation of the neuron firing rate
                data = [euler[ring][neuron](neurons_dicts[ring][neuron]["r"][frame - 1], frame - 1) for neuron in neurons_dicts[ring].keys()]
                max_ = max(data)
                max_indx = data.index(max_) + 1
                data_to_plot = []
                for neuron in neurons_dicts[ring].keys():
                    # normalize
                    neurons_dicts[ring][neuron]["r"][frame] = data[neuron] / max_
                    data_to_plot.append(neurons_dicts[ring][neuron]["r"][frame])

                axs_lines[rings.index(ring)].set_xdata(max_indx)
                lines[rings.index(ring)].set_ydata(np.asarray(data_to_plot))

            if 1 in I_left[:, frame]:
                addon = "left head direction"
            elif 1 in I_right[:, frame]:
                addon = "right head direction"
            else:
                addon = "no head direction"
            # print("max " + str(max_indx))
            plt.suptitle(title + addon + f" (time: {frame}[deci-sec])")

        return main_ring, left_ring, right_ring

    ani = animation.FuncAnimation(fig=fig, func=update, frames=int(total_time / DELTA), interval=0.1)
    # uncomment the following line if you want to save the animation
    # save_animation_in_chunks(ani, int(total_time / DELTA), chunk_size=1000, fps=30)

    plt.show()


def main():
    global total_time, pulse_strength, I_left, I_right, shift_size
    build_dicts("single")  # initialize dictionaries for single rings
    for ring_ in rings:  # show graphs of weights and firing rate over time graph for each ring
        show_weights_to_diff(shift=shifts[ring_])
        show_animation_of_rate_vs_time(where_to_inject=100, shift=shifts[ring_])

    # change some basic data for the triple ring
    total_time = 500
    pulse_strength = 1
    # it takes some time to see the movement of the graph, so you should uncomment the following line in order to see it a little faster
    # shift_size= 10
    build_dicts("triple")
    head_turns = np.zeros((num_of_neurons, len(time)))
    head_turns[:, int(3 / DELTA):int(230 / DELTA)] = 1
    head_turns[:, int(275 / DELTA):int(500 / DELTA)] = -1

    # playing with the firing rate and the currents in order to make the 100th neuron the leading one when the graph starts
    for n in range(95, 105):
        neurons_dicts["main"][n]["r"][0] = 1
    neurons_dicts["main"][99]["I_inj"][0:1] = 20

    # turning the head turns into currents
    I_right = np.where(head_turns > 0, pulse_strength, 0)
    I_left = np.where(head_turns < 0, pulse_strength, 0)

    show_complete_triple_ring_model()


main()
