import numpy as np
import os
import matplotlib.pyplot as plt

head_direction = np.array([])
times = np.array([])
spikes = np.array([])
colors = ["blue", "green", "red", "orange", "cyan", "purple", "magenta"]


def nueron_data_by_index(index=0):
    global head_direction, times, spikes
    path = ""
    addon = ""
    if index < 4:
        path = "CSV files/" + os.listdir("CSV files")[index]
    elif index == 4:
        path = "CSV files/" + os.listdir("CSV files")[4]
        addon = "1"
    elif index == 5:
        path = "CSV files/" + os.listdir("CSV files")[4]
        addon = "2"

    head_direction = np.loadtxt(f"{path}/headDirection.csv", delimiter=",", dtype=float)
    spikes = np.loadtxt(f"{path}/spiketrain{addon}.csv", delimiter=",", dtype=float)
    # times = np.loadtxt(f"{path}/post.csv", delimiter=",", dtype=float)
    times = np.array([i for i in range(spikes.shape[0])])


def show_neuron(spikes_=spikes.copy(), index=0, permutation_num=0):
    addon = ""
    # Create bins for every x degrees
    x = int(360 / 12)
    bins = np.arange(0, 360 + x, x)

    # Calculate bin counts for spikes
    bin_counts = np.histogram(np.degrees(head_direction), bins=bins, weights=spikes_)[0]

    # Calculate the time spent in each bin
    times_looked_at_each_direction = np.histogram(np.degrees(head_direction), bins=bins)[0]

    # Calculate the firing rate for each bin
    firing_rates = bin_counts / (times_looked_at_each_direction / 1000)

    # Create a polar plot
    fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}, figsize=(14, 6))
    theta = np.radians(bins[:-1])  # Convert bin edges to radians
    width = np.radians(x)  # Width of each bin in radians

    # Plot
    ax.bar(theta, firing_rates, width=width, bottom=0, color=colors[index])
    max_ = max(np.max(firing_rates), 10)
    ax.set_rmax(max_)
    ax.grid(True)
    # ax.set_rticks([0, 0.25, 0.5, 0.75, 1])
    ax.set_theta_zero_location("N")  # Set the 0 degrees at the top
    ax.set_theta_direction(-1)  # Set the direction of degrees clockwise
    if permutation_num:
        addon = f"(permutation #{permutation_num})"
    ax.set_title(f"Spikes frequency as a function of head direction for neuron #{index + 1} {addon}")
    label_position = ax.get_rlabel_position()
    ax.text(np.radians(label_position + 10), ax.get_rmax() / 2, 'Fire Rate [Hz]', rotation=-label_position, ha='center', va='center')
    ax.set_xlabel("Head Direction [degrees]")
    plt.show()


def show_all_neurons():
    for i in range(6):
        nueron_data_by_index(i)
        show_neuron(i)


def permutation(index=0, permutation_num=0):
    fire = np.indices(spikes.shape)[0][spikes == 1]
    fire1 = fire[:-1]
    fire2 = fire[1:]
    spaces = fire2 - fire1
    spaces = np.hstack((fire[0], spaces))

    new_spikes = np.zeros(spikes.shape)
    new_spaces = spaces.copy()
    np.random.shuffle(new_spaces)
    # print(new_spaces)
    tmp_index = new_spaces[0]
    new_spikes[tmp_index] = 1
    for space in new_spaces[1:]:
        tmp_index += space
        new_spikes[tmp_index] = 1

    show_neuron(spikes_=new_spikes, index=index, permutation_num=permutation_num)


def show_neurons_and_permutations_of_suspects(suspects, num_of_permutations=2):
    global spikes
    for i in range(6):
        nueron_data_by_index(i)
        show_neuron(spikes_=spikes, index=i)
        if i in suspects:
            for num in range(num_of_permutations):
                permutation(i, num + 1)


def get_single_weight(row, col):
    diff = row - col
    if diff * 2 > num_of_neurons:
        diff -= (diff * 2 - num_of_neurons)
    diff = pow(diff, 2)
    # return diff
    return np.exp(-diff / pow(excitation, 2)) - w_const * np.exp(-diff / pow(inhibition, 2))


