import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.stats import pearsonr
from scipy import stats
import warnings
from scipy.optimize import OptimizeWarning

# Ignore only OptimizeWarnings
warnings.simplefilter("ignore", OptimizeWarning)

# Parameters
n_bins = 100  # Number of histogram bins
trial_duration = 1.28  # Trial duration in seconds
data_arr = np.load("SpikesX10U12D.npy", allow_pickle=True)

# Dataset dimensions
nUnits, nDirections, nRepetitions = data_arr.shape

# ===============================
# 1. Firing Rate Statistics (Unit 1, Direction 1)
# ===============================
direction_rates = [
    len(data_arr[0, 0, rep][0]) / trial_duration for rep in range(nRepetitions)
]

# Basic statistics
direction_mean = np.mean(direction_rates)
direction_std = np.std(direction_rates)
direction_median = np.median(direction_rates)

print(f"Mean: {direction_mean:.2f}Hz \nSTD: {direction_std:.2f}Hz \nMedian: {direction_median:.2f}Hz\n")

# ===============================
# 2. PSTH Calculation
# ===============================

# Preallocate histogram counts
spikes_hist_counts = np.zeros((nUnits, nDirections, nRepetitions, n_bins))
bin_edges = np.linspace(0, trial_duration, n_bins + 1)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
bin_width = bin_edges[1] - bin_edges[0]

# Fill spikes_hist_counts
for unit in range(nUnits):
    for direction in range(nDirections):
        for rep in range(nRepetitions):
            spike_times = data_arr[unit, direction, rep][0]
            hist_counts, _ = np.histogram(spike_times, bins=bin_edges)
            spikes_hist_counts[unit, direction, rep, :] = hist_counts

# Plot PSTH for a selected unit
selected_unit = 2  # Unit index (0-based)
fig, axes = plt.subplots(4, 3, figsize=(12, 10), sharex=True, sharey=True)
fig.suptitle(f'Unit #{selected_unit + 1} - PSTH per Direction', fontsize=16)

for i, ax in enumerate(axes.flatten()):
    if i < nDirections:
        psth = np.sum(spikes_hist_counts[selected_unit, i], axis=0) / (nRepetitions * bin_width)
        ax.bar(bin_centers, psth, width=bin_width, alpha=0.7)
        ax.set_title(f'θ = {i * (360 // nDirections)}\u00b0')
        ax.set_xlim(0, trial_duration)
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Firing Rate (spikes/s)')
    else:
        ax.axis('off')

plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()


# ===============================
# 3. Orientation & Direction Tuning
# ===============================

# Functions to fit
def von_mises_direction(x, A, k, PO):
    return A * np.exp(k * np.cos(x - PO))


def von_mises_orientation(x, A, k, PO):
    return A * np.exp(k * np.cos(2 * (x - PO)))


def gaussian_direction(x, A, sigma, PO, B):
    
    # Handle circular nature of direction data
    delta = np.arctan2(np.sin(x - PO), np.cos(x - PO))
    return A * np.exp(-0.5 * (delta / sigma)**2) + B

def gaussian_orientation(x, A, sigma, PO, B):
    
    # Handle circular nature and pi-periodicity of orientation data
    delta = np.arctan2(np.sin(2 * (x - PO)), np.cos(2 * (x - PO))) / 2
    return A * np.exp(-0.5 * (delta / sigma)**2) + B


# Helper: fit function to data and return RMSE
def fit_and_evaluate(fit_func, x_data, y_data):
    params, _ = curve_fit(fit_func, x_data, y_data, maxfev=10000)
    y_fit = fit_func(x_data, *params)
    rmse = np.sqrt(np.mean((y_data - y_fit) ** 2))
    return params, rmse


# Mean/STD firing rate across repetitions
mean_rates = np.zeros((nUnits, nDirections))
std_rates = np.zeros((nUnits, nDirections))

for unit in range(nUnits):
    for direction in range(nDirections):
        rates = [len(data_arr[unit, direction, rep][0]) / trial_duration for rep in range(nRepetitions)]
        mean_rates[unit, direction] = np.mean(rates)
        std_rates[unit, direction] = np.std(rates)

# Fit models
x_data_deg = np.linspace(0, 360, nDirections, endpoint=False)
x_data_rad = np.deg2rad(x_data_deg)

fitted_models_vm = []
rmses_vm = []
fitted_models_gauss = []
rmses_gauss = []

