import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Path to the CSV file
CSV_FILE_PATH = "../../data/rps/rps.csv"

# Load the data
df = pd.read_csv(CSV_FILE_PATH)

# Convert necessary columns to appropriate types
df["idRound"] = df["idRound"].astype(int)
df["outcomeRound"] = df["outcomeRound"].astype(float)

# List of opponent strategies to consider
opponent_strategies = ["R-P-S"]

# **Fix Warning**: Ensure we work with a full copy
df_filtered = df[df["opponentStrategy"].isin(opponent_strategies)].copy()

# Custom color palette for models
color_palette = {
    'gpt-4.5-preview-2025-02-27': '#7abaff',  # BlueEscape
    'gpt-4.5-preview-2025-02-27 strategy': '#000037',  # BlueHorizon
    'llama3': '#32a68c',  # vertAvenir
    'mistral-small': '#ff6941',  # orangeChaleureux
    'mistral-small strategy': '#ffd24b',  # yellow determined
    'deepseek-r1': '#5862ed'  # indigoInclusif
}

# Group by model and round number, compute mean and standard deviation
summary = df_filtered.groupby(["model", "idRound"]).agg(
    mean_outcome=("outcomeRound", "mean"),
    std_outcome=("outcomeRound", "std"),
    count=("outcomeRound", "count")
).reset_index()

# Compute standard error (SEM)
summary["sem"] = summary["std_outcome"] / np.sqrt(summary["count"])

# Compute 95% confidence intervals
summary["ci_upper"] = summary["mean_outcome"] + (1.96 * summary["sem"])
summary["ci_lower"] = summary["mean_outcome"] - (1.96 * summary["sem"])

# Set the figure size
plt.figure(figsize=(10, 6))

# Loop through each model and plot its performance with confidence interval
for model in summary["model"].unique():
    df_model = summary[summary["model"] == model]

    # Plot mean outcome
    plt.plot(df_model["idRound"], df_model["mean_outcome"],
             label=model,
             color = color_palette.get(model, '#63656a'))  # Default to light gray if model not in palette

    # Plot confidence interval as a shaded region
    plt.fill_between(df_model["idRound"],
                     df_model["ci_lower"], df_model["ci_upper"],
                     color=color_palette.get(model, '#333333'),
                     alpha=0.2)  # Transparency for better visibility

# Add legends and labels
plt.xlabel("Round Number")
plt.ylabel("Average Points Earned")
plt.title("Average Points Earned per Round Against 3-Loop Behaviour (95% CI)")
plt.legend()
plt.grid(True)
plt.ylim(0, 2)  # Points are between 0 and 2

# Save the figure as an SVG file
plt.savefig('../../figures/rps/rps_3loop.svg', format='svg')