"""Starts the WebSocket server for the Python Extension Wrapper."""

import asyncio
import http
import logging
import platform
import signal
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Any, Optional

import websockets
from altair_aitools.ext.runtime import RuntimeConfig

from pew.util import setup_logging
from pew.websocket import ExecutionHandler, WebSocketLoggingHandler
from pew.websocket.dto import (
    Client2ServerType,
    Server2ClientType,
    TerminationParams,
    WebSocketRequest,
    WebSocketResponse,
)
from pew.wrapper import Functions

logger = logging.getLogger("PEW")
logger.propagate = False


async def idle_timer(stop_future: asyncio.Future[Any], timeout: int):
    """Stops the server if it is idle for a certain amount of time.

    Args:
        stop_future (asyncio.Future): Future to stop the server.
        timeout (int): Timeout in seconds.
    """
    while not stop_future.done():
        await asyncio.sleep(timeout - executor.get_idle_time())
        if executor.get_idle_time() >= timeout:
            logger.info("Idle timeout reached, stopping the server.")
            stop_future.set_result(None)


async def handler(websocket: websockets.WebSocketServerProtocol):
    """Handles the WebSocket connection.

    Args:
        websocket (websockets.WebSocketServerProtocol): WebSocket connection.
    """
    async for message in websocket:
        request = WebSocketRequest.model_validate_json(message)
        if request.type == Client2ServerType.START:
            try:
                functions = Functions.model_validate(request.data)
                exec_id = await executor.start(functions, websocket.send)
                await websocket.send(
                    WebSocketResponse(
                        type=Server2ClientType.STARTED, execution_id=exec_id
                    ).model_dump_json()
                )
                logger.info(f"Execution {exec_id} has been started")
            except Exception as ex:
                logger.error(
                    f"Start command failed with the following exception:\n{str(ex)}"
                )
                await websocket.send(
                    WebSocketResponse(
                        type=Server2ClientType.ERROR, execution_id=None, message=str(ex)
                    ).model_dump_json()
                )
        elif request.type == Client2ServerType.STOP:
            stop_data = TerminationParams.model_validate(request.data)
            if stop_data.execution_id in executor:
                logger.info(
                    f"Stopping of execution {stop_data.execution_id} has been requested"
                )
                await executor.stop(stop_data.execution_id, stop_data.patience)
            else:
                logger.error(f"No execution with id {stop_data.execution_id}")
                await websocket.send(
                    WebSocketResponse(
                        type=Server2ClientType.ERROR,
                        execution_id=stop_data.execution_id,
                        message=f"No execution with id {stop_data.execution_id}",
                    ).model_dump_json()
                )


async def health_check(path, request_headers):
    """Handles the health check request.

    Args:
        path (str): Request path.
        request_headers (List[Tuple[str, str]]): Request headers.
    """
    if path == "/healthz":
        return http.HTTPStatus.OK, [], b"OK\n"


async def serve(
    port: int, username: Optional[str], pw: Optional[str], idle_timeout: int
):
    """Starts the WebSocket server.

    Args:
        port (int): Port number.
        username (Optional[str]): Username for Basic authentication.
        pw (Optional[str]): Password for Basic authentication.
        idle_timeout (int): Timeout in seconds for idle server.
    """
    loop = asyncio.get_running_loop()
    stop = loop.create_future()

    # Handle SIGTERM on Unix systems
    if platform.system() != "Windows":
        loop.add_signal_handler(signal.SIGTERM, stop.set_result, None)

    auth_protocol = (
        websockets.basic_auth_protocol_factory(credentials=(username, pw))
        if username and pw
        else None
    )

    global executor
    executor = ExecutionHandler(loop)

    if idle_timeout >= 0:
        asyncio.ensure_future(idle_timer(stop, idle_timeout))

    async with websockets.serve(
        handler,
        "localhost",
        port,
        create_protocol=auth_protocol,
        process_request=health_check,
    ):
        await stop


def parse_args() -> Namespace:
    """Parses the command line arguments."""
    description = "WebSocket API Wrapper for Python extensions."
    parser = ArgumentParser(description=description)
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default=8000,
        help="port number where the server will listen",
    )
    parser.add_argument(
        "-c",
        "--client",
        type=str,
        default="PEL",
        help="username to access the API with Basic authentication",
    )
    parser.add_argument("--pw", type=str, help="password to access the API")
    parser.add_argument(
        "--idle-timeout",
        type=int,
        default=360,
        help="timeout in seconds for idle server (<0 to disable)",
    )
    parser.add_argument(
        "--temp-dir",
        type=Path,
        default=None,
        help="temporary directory to store the files",
    )
    parser.add_argument(
        "--secret",
        type=str,
        default=None,
        help="Shared secret to decrypt connection values"
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="make the operation more descriptive",
    )
    return parser.parse_args()


if __name__ == "__main__":
    logger.info("Starting wrapper server...")
    args = parse_args()
    setup_logging(args.verbose, WebSocketLoggingHandler())
    if args.temp_dir:
        RuntimeConfig.set("temp_dir", args.temp_dir)
    if args.secret:
        RuntimeConfig.set("secret", args.secret)
    asyncio.run(serve(args.port, args.client, args.pw, args.idle_timeout))