for unit in range(nUnits):
    # Von-Mises Fit
    params_dir_vm, rmse_dir_vm = fit_and_evaluate(von_mises_direction, x_data_rad, mean_rates[unit])
    params_ori_vm, rmse_ori_vm = fit_and_evaluate(von_mises_orientation, x_data_rad, mean_rates[unit])
    best_vm = (von_mises_direction if rmse_dir_vm < rmse_ori_vm else von_mises_orientation, params_dir_vm if rmse_dir_vm < rmse_ori_vm else params_ori_vm)
    rmses_vm.append(min(rmse_dir_vm, rmse_ori_vm))
    fitted_models_vm.append(lambda x, func=best_vm[0], params=best_vm[1]: func(x, *params))

    # Gaussian Fit
    params_dir_gauss, rmse_dir_gauss = fit_and_evaluate(gaussian_direction, x_data_rad, mean_rates[unit])
    params_ori_gauss, rmse_ori_gauss = fit_and_evaluate(gaussian_orientation, x_data_rad, mean_rates[unit])
    best_gauss = (gaussian_direction if rmse_dir_gauss < rmse_ori_gauss else gaussian_orientation, params_dir_gauss if rmse_dir_gauss < rmse_ori_gauss else params_ori_gauss)
    rmses_gauss.append(min(rmse_dir_gauss, rmse_ori_gauss))
    fitted_models_gauss.append(lambda x, func=best_gauss[0], params=best_gauss[1]: func(x, *params))

print(f"Von Mises Average RMSE: {np.mean(rmses_vm):.3f}")
print(f"Gaussian Average RMSE: {np.mean(rmses_gauss):.3f}\n")


# Plot fits
def plot_fits(models, title):
    x_plot_deg = np.linspace(0, 360, 100)
    x_plot_rad = np.deg2rad(x_plot_deg)
    fig, axes = plt.subplots(4, 3, figsize=(12, 10))
    fig.suptitle(title, fontsize=16)

    for i, ax in enumerate(axes.flatten()):
        if i < nUnits:
            ax.errorbar(x_data_deg, mean_rates[i], yerr=std_rates[i], fmt='o', markerfacecolor="white", zorder=0, label='Data', capsize=5)
            ax.plot(x_plot_deg, models[i](x_plot_rad), '-', label='Fit')
            ax.set_title(f'Unit {i + 1}')
            ax.set_xlabel('Direction (deg)')
            ax.set_ylabel('Firing Rate (Hz)')
            ax.legend()
        else:
            ax.axis('off')

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


plot_fits(fitted_models_vm, 'Von Mises Fits per Unit')
plot_fits(fitted_models_gauss, 'Gaussian Fits per Unit')

# ===============================
# 4. Correlation Between Mean Rate and Variability
# ===============================

unit_means = mean_rates.mean(axis=1)
unit_stds = std_rates.std(axis=1)

correlation, p_value = pearsonr(unit_means, unit_stds)
print(f"Correlation: {correlation:.2f}, p-value: {p_value:.3f}\n")

