import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

head_direction = np.array([])
x_pos = np.array([])
y_pos = np.array([])
times = np.array([])
spikes = np.array([])
colors = ["blue", "green", "red", "orange", "cyan", "purple", "magenta"]


def neuron_data_by_index(index=0):
    global head_direction, x_pos, y_pos, times, spikes
    path = ""
    addon = ""
    if index < 4:
        path = "CSV files/" + os.listdir("CSV files")[index]
    elif index == 4:
        path = "CSV files/" + os.listdir("CSV files")[4]
        addon = "1"
    elif index == 5:
        path = "CSV files/" + os.listdir("CSV files")[4]
        addon = "2"

    head_direction = np.loadtxt(f"{path}/headDirection.csv", delimiter=",", dtype=float)
    x_pos = np.loadtxt(f"{path}/posx.csv", delimiter=",", dtype=float)
    y_pos = np.loadtxt(f"{path}/posy.csv", delimiter=",", dtype=float)
    spikes = np.loadtxt(f"{path}/spiketrain{addon}.csv", delimiter=",", dtype=float)
    times = np.array([i for i in range(spikes.shape[0])])


neuron_data_by_index()


def show_spikes_and_path(index, permutation_num=0, spikes_=spikes):
    global colors
    addon = ""
    if permutation_num:
        addon = f"(permutation #{permutation_num})"

    plt.figure(figsize=(10, 8))
    plt.plot(x_pos, y_pos, color="black", zorder=1, label="path")
    # The check I did on the second neuron:
    # spikes_times = times[(head_direction <= np.radians(75)) & (head_direction >= np.radians(45))]
    spikes_times = times[spikes_ == 1]
    plt.scatter(x_pos[spikes_times], y_pos[spikes_times], color=colors[index], linewidths=3, zorder=2, label="neuron fired")
    plt.title(f"Path of the rat and the locations where neuron #{index + 1} fired {addon}")
    # plt.title(f"Path of the rat and the locations where neuron #{index + 1} turned its head to around 60 degrees (between 45deg and 75deg)")
    plt.xlabel('X Position [cm]')
    plt.ylabel('Y Position [cm]')
    plt.legend(loc="best")
    plt.show()


def permutation(index=0, permutation_num=0, graph="path"):
    show = None
    if graph == "path":
        show = show_spikes_and_path
    elif graph == "heatmap":
        show = plot_neuron_heat_map

    fire = np.indices(spikes.shape)[0][spikes == 1]
    fire1 = fire[:-1]
    fire2 = fire[1:]
    spaces = fire2 - fire1
    spaces = np.hstack((fire[0], spaces))

    new_spikes = np.zeros(spikes.shape)
    new_spaces = spaces.copy()
    np.random.shuffle(new_spaces)
    tmp_index = new_spaces[0]
    new_spikes[tmp_index] = 1
    for space in new_spaces[1:]:
        tmp_index += space
        new_spikes[tmp_index] = 1

    show(spikes_=new_spikes, index=index, permutation_num=permutation_num)


def show_neurons_and_permutations_of_suspects(suspects, num_of_permutations=2, graph="path"):
    global spikes
    show = None
    if graph == "path":
        show = show_spikes_and_path
    elif graph == "heatmap":
        show = plot_neuron_heat_map

    for i in range(6):
        neuron_data_by_index(i)
        show(i, spikes_=spikes)
        if i in suspects:
            for num in range(num_of_permutations):
                permutation(i, num + 1, graph)


def plot_neuron_heat_map(index=0, spikes_=spikes, permutation_num=0):
    addon = ""
    if permutation_num:
        addon = f"(permutation #{permutation_num})"
    size_of_cells = 5  # [cm^2] , I chose this size because of the average length of a rats head which is around 45.3mm (based on this research, at table 2: https://www.sciencedirect.com/science/article/pii/S187603411730237X#tbl0010)
    num_of_cells_x = int(100 / size_of_cells)  # 100cm is the length of the arena of the rat (based on the previous Word document we were given)
    num_of_cells_y = num_of_cells_x  # The arena is a square (based on the previous Word document we were given)

    # The manipulation I did on the second neuron (removing the spikes when the head is at the preferred direction)
    # spikes_times = times[(head_direction <= np.radians(75)) & (head_direction >= np.radians(45))]
    # spikes_[spikes_times] = 0

    cell_counts, _ = np.histogramdd([y_pos, x_pos], bins=(num_of_cells_y, num_of_cells_x), weights=spikes_)
    time_in_cells, _ = np.histogramdd([y_pos, x_pos], bins=(num_of_cells_y, num_of_cells_x))
    firing_rate = cell_counts / (np.clip(time_in_cells, a_min=1, a_max=None) / 1000)  # I added np.clip in order to fix division by zero and with that the blank cells on the plot
    firing_rate = gaussian_filter(firing_rate, sigma=1)
    plt.figure(figsize=(10, 8))
    plt.imshow(firing_rate, origin='lower', cmap="inferno", extent=(0, num_of_cells_x * size_of_cells, 0, num_of_cells_y * size_of_cells), vmin=0, vmax=30)
    plt.colorbar(label='Firing Rate (Hz)')
    plt.xlabel('X Position [cm]')
    plt.xticks([i * size_of_cells for i in range(num_of_cells_x + 1)])
    plt.ylabel('Y Position [cm]')
    plt.yticks([i * size_of_cells for i in range(num_of_cells_y + 1)])
    plt.title(f"Heatmap of neuron #{index + 1} {addon}")
    plt.show()


def show_all_neurons_graphs(graph="path"):
    show = None
    if graph == "path":
        show = show_spikes_and_path
    elif graph == "heatmap":
        show = plot_neuron_heat_map

    for i in range(6):
        neuron_data_by_index(i)
        show(i)


def main():
    show_neurons_and_permutations_of_suspects([0, 1, 2, 3, 4], graph="path")
    show_neurons_and_permutations_of_suspects([0, 1, 2, 3, 4], graph="heatmap")


main()
