"""Module for handling WebSocket connections."""

import asyncio
import ctypes
import logging
import threading
import uuid
from typing import Any, Callable, Coroutine, Iterator, Optional

from websockets import ConnectionClosed

from ..io import IOCache
from ..wrapper.wrapper import FunctionWrapper
from ..wrapper.wrapper_json_parser import Functions
from .dto import ExceptionWrapper, LogRecord, Server2ClientType, WebSocketResponse

logger = logging.getLogger("PEW")

thread_local = threading.local()


class WebSocketLoggingHandler(logging.Handler):
    """Logging handler that sends log records to the client."""

    def emit(self, record: logging.LogRecord) -> None:
        """Sends a log record to the client.

        Args:
            record (logging.LogRecord): Log record to be sent.
        """
        ws_conn = getattr(thread_local, "ws_conn", None)
        if ws_conn:
            ws_conn.send(Server2ClientType.LOG, data=LogRecord(record).model_dump())


class WebSocketConnectionWrapper:
    """Wrapper class for sending messages to the client."""

    def __init__(
        self,
        send_message: Callable[[str], Coroutine[None, None, None]],
        loop: asyncio.AbstractEventLoop,
    ):
        """Initializes a WebSocketConnectionWrapper.

        Args:
            send_message (Callable[[str], Coroutine[None, None, None]]):
                Callback function to send message to the client.
            loop (asyncio.AbstractEventLoop): Event loop.
        """
        self._send_message = send_message
        self._loop = loop

    def send(
        self,
        type: Server2ClientType,
        *,
        message: Optional[str] = None,
        data: Optional[dict[str, Any]] = None,
    ) -> None:
        """Sends a message to the client.

        Args:
            type (Server2ClientType): Message type.
            message (Optional[str], optional): Message to be sent. Defaults to None.
            data (Optional[Dict[str, Any]], optional): Additional data to be sent. Defaults to None.
        """
        execution_id = getattr(thread_local, "exec_id", None)
        try:
            asyncio.run_coroutine_threadsafe(
                self._send_message(
                    WebSocketResponse(
                        type=type,
                        execution_id=execution_id,
                        message=message,
                        data=data,
                    ).model_dump_json()
                ),
                self._loop,
            )
        except ConnectionClosed:
            logger.error(
                f"Websocket connection closed before sending message with execution id '{execution_id}'."
            )


class ExecutorThread(threading.Thread):
    """Thread implementation for function chain executors."""

    def __init__(
        self,
        exec_id: str,
        operator_iter: Iterator[str],
        ws_conn: WebSocketConnectionWrapper,
        operator_completed_hook: Optional[Callable[[int, str], None]] = None,
        finished_hook: Optional[Callable[[], None]] = None,
    ):
        """Initilaizes an ExecutorThread.

        Args:
            exec_id (str): Execution identifier.
            operator_iter (Iterator[str]):
                Iterator object that executes the business logic in cooperative multitasking.
            ws_conn (WebSocketConnectionWrapper):
                Wrapper object for sending messages to the client.
            operator_completed_hook (Optional[Callable[[int, str], None]], optional):
                Hook function to be called after each operator is executed. Defaults to None.
            finished_hook (Optional[Callable[[], None]], optional):
                Hook function to be called when the thread finishes. Defaults to None.
        """
        threading.Thread.__init__(self, daemon=True)
        self.exec_id = exec_id
        self.operator_iter = operator_iter
        self._ws_conn = ws_conn
        self._operator_completed_hook = operator_completed_hook
        self._finished_hook = finished_hook
        self._stop_event = threading.Event()
        self.status = Server2ClientType.STARTED
        self.message = None
        self.exception = None

    def run(self):
        """Method representing the thread's activity."""
        thread_local.exec_id = self.exec_id
        thread_local.ws_conn = self._ws_conn
        try:
            for i, op_name in enumerate(self.operator_iter):
                if self._operator_completed_hook:
                    self._operator_completed_hook(i, op_name)
                if self._stop_event.is_set():
                    raise InterruptedError()
            self.status = Server2ClientType.FINISHED
        except (InterruptedError, SystemExit):
            self.status = Server2ClientType.STOPPED
            logger.error("Execution was stopped")
        except Exception as ex:
            self.status = Server2ClientType.ERROR
            self.message = str(ex)
            self.exception = ExceptionWrapper(
                ex, skip_frames=1
            )  # Skip the wrapper frame
        finally:
            logger.debug(f"Execution ended, status = {self.status}")
            if self._finished_hook:
                self._finished_hook()
            del thread_local.exec_id
            del thread_local.ws_conn

    def stop(self):
        """Terminates the thread gracefully."""
        self._stop_event.set()

    def kill(self):
        """Kills the thread raising a SystemExit exception."""
        if not self.is_alive() or not self.ident:
            return
        thread_id = ctypes.c_ulong(self.ident)
        # Asynchronously raise an exception in the thread
        res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
            thread_id, ctypes.py_object(SystemExit)
        )
        if res > 1:
            # Error -> revert the effect
            ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, None)
            logger.error("Cannot kill execution")


class ExecutionHandler:
    """Handles ExecutorThread management."""

    def __init__(self, loop: asyncio.AbstractEventLoop):
        """Initializes an ExecutionHandler.

        Args:
            loop (asyncio.AbstractEventLoop): Event loop.
        """
        self.executions: dict[str, ExecutorThread] = {}
        self._loop = loop
        self._get_time = loop.time
        self._last_activity_time = self._get_time()

    def get_idle_time(self) -> float:
        """Returns the time elapsed since the last activity."""
        if self.executions:
            return 0.0
        return self._get_time() - self._last_activity_time

    async def start(
        self,
        functions: Functions,
        io_cache: IOCache,
        send_message: Callable[[str], Coroutine[None, None, None]],
    ) -> str:
        """Starts the execution in a worker thread.

        Args:
            functions (Functions): Functions to be executed.
            io_cache (IOCache): Cache for storing Python operator I/O.
            send_message (Callable[[str], Coroutine]): Callback function to send message to client.

        Returns:
            str: Execution id.
        """
        exec_id = str(uuid.uuid4())
        ws_conn = WebSocketConnectionWrapper(send_message, self._loop)
        function_executor = FunctionWrapper().executor(functions, io_cache)

        def operator_completed_hook(i: int, op_name: str):
            ws_conn.send(
                Server2ClientType.OPERATOR_COMPLETED,
                message=f"Operator {i} '{op_name}' completed",
                data={"operator_index": i, "operator_name": op_name},
            )

        def finished_hook():
            thread = self.executions[exec_id]
            try:
                ws_conn.send(
                    thread.status,
                    message=thread.message,
                    data=thread.exception.model_dump() if thread.exception else None,
                )
            except Exception:
                ws_conn.send(
                    Server2ClientType.ERROR,
                    message="An error occurred while sending the final message",
                )
            finally:
                self._last_activity_time = self._get_time()
                del self.executions[exec_id]

        thread = ExecutorThread(
            exec_id, function_executor, ws_conn, operator_completed_hook, finished_hook
        )
        self.executions[exec_id] = thread
        thread.start()
        return exec_id

    async def stop(self, exec_id: str, patience: float) -> None:
        """Terminates the execution identified by exec_id.

        Args:
            exec_id (str): Execution identifier.
            patience (float): Seconds before killing the execution.
        """
        thread = self.executions[exec_id]
        if patience > 0.0:
            thread.stop()
            await asyncio.sleep(patience)
        if thread.is_alive():
            thread.kill()


    def __contains__(self, exec_id: str) -> bool:
        return exec_id in self.executions
