import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Path to the CSV file
CSV_FILE_PATH = "../../data/rps/rps.csv"

# Load the data
df = pd.read_csv(CSV_FILE_PATH)

# Convert necessary columns to appropriate types
df["idRound"] = df["idRound"].astype(int)
df["outcomeRound"] = df["outcomeRound"].astype(float)

# List of opponent strategies to consider
opponent_strategies = ["always_rock", "always_paper", "always_scissor"]

# **Fix Warning**: Ensure we work with a full copy
df_filtered = df[df["opponentStrategy"].isin(opponent_strategies)].copy()

# Custom color palette for models
color_palette = {
    'random': '#63656a',  # gray
    'gpt-4.5-preview-2025-02-27': '#7abaff',  # BlueEscape
    'llama3': '#32a68c',  # vertAvenir
    'mistral-small': '#ff6941',  # orangeChaleureux
    'deepseek-r1': '#5862ed'  # indigoInclusif
}

# Compute mean, standard error (SEM), and 95% confidence interval by model and round
agg_data = df_filtered.groupby(["model", "idRound"]).agg(
    mean_outcome=("outcomeRound", "mean"),
    sem_outcome=("outcomeRound", lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))  # Standard error
).reset_index()

# Compute 95% Confidence Interval (CI)
agg_data["ci95"] = 1.96 * agg_data["sem_outcome"]  # 95% confidence interval

# Set the figure size
plt.figure(figsize=(10, 6))

# Loop through each model and plot its aggregated performance across rounds
for model in agg_data["model"].unique():
    df_model = agg_data[agg_data["model"] == model]
    color = color_palette.get(model, '#333333')  # Default to dark gray if model not in palette

    # Plot mean values
    plt.plot(df_model["idRound"], df_model["mean_outcome"], label=model, color=color)

    # Add 95% confidence interval (shaded region)
    plt.fill_between(df_model["idRound"],
                     df_model["mean_outcome"] - df_model["ci95"],  # Lower bound (95% CI)
                     df_model["mean_outcome"] + df_model["ci95"],  # Upper bound (95% CI)
                     color=color, alpha=0.2)  # Transparency for shading

# Add legends and labels
plt.xlim(1, 10)
plt.xlabel("Round Number")
plt.ylabel("Average Points Earned")
plt.title("Average Points Earned per Round Against Constant Behaviour (with 95% Confidence Interval)")
plt.legend()
plt.grid(True)
plt.ylim(0, 2)  # Points are between 0 and 2

# Save the figure as an SVG file
plt.savefig('../../figures/rps/rps_constant.svg', format='svg')