"""WebSocket handler for the server."""

import asyncio
import http
import logging
import platform
import signal
from typing import Any, Callable, Coroutine

import websockets

from ..io import IOCache
from ..wrapper import Functions
from ._core import ExecutionHandler
from .dto import (
    Client2ServerType,
    IOCacheEvictionParams,
    Server2ClientType,
    TerminationParams,
    WebSocketRequest,
    WebSocketResponse,
)

logger = logging.getLogger("PEW")
logger.propagate = False


class WebSocketHandler:
    """Handles WebSocket connections."""

    def __init__(self, loop: asyncio.AbstractEventLoop):
        """Initializes a WebSocketHandler.

        Args:
            loop (asyncio.AbstractEventLoop): Event loop.
        """
        self.executor = ExecutionHandler(loop)
        self.io_cache = IOCache()

    async def handler(self, websocket: websockets.WebSocketServerProtocol):
        """Handles the WebSocket connection.

        Args:
            websocket (websockets.WebSocketServerProtocol): WebSocket connection.
        """
        async for message in websocket:
            try:
                request = WebSocketRequest.model_validate_json(message)
                if request.type == Client2ServerType.START:
                    functions = Functions.model_validate(request.data)
                    exec_id = await self._start_execution(functions, websocket.send)
                    await websocket.send(
                        WebSocketResponse(
                            type=Server2ClientType.STARTED, execution_id=exec_id
                        ).model_dump_json()
                    )
                elif request.type == Client2ServerType.STOP:
                    stop_data = TerminationParams.model_validate(request.data)
                    await self._stop_execution(stop_data)
                elif request.type == Client2ServerType.EVICT_IO_CACHE:
                    eviction_data = IOCacheEvictionParams.model_validate(request.data)
                    await self._evict_io_cache(eviction_data)
                else:
                    logger.warning(f"Unknown request type: {request.type}")
            except Exception as ex:
                logger.error(f"Request failed with the following exception:\n{str(ex)}")
                await websocket.send(
                    WebSocketResponse(
                        type=Server2ClientType.ERROR, execution_id=None, message=str(ex)
                    ).model_dump_json()
                )

    async def idle_timer(self, stop_future: asyncio.Future[Any], timeout: int):
        """Stops the server if it is idle for a certain amount of time.

        Args:
            stop_future (asyncio.Future): Future to stop the server.
            timeout (int): Timeout in seconds.
        """
        while not stop_future.done():
            await asyncio.sleep(timeout - self.executor.get_idle_time())
            if self.executor.get_idle_time() >= timeout:
                logger.info("Idle timeout reached, stopping the server.")
                stop_future.set_result(None)

    async def _start_execution(
        self,
        functions: Functions,
        send_message: Callable[[str], Coroutine[None, None, None]],
    ) -> str:
        exec_id = await self.executor.start(functions, self.io_cache, send_message)
        logger.info(f"Execution {exec_id} has been started")
        return exec_id

    async def _stop_execution(self, stop_data: TerminationParams) -> None:
        if stop_data.execution_id in self.executor:
            logger.info(
                f"Stopping of execution {stop_data.execution_id} has been requested"
            )
            await self.executor.stop(stop_data.execution_id, stop_data.patience)
        else:
            raise ValueError(f"No execution with id {stop_data.execution_id}")

    async def _evict_io_cache(self, eviction_data: IOCacheEvictionParams) -> None:
        if eviction_data.clear_all:
            self.io_cache.clear()
            logger.info("IO cache has been cleared")
        elif eviction_data.evict_keys:
            for key in eviction_data.evict_keys:
                if key in self.io_cache:
                    del self.io_cache[key]
            logger.info(f"Keys {eviction_data.evict_keys} have been evicted")


async def health_check(
    path: str, request_headers: websockets.Headers
) -> tuple[int, websockets.Headers, bytes] | None:
    """Handles the health check request.

    Args:
        path (str): Request path.
        request_headers (websockets.Headers): Request headers.
    """
    if path == "/healthz":
        return http.HTTPStatus.OK, websockets.Headers(), b"OK\n"


async def serve(port: int, username: str | None, pw: str | None, idle_timeout: int):
    """Starts the WebSocket server.

    Args:
        port (int): Port number.
        username (str | None): Username for Basic authentication.
        pw (str | None): Password for Basic authentication.
        idle_timeout (int): Timeout in seconds for idle server.
    """
    loop = asyncio.get_running_loop()
    stop = loop.create_future()

    # Handle SIGTERM on Unix systems
    if platform.system() != "Windows":
        loop.add_signal_handler(signal.SIGTERM, stop.set_result, None)

    auth_protocol = (
        websockets.basic_auth_protocol_factory(credentials=(username, pw))
        if username and pw
        else None
    )

    ws_handler = WebSocketHandler(loop)

    if idle_timeout >= 0:
        asyncio.ensure_future(ws_handler.idle_timer(stop, idle_timeout))

    async with websockets.serve(
        ws_handler.handler,
        "localhost",
        port,
        create_protocol=auth_protocol,
        process_request=health_check,
    ):
        await stop