def get_all_weights():
    result = np.empty((0, num_of_neurons), dtype=object)  # Initialize as a 2D array with 0 rows

    for i in range(1, num_of_neurons + 1):
        tmp_weights = []
        for j in range(1, num_of_neurons + 1):
            if j > i:
                tmp_weights.append(get_single_weight(j, i))
            else:
                tmp_weights.append(0)
        result = np.vstack((result, tmp_weights))

    result = result + result.T
    return result


def f(x):
    return x if x >= 0 else 0


excitation = 30
inhibition = 60
w_const = 0.2
num_of_neurons = 50

total_time = 1000  # [msec]
DELTA = 0.0001  # delta for euler approximation
size = int(total_time)
time = np.linspace(start=0, stop=size, num=size, dtype=int)  # [μsec]
pulse_time = 50  # duration of the pulse [msec]
pulse_strength = 250000  # pulse current value [nA]
start_I_time = 10  # pulse starts after x [msec]
weights = np.zeros((num_of_neurons, num_of_neurons))
neurons_dicts = {}
euler = {}
tao = 1


def build_dicts():
    global neurons_dicts, euler
    for neuron in range(num_of_neurons):
        neurons_dicts[neuron] = {"r": np.zeros(size), "I_inj": np.zeros(size)}
        neurons_dicts[neuron]["r"][0] = np.random.randint(0, 500)
        euler[neuron] = euler_approximation(neuron)


def inject_current(neurons_to_inject, strength=pulse_strength, duration=pulse_time, when=start_I_time):
    global neurons_dicts
    if len(neurons_to_inject) > 0:
        for i in neurons_to_inject:
            neurons_dicts[i]["I_inj"][int(when): int(when + duration)] = strength  # Injecting Current Pulse for "duration" mseconds


def euler_approximation(neuron_index, delta=DELTA):
    # Check if the function name is in the derivative dictionary
    derivative = lambda t: (-neurons_dicts[neuron_index]["r"][t] + f(neurons_dicts[neuron_index]["I_inj"][t] + sum([neurons_dicts[neuron_index]["r"][t] * weights[neuron_index][index] for index in range(num_of_neurons)]))) / tao

    # Create a lambda function for Euler's method
    def euler_func(prev_y, prev_t):
        return prev_y + delta * derivative(prev_t)

    return euler_func


def generate_data(mseconds=total_time):
    global neurons_dicts, euler, weights
    weights = get_all_weights()
    for index in range(1, mseconds):
        for neuron in neurons_dicts.keys():
            # euler approximation of the neuron firing rate
            neurons_dicts[neuron]["r"][index] = euler[neuron](neurons_dicts[neuron]["r"][index - 1], index - 1)


def show_some_neurons_from_circular_network(inject=False, inhib=w_const):
    global w_const, DELTA
    neurons_to_show = np.random.randint(0, num_of_neurons, 10)
    build_dicts()
    # inject_current([0, 1, 10, 20, 30, 40, 5])
    addon = ""

    if inject:
        to_inject = np.random.choice(neurons_to_show, size=5, replace=False)
        inject_current(to_inject)
        addon = f"with current injections at {to_inject} "

    if inhib != w_const:
        w_const = inhib

    addon += f"(when inhibition ratio is {w_const})"
    generate_data()
    w_const = 0.2
    plt.figure(figsize=(14, 6))

    for i in neurons_to_show:
        plt.plot(time, neurons_dicts[i]["r"], label=f"Neuron #{i}", alpha=0.8)

    plt.ylim(-5, 8000)
    plt.title(f"Firing rate of 10 randomly selected neurons from the circular network, as a function of time {addon}")
    plt.xlabel("Time [msec]")
    plt.ylabel("Fire Rate [Hz]")
    plt.legend(loc="best")
    plt.show()


def main():
    show_neurons_and_permutations_of_suspects(suspects=[1, 2, 3], num_of_permutations=3)
    show_some_neurons_from_circular_network()
    show_some_neurons_from_circular_network(inhib=0)
    show_some_neurons_from_circular_network(inject=True, inhib=0)
    show_some_neurons_from_circular_network(inject=True, inhib=0.5)
    show_some_neurons_from_circular_network(inject=True, inhib=1)
    show_some_neurons_from_circular_network(inject=True, inhib=10)
    show_some_neurons_from_circular_network(inject=True, inhib=100)
    show_some_neurons_from_circular_network(inject=True, inhib=1000)


main()