# Plot
plt.figure(figsize=(8, 6))
plt.scatter(unit_means, unit_stds, marker='o', alpha=0.7)
plt.plot(np.unique(unit_means), np.poly1d(np.polyfit(unit_means, unit_stds, 1))(np.unique(unit_means)), color='red')
plt.title('Correlation Between Mean Firing Rate and Variability')
plt.xlabel('Mean Firing Rate (Hz)')
plt.ylabel('STD of Firing Rates (Hz)')
plt.grid(True)
plt.text(0.05, 0.95, f'Pearson r={correlation:.2f}\np={p_value:.3f}', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top')
plt.show()

# ===============================
# 5. Hypothesis Testing
# ===============================

# Choosing a unit and two directions (unit 1, directions 0° and 180°)
unit_index = 0  # First unit
dir1_index = 0  # 0 degrees
dir2_index = 6  # 180 degrees

# Collect firing rates for all repetitions
rates_dir1 = [len(data_arr[unit_index, dir1_index, rep][0]) / trial_duration for rep in range(nRepetitions)]
rates_dir2 = [len(data_arr[unit_index, dir2_index, rep][0]) / trial_duration for rep in range(nRepetitions)]

# Perform paired t-test
t_stat, p_val = stats.ttest_rel(rates_dir1, rates_dir2)

print(f"Hypothesis Testing Results (Unit {unit_index + 1}):")
print(f"Comparing Direction {dir1_index * (360 // nDirections)}\u00b0 vs {dir2_index * (360 // nDirections)}\u00b0")
print(f"t-statistic = {t_stat:.3f}")
print(f"p-value = {p_val:.5f}")

# Conclusion
if p_val < 0.05:
    print("Result: Significant difference (p < 0.05)\n")
else:
    print("Result: No significant difference (p >= 0.05)\n")


# ===============================
# 6. Bonus
# ===============================

# Hybrid model combining von Mises and Gaussian

def hybrid_function(x, A_vm, k, PO_vm, A_g, sigma, PO_g, B, w):
    """
    A hybrid function that combines von Mises and Gaussian for circular data
    
    Parameters:
    ----------
    x : array
        Input angles in radians
    A_vm : float
        Amplitude for von Mises component
    k : float
        Concentration parameter for von Mises
    PO_vm : float
        Preferred orientation/direction for von Mises in radians
    A_g : float
        Amplitude for Gaussian component
    sigma : float
        Standard deviation for Gaussian component
    PO_g : float
        Preferred orientation/direction for Gaussian in radians
    B : float
        Baseline/offset parameter
    w : float
        Weight parameter (between 0 and 1) balancing von Mises vs Gaussian
        w=0: pure Gaussian, w=1: pure von Mises
    
    Returns:
    -------
    array : The hybrid model values
    """
    # Von Mises component - using the more general form that works for both direction and orientation
    # The model will decide which periodicity fits better during optimization
    vm_component = A_vm * np.exp(k * np.cos(x - PO_vm))
    
    # Gaussian component with circular correction
    delta = np.arctan2(np.sin(x - PO_g), np.cos(x - PO_g))
    gauss_component = A_g * np.exp(-0.5 * (delta / sigma)**2) + B
    
    # Weighted combination
    return w * vm_component + (1 - w) * gauss_component


def hybrid_function_orientation(x, A_vm, k, PO_vm, A_g, sigma, PO_g, B, w):
    """
    A hybrid function specifically for orientation tuning (pi-periodic)
    
    Parameters same as hybrid_function but adapted for orientation data
    """
    # Von Mises component for orientation
    vm_component = A_vm * np.exp(k * np.cos(2 * (x - PO_vm)))
    
    # Gaussian component with circular correction for orientation
    delta = np.arctan2(np.sin(2 * (x - PO_g)), np.cos(2 * (x - PO_g))) / 2
    gauss_component = A_g * np.exp(-0.5 * (delta / sigma)**2) + B
    
    # Weighted combination
    return w * vm_component + (1 - w) * gauss_component

def fit_and_evaluate_hybrid(x_data, y_data):
    """
    Try both direction and orientation versions of the hybrid model
    and return the better fitting one
    """
    try:
        # Initial parameter guesses - might need adjustment based on your data
        p0_dir = [np.max(y_data)/2, 2.0, 0.0, np.max(y_data)/2, 0.5, 0.0, np.min(y_data), 0.5]
        p0_ori = [np.max(y_data)/2, 2.0, 0.0, np.max(y_data)/2, 0.5, 0.0, np.min(y_data), 0.5]
        
        # Try direction model
        params_dir, _ = curve_fit(hybrid_function, x_data, y_data, p0=p0_dir, 
                                 bounds=([0, 0, -np.pi, 0, 0.01, -np.pi, 0, 0], 
                                         [np.inf, 20, np.pi, np.inf, np.pi, np.pi, np.inf, 1]),
                                 maxfev=10000)
        y_fit_dir = hybrid_function(x_data, *params_dir)
        rmse_dir = np.sqrt(np.mean((y_data - y_fit_dir) ** 2))
        
        # Try orientation model
        params_ori, _ = curve_fit(hybrid_function_orientation, x_data, y_data, p0=p0_ori,
                                 bounds=([0, 0, -np.pi/2, 0, 0.01, -np.pi/2, 0, 0], 
                                         [np.inf, 20, np.pi/2, np.inf, np.pi/2, np.pi/2, np.inf, 1]),
                                 maxfev=10000)
        y_fit_ori = hybrid_function_orientation(x_data, *params_ori)
        rmse_ori = np.sqrt(np.mean((y_data - y_fit_ori) ** 2))
        
        # Return the better model
        if rmse_dir <= rmse_ori:
            return params_dir, rmse_dir, hybrid_function
        else:
            return params_ori, rmse_ori, hybrid_function_orientation
            
    except RuntimeError:
        # Fallback in case of fitting errors
        print("Warning: Hybrid model fitting failed, returning fallback values")
        return [0, 0, 0, 0, 0, 0, 0, 0.5], np.inf, hybrid_function

fitted_models_hybrid = []
rmses_hybrid = []

for unit in range(nUnits):
    params_hybrid, rmse_hybrid, func_hybrid = fit_and_evaluate_hybrid(x_data_rad, mean_rates[unit])
    rmses_hybrid.append(rmse_hybrid)
    fitted_models_hybrid.append(lambda x, func=func_hybrid, params=params_hybrid: func(x, *params))

print(f"Von Mises Average RMSE: {np.mean(rmses_vm):.3f}")
print(f"Gaussian Average RMSE: {np.mean(rmses_gauss):.3f}")
print(f"Hybrid Model Average RMSE: {np.mean(rmses_hybrid):.3f}\n")

# Also plot the hybrid fits
plot_fits(fitted_models_hybrid, 'Hybrid Model Fits per Unit')
