import os
import asyncio
import csv
import random
import re
import json
import requests
from typing import Dict, Literal, List, Callable
from pydantic import BaseModel, ValidationError
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient

# Load API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PAGODA_API_KEY = os.getenv("PAGODA_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("Missing OPENAI_API_KEY. Set it as an environment variable.")
if not PAGODA_API_KEY:
    raise ValueError("Missing PAGODA_API_KEY. Set it as an environment variable.")

CSV_FILE_PATH = "../../data/rps/rps.csv"

# Define the expected response format as a Pydantic model
class AgentResponse(BaseModel):
    move: Literal["Rock", "Paper", "Scissors"]
    motivations: str

class RPS:
    def __init__(self, model: str, temperature: float, game_id: int, opponent_strategy_fn: Callable[[List[Dict]], str], strategy=False, max_retries: int = 3):
        self.model = model
        self.temperature = temperature
        self.game_id = game_id
        self.strategy = strategy
        self.max_retries = max_retries
        self.history: List[Dict] = []
        self.player_score_game = 0
        self.opponent_strategy_fn = opponent_strategy_fn

        is_openai_model = model.startswith("gpt")
        is_pagoda_model = ":" in model

        self.base_url = (
            "https://api.openai.com/v1" if is_openai_model else
            "https://ollama-ui.pagoda.liris.cnrs.fr/ollama/api/generate" if is_pagoda_model else
            "http://localhost:11434/v1"
        )

        model_info = {
            "temperature": self.temperature,
            "function_calling": True,
            "parallel_tool_calls": True,
            "family": "unknown",
            "json_output": True,
            "vision": False
        }

        self.model_client = OpenAIChatCompletionClient(
            timeout=60,
            model=self.model,
            base_url=self.base_url,
            api_key=OPENAI_API_KEY,
            model_info=model_info,
            response_format=AgentResponse
        )

    async def play_round(self, round_id: int) -> Dict:
        opponent_move = self.opponent_strategy_fn(self.history)
        if self.strategy:
            move, reasoning = self.apply_strategy()
        else:
            move, reasoning = await self.model_based_prediction()
        outcome = self.determine_winner(move, opponent_move)
        self.update_score(outcome)
        round_result = {
            "Agent Move": move,
            "Opponent Move": opponent_move,
            "Motivations": reasoning,
            "Outcome": outcome
        }
        self.history.append(round_result)
        print(f"Round {round_id}: {self.player_score_game}")
        return round_result

    async def model_based_prediction(self) -> Dict:
        history_summary = self.get_history_summary()
        instruction = f"""
        You are playing Rock-Paper-Scissors.

        Rules:
        - Rock beats Scissors
        - Scissors beats Paper
        - Paper beats Rock
        - Tie = same move

        Win = 2 points, Tie = 1 point, Loss = 0 points.

        Game History:
        {history_summary}

        Choose your next move: Rock, Paper, or Scissors.

        Respond ONLY with JSON format: {{
            "move": "Rock" | "Paper" | "Scissors",
            "motivations": "why you chose it"
        }}
        """

        if ":" in self.model:  # Pagoda
            return await self.run_pagoda(instruction)

        for attempt in range(1, self.max_retries + 1):
            agent = AssistantAgent(
                name="Player",
                model_client=self.model_client,
                system_message="You are a helpful assistant."
            )
            response = await agent.on_messages(
                [TextMessage(content=instruction, source="user")],
                cancellation_token=CancellationToken(),
            )
            try:
                response_data = response.chat_message.content
                agent_response = AgentResponse.model_validate_json(response_data)
                return agent_response.move, agent_response.motivations
            except (ValidationError, json.JSONDecodeError) as e:
                print(f"Attempt {attempt}: Failed to parse model response. Error: {e}")
        raise ValueError("Model failed to provide a valid response after multiple attempts.")

    async def run_pagoda(self, instruction: str):
        headers = {
            "Authorization": f"Bearer {PAGODA_API_KEY}",
            "Content-Type": "application/json"
        }
        payload = {
            "model": self.model,
            "temperature": self.temperature,
            "prompt": instruction,
            "stream": False
        }

        for attempt in range(self.max_retries):
            try:
                response = requests.post(self.base_url, headers=headers, json=payload)
                response.raise_for_status()
                response_data = response.json()
                raw_response = response_data.get("response", "")
                parsed_json = self.extract_json_from_response(raw_response)

                if not parsed_json:
                    print(f"Attempt {attempt+1}: Could not parse JSON - Raw response: {raw_response}")
                    continue

                agent_response = AgentResponse(**parsed_json)
                if agent_response.move in ["Rock", "Paper", "Scissors"]:
                    return agent_response.move, agent_response.motivations
            except Exception as e:
                print(f"Attempt {attempt+1}: Error in run_pagoda - {e}")
        raise ValueError("run_pagoda failed to get a valid response.")

    def extract_json_from_response(self, text: str) -> dict:
        try:
            json_str = re.search(r"\{.*\}", text, re.DOTALL)
            if json_str:
                return json.loads(json_str.group())
        except Exception as e:
            print(f"Error extracting JSON: {e}")
        return {}

    def apply_strategy(self):
        """Play the next move using a heuristic."""
        opponent_move = self.opponent_strategy_fn(self.history)
        # Determine the best counter move
        counter_moves = {"Rock": "Paper", "Paper": "Scissors", "Scissors": "Rock"}
        if self.model == "random":
            move = random.choice(["Rock", "Paper", "Scissors"])
            motivations = "No history available. Choosing randomly."
        if self.model == "gpt-4.5-preview-2025-02-27":
            if not self.history:
                move = random.choice(["Rock", "Paper", "Scissors"])
                motivations = "No history available. Choosing randomly."
            else:
                # Count occurrences of each move
                move_counts = {"Rock": 0, "Paper": 0, "Scissors": 0}
                for round_data in self.history:
                    move_counts[round_data["Opponent Move"]] += 1
                # Find the most common move
                most_common_move = max(move_counts, key=move_counts.get)
                predicted_move = most_common_move
                move = counter_moves[predicted_move]
                motivations = f"Based on history, the opponent most frequently played {most_common_move}, so I play {move}."
        if self.model == "llama3":
            move = "None"
            motivations = "error"
        if self.model == "mistral-small":
            if not self.history:
                # If there is no history, we can't make an educated guess.
                move = "Scissors"
                motivations = "No game history available."
            opponent_moves = [move['Opponent Move'] for move in self.history]
            move_count = {
                'Rock': opponent_moves.count('Rock'),
                'Paper': opponent_moves.count('Paper'),
                'Scissors': opponent_moves.count('Scissors')
            }
            # Determine the most frequent move
            max_move = max(move_count, key=move_count.get)
            move = counter_moves[max_move]
            if move_count[max_move] > 0:
                motivations = f"Play {move} since predicted {max_move} because it has been played {move_count[max_move]} times."
            else:
                motivations = "Unable to determine a pattern; defaulting to Scissors."
        if self.model == "deepseek-r1":
            move = "None"
            motivations = "error"

        outcome = self.determine_winner(move, opponent_move)
        self.update_score(outcome)  # Use the correct outcome here
        return move, motivations

    @staticmethod
    def determine_winner(player_move: str, opponent_move: str) -> int:
        if player_move == "None":
            return 0
        win_conditions = {"Rock": "Scissors", "Scissors": "Paper", "Paper": "Rock"}
        if player_move == opponent_move:
            return 1  # Tie
        elif win_conditions[player_move] == opponent_move:
            return 2  # Win
        else:
            return 0  # Loss

    # Sample opponent strategy
    def simple_opponent_strategy(history):
        moves = ["Rock", "Paper", "Scissors"]
        return moves[len(history) % 3]

    def update_score(self, outcome: int):
        """Updates the score based on the outcome."""
        if outcome == 2:  # Win
            self.player_score_game += 2
        elif outcome == 1:  # Tie
            self.player_score_game += 1
        # No update needed for Loss (0)

    def get_history_summary(self) -> str:
        if not self.history:
            return "This is the first round."
        summary = "\n".join(
            [
                f"Round {i + 1}: You played {r['Agent Move']}, Opponent played {r['Opponent Move']}. Outcome: {r['Outcome']}"
                for i, r in enumerate(self.history)]
        )
        summary += f"\nCurrent Score - You: {self.player_score_game}\n"
        return summary


# Runner
async def main():
    game = RPS(
        model="mixtral:8x7b",#  "llama3.3:latest", "mixtral:8x7b", "deepseek-r1:7b"
        temperature=0.7,
        game_id=1,
        opponent_strategy_fn=simple_opponent_strategy,
        strategy=False  # or True for rule-based
    )
    num_rounds = 10
    for round_id in range(1, num_rounds + 1):
        await game.play_round(round_id)
    print(f"Final Score: {game.player_score_game}")

if __name__ == "__main__":
    asyncio.run(main())