"""Module for handling data exchange between the Altair and the PEW."""

from __future__ import annotations

import zipfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from io import BytesIO, IOBase, TextIOWrapper
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Generator

import numpy
import pandas
import rapidminer
from altair_aitools.ext.io import SerializableType
from altair_aitools.ext.io.data_objects import DataObjectWrapper
from altair_aitools.ext.metadata import ColumnRole, get_role, set_role

from ..util import check_directory, get_input_file

try:
    import rapidminer.core.serdeutils as serdeutils
except (ImportError, ModuleNotFoundError):
    import rapidminer.rapidminer.core.serdeutils as serdeutils

if TYPE_CHECKING:
    from ..wrapper.wrapper_json_parser import Input, Output
    from ._types import FileType


class DataExchangeHandler(ABC):
    """Abstract Base Class of data exchange handlers."""

    def __init__(self, parameters: dict[str, str]):
        """Initializes the data exchange parameters and the I/O handler functions.

        Args:
            parameters (Dict[str, str]): Data exchange parameters.
        """
        self.parameters = parameters
        self._validate_parameters()
        self.input_handlers = {
            "file": self._load_file,
            "data_object": self._load_object,
            "tensor": self._load_tensor,
            "dataframe": self._load_dataframe,
        }
        self.output_handlers = {
            "file": self._return_file,
            "data_object": self._return_object,
            "tensor": self._return_tensor,
            "dataframe": self._return_dataframe,
        }

    def load_input(self, input: Input, param_type: type) -> Any:
        """Template function for input processing.

        Args:
            input (Input): Properties of the input.
            param_type (type): Expected type of the input.

        Returns:
            Any: Loaded input object or None if it is not provided.
        """
        if input.data_type not in self.input_handlers:
            raise ValueError(
                f"Unsupported input type '{input.data_type}' found for data exchange (parameter: '{input.name}')"
            )
        if not input.value:
            return None
        handler = self.input_handlers[input.data_type]
        return handler(input, param_type=param_type)

    def return_output(self, output_meta: Output, output: Any) -> None:
        """Template function for output processing.

        Args:
            output_meta (Output): Properties of the output.
            output (Any): Output object.
        """
        if output_meta.data_type not in self.output_handlers:
            raise ValueError(
                f"Unsupported output type '{output_meta.data_type}' found for data exchange"
            )
        handler = self.output_handlers[output_meta.data_type]
        handler(output_meta, output)

    @abstractmethod
    def _validate_parameters(self) -> None:
        pass

    @abstractmethod
    def _load_file(self, input: Input, param_type: type) -> FileType | list[FileType]:
        pass

    @abstractmethod
    def _load_object(
        self, input: Input, param_type: type
    ) -> SerializableType | list[SerializableType]:
        pass

    @abstractmethod
    def _load_tensor(
        self, input: Input, param_type: type
    ) -> numpy.ndarray | list[numpy.ndarray]:
        pass

    @abstractmethod
    def _load_dataframe(
        self, input: Input, param_type: type
    ) -> pandas.DataFrame | list[pandas.DataFrame]:
        pass

    @abstractmethod
    def _return_file(self, output_meta: Output, output: Any) -> None:
        pass

    @abstractmethod
    def _return_object(self, output_meta: Output, output: Any) -> None:
        pass

    @abstractmethod
    def _return_tensor(self, output_meta: Output, output: Any) -> None:
        pass

    @abstractmethod
    def _return_dataframe(self, output_meta: Output, output: Any) -> None:
        pass


class FileSystemDataExchangeHandler(DataExchangeHandler):
    """Handles data exchange in 'file_system' mode using serialization."""

    def __init__(self, parameters: dict[str, str]):
        """Initializes I/O locations.

        Args:
            parameters (Dict[str, str]): Data exchange parameters.
        """
        super().__init__(parameters)
        working_directory = Path(parameters["working_dir"])
        inputs_location = working_directory / "inputs"
        outputs_location = working_directory / "outputs"

        check_directory(inputs_location)
        check_directory(outputs_location)

        self.INPUTS_LOCATION = inputs_location
        self.OUTPUTS_LOCATION = outputs_location

    def _validate_parameters(self) -> None:
        if "working_dir" not in self.parameters:
            raise KeyError(
                "Missing parameter 'working_dir' for setting up data exchange handler in 'file_system' mode"
            )

    def _get_file_name(self, input: Input) -> Path:
        if not isinstance(input.value, str):
            raise TypeError(
                f"The input value '{input.name}' ({type(input.value)}) is not compatible with 'str'."
            )
        file_path = Path(input.value)
        if not file_path.is_absolute():
            file_path = get_input_file(self.INPUTS_LOCATION, file_path)
        return file_path

    @contextmanager
    def _open_input_file(self, input: Input) -> Generator[IO[bytes], None, None]:
        file = open(self._get_file_name(input), "rb")
        try:
            yield file
        finally:
            file.close()

    @contextmanager
    def _open_input_collection(
        self, input: Input
    ) -> Generator[zipfile.ZipFile, None, None]:
        file = zipfile.ZipFile(self._get_file_name(input), "r")
        try:
            yield file
        finally:
            file.close()

    @contextmanager
    def _open_output_file(
        self, output_meta: Output
    ) -> Generator[IO[bytes], None, None]:
        file = open(self.OUTPUTS_LOCATION / output_meta.name, "wb")
        try:
            yield file
        finally:
            file.close()

    @contextmanager
    def _open_output_collection(
        self, output_meta: Output
    ) -> Generator[zipfile.ZipFile, None, None]:
        file = zipfile.ZipFile(self.OUTPUTS_LOCATION / output_meta.name, "w")
        try:
            yield file
        finally:
            file.close()

    def _handle_output_collection(
        self,
        output_meta: Output,
        output: Any,
        output_writer: Callable[[Output, Any, str], None],
        suffix: str = ".file",
        use_temp_dir: bool = False,
    ) -> None:
        if not isinstance(output, list):
            raise TypeError(
                f"The output value '{output_meta.name}' ({type(output)}) is not compatible with 'list'."
            )
        if use_temp_dir:
            temp_dir = self.OUTPUTS_LOCATION / output_meta.name.split(".")[0]
            check_directory(temp_dir)
            temp_files: list[Path] = []
            for idx, out in enumerate(output):
                file_path = temp_dir / f"{idx}{suffix}"
                output_writer(out, file_path, f"{idx}{suffix}")
                temp_files.append(file_path)
            with self._open_output_collection(output_meta) as collection:
                for file in temp_files:
                    collection.write(file, file.name)
        else:
            with self._open_output_collection(output_meta) as collection:
                for idx, out in enumerate(output):
                    with collection.open(f"{idx}.{suffix}", "w") as output_file:
                        output_writer(out, output_file, f"{idx}.{suffix}")

    def _load_file_from_file(self, file: IO[bytes], param_type: type) -> FileType:
        content = BytesIO(file.read())
        if issubclass(param_type, TextIOWrapper):
            content = TextIOWrapper(content, encoding="utf-8")
        return content

    def _load_object_from_file(
        self, file: IO[bytes], param_type: type, file_name: str
    ) -> SerializableType:
        data_object = DataObjectWrapper.deserialize(file).object
        if not isinstance(data_object, param_type):
            raise ValueError(
                f"Input '{file_name}' is not a valid '{param_type.__name__}' object."
            )
        return data_object

    def _load_tensor_from_file(
        self, file: IO[bytes], param_type: type, file_name: str
    ) -> Any:
        if not issubclass(param_type, numpy.ndarray):
            raise ValueError(
                f"Expected type for input '{file_name}' is 'numpy.ndarray'."
            )
        data_object = DataObjectWrapper.deserialize(file).object
        if not isinstance(data_object, param_type):
            raise ValueError(
                f"Input '{file_name}' is not a valid '{param_type.__name__}' object."
            )

    def _load_dataframe_from_file(
        self, file: IO[bytes], param_type: type, file_name: str
    ) -> Any:
        if not issubclass(param_type, pandas.DataFrame):
            raise ValueError(
                f"Expected type for input '{file_name}' is 'pandas.DataFrame'."
            )
        df = rapidminer.Project().read(file)
        # Set column roles
        for col in df.columns:
            try:
                _, role = df.rm_metadata[col]
                set_role(df, col, ColumnRole(role.lower()))
            except Exception:
                set_role(df, col, ColumnRole.REGULAR)
        return df

    def _return_file_to_file(self, output: Any, output_file: IO[bytes]) -> None:
        if isinstance(output, IOBase) and output.readable():
            output = output.read()
        if isinstance(output, str):
            output = output.encode("utf-8")
        output_file.write(output)

    def _return_object_to_file(
        self, output: Any, output_file: IO[bytes], file_name: str
    ) -> None:
        if not isinstance(output, SerializableType):
            raise TypeError(
                f"The output value '{file_name}' ({type(output)}) is not compatible with 'SerializableType'."
            )
        output_file.write(DataObjectWrapper(object=output).serialize())

    def _return_tensor_to_file(
        self, output: Any, output_file: IO[bytes], file_name: str
    ) -> None:
        if not isinstance(output, numpy.ndarray):
            raise TypeError(
                f"The output value '{file_name}' ({type(output)}) is not compatible with 'numpy.ndarray'."
            )
        output_file.write(DataObjectWrapper(object=output).serialize())

    def _return_dataframe_to_file(
        self, output: Any, output_file: Path, file_name: str
    ) -> None:
        if not isinstance(output, pandas.DataFrame):
            raise TypeError(
                f"The output value '{file_name}' ({type(output)}) is not compatible with 'pandas.DataFrame'."
            )
        rm_metadata = {}
        for col in output.columns:
            role = get_role(output, col)
            # Use None to let the rapidminer package decide the type
            rm_metadata[col] = (None, role.value)
        serdeutils.set_metadata_without_warning(output, rm_metadata)
        rapidminer.Project().write(output, output_file)

    def _load_file(self, input: Input, param_type: type) -> FileType | list[FileType]:
        if input.collection:
            with self._open_input_collection(input) as collection:
                if issubclass(param_type, (BytesIO, TextIOWrapper)):
                    return [
                        self._load_file_from_file(
                            collection.open(file_name), param_type
                        )
                        for file_name in collection.namelist()
                    ]
                else:
                    temp_dir = self.INPUTS_LOCATION / input.name.split(".")[0]
                    collection.extractall(temp_dir)
                    return [temp_dir / file_name for file_name in collection.namelist()]
        if issubclass(param_type, (BytesIO, TextIOWrapper)):
            with self._open_input_file(input) as file:
                return self._load_file_from_file(file, param_type)
        else:
            return self._get_file_name(input)

    def _load_object(
        self, input: Input, param_type: type
    ) -> SerializableType | list[SerializableType]:
        if input.collection:
            with self._open_input_collection(input) as collection:
                return [
                    self._load_object_from_file(
                        collection.open(file_name), param_type, file_name
                    )
                    for file_name in collection.namelist()
                ]
        with self._open_input_file(input) as file:
            return self._load_object_from_file(file, param_type, input.name)

    def _load_tensor(self, input: Input, param_type: type) -> Any:
        if input.collection:
            with self._open_input_collection(input) as collection:
                return [
                    self._load_tensor_from_file(
                        collection.open(file_name), param_type, file_name
                    )
                    for file_name in collection.namelist()
                ]
        with self._open_input_file(input) as file:
            return self._load_tensor_from_file(file, param_type, input.name)

    def _load_dataframe(self, input: Input, param_type: type) -> Any:
        if input.collection:
            with self._open_input_collection(input) as collection:
                return [
                    self._load_dataframe_from_file(
                        collection.open(file_name), param_type, file_name
                    )
                    for file_name in collection.namelist()
                ]
        with self._open_input_file(input) as file:
            return self._load_dataframe_from_file(file, param_type, input.name)

    def _return_file(self, output_meta: Output, output: Any) -> None:
        if output_meta.collection:
            self._handle_output_collection(
                output_meta, output, self._return_file_to_file, ".file"
            )
        else:
            with self._open_output_file(output_meta) as output_file:
                self._return_file_to_file(output, output_file)

    def _return_object(self, output_meta: Output, output: Any) -> None:
        if output_meta.collection:
            self._handle_output_collection(
                output_meta, output, self._return_object_to_file, ".json"
            )
        with self._open_output_file(output_meta) as output_file:
            self._return_object_to_file(output, output_file, output_meta.name)

    def _return_tensor(self, output_meta: Output, output: Any) -> None:
        if output_meta.collection:
            self._handle_output_collection(
                output_meta, output, self._return_tensor_to_file, ".json"
            )
        with self._open_output_file(output_meta) as output_file:
            self._return_tensor_to_file(output, output_file, output_meta.name)

    def _return_dataframe(self, output_meta: Output, output: Any) -> None:
        if output_meta.collection:
            self._handle_output_collection(
                output_meta,
                output,
                self._return_dataframe_to_file,
                ".rmhdf5table",
                use_temp_dir=True,
            )
        else:
            file_path = self.OUTPUTS_LOCATION / output_meta.name
            self._return_dataframe_to_file(output, file_path, output_meta.name)
