import ska_control_model as scm
import ska_tango_base as stb
import ska_tango_base.long_running_commands as stb_lrc
import tango
import tango.server
import watchfiles
import release

import os
import pwd
import grp
import time
import stat
import threading
import logging
import typing

# Disable watchfiles info logs as they can be quite noisy
logging.getLogger("watchfiles.main").setLevel(logging.WARNING)


class FileStats(stb_lrc.LRCMixin, stb.base.BaseInterface):
    """Tango device that exposes file statistics as attributes."""

    FilePath = tango.server.device_property(
        dtype="DevString", default_value="dummy", doc="Path of file to monitor"
    )

    def init_device(self) -> None:
        """Initialise the device."""
        super().init_device()

        self._version_id = release.version
        self._build_state = f"{release.name} {release.version}: {release.description}"

        self.stop_event: threading.Event | None = None
        self.monitor_thread: threading.Thread | None = None

        self.init_completed()

    def on_new_shared_bus(self) -> None:
        """Create sharing observers."""
        super().on_new_shared_bus()
        self.metadata = Metadata(self.FilePath, self.logger)

    def delete_device(self) -> None:
        """De-initialise device."""
        self._stop_monitor_thread()
        super().delete_device()

    def change_control_level(self, control_level: stb.base.ControlLevel) -> None:
        """Change how the device is interacting with the system under control."""
        if control_level == stb.base.ControlLevel.NO_CONTACT:
            self._stop_monitor_thread()
            self.metadata.reset()
            self.component_disconnected()
        elif control_level == stb.base.ControlLevel.FULL_CONTROL:
            self.component_on()
            self._start_monitor_thread()
        else:
            raise ValueError(f"Unknown control_level {control_level}")

    def _start_monitor_thread(self) -> None:
        def target(device: FileStats) -> None:
            try:
                assert device.stop_event is not None
                device._status = f"Monitoring '{self.FilePath}'"
                device.metadata.monitor_for_updates(device.stop_event)
            except Exception as ex:
                device.logger.exception("File monitor raised unexpected exception")
                device.component_fault()
                device._status = f"Monitor thread crashed: {ex}"
                return
            device._status = f"Monitoring disabled"

        self.stop_event = threading.Event()
        self.monitor_thread = threading.Thread(target=target, args=(self,))
        self.monitor_thread.start()

    def _stop_monitor_thread(self) -> None:
        if self.stop_event is not None:
            assert self.monitor_thread is not None

            self.stop_event.set()
            self.monitor_thread.join()

            self.monitor_thread = None
            self.stop_event = None

    def read_attr_hardware(self, attr_list: list[int]) -> None:
        """Prepare for attribute read."""
        _ = attr_list
        if self.get_state() != tango.DevState.DISABLE:
            self.metadata.refresh()
            with self.allow_internal_threads():
                self.shared_bus.wait_for_thread()

    size = stb.software_bus.attribute_from_signal(
        "metadata.size", dtype=int, abs_change=1
    )
    lastModifiedTime = stb.software_bus.attribute_from_signal(
        "metadata.last_modified_time", dtype=str
    )
    owner = stb.software_bus.attribute_from_signal("metadata.owner", dtype=str)
    mode = stb.software_bus.attribute_from_signal("metadata.mode", dtype=str)

    @stb.software_bus.listen_to_signal("metadata.stat_failed")
    def __on_stat_failed(self, ex: Exception | None) -> None:
        if ex is None:
            self.report_health(scm.HealthState.OK, [])
        else:
            self.report_health(scm.HealthState.FAILED, [f"stat('{self.FilePath}'): {ex}"])

    def is_Shrink_allowed(self) -> bool:
        """Return true if the Shrink command is allowed."""
        return typing.cast(bool, self.get_state() != tango.DevState.DISABLE)

    @tango.server.command(dtype_out="DevVarLongStringArray")
    def Shrink(self, new_size: int) -> stb.type_hints.DevVarLongStringArrayType:
        """
        Shrink the file to the provided size.

        :param new_size: size to shrink to
        :return: ResultCode.OK, message
        :raises ValueError: if the new file size is greater than the current size or
                            the current size is unknown
        """

        current_size = self.metadata.get_size_or_raise()

        if new_size > current_size:
            raise ValueError(
                f"New size '{new_size}' is greater than current size '{current_size}'. "
                "Cannot shrink to new size."
            )

        os.truncate(self.metadata.path, new_size)

        return [scm.ResultCode.OK], [f"File shrunk to size '{new_size}'"]

    def is_Grow_allowed(
        self,
        request_type: stb_lrc.LRCReqType = stb_lrc.LRCReqType.ENQUEUE_REQ,
    ) -> bool:
        """Return true if the Grow command is allowed."""
        if request_type == stb_lrc.LRCReqType.EXECUTE_REQ:
            return typing.cast(bool, self.get_state() != tango.DevState.DISABLE)

        return True

    Grow_SCHEMA: dict[str, stb.type_hints.JSONData] = {
        "$schema": "https://json-schema.org/draft/2020-12/schema",
        "$id": "artefact.skao.int/mylrc.schema.json",
        "title": "Grow schema",
        "description": "Validates the keyword arguments for the Grow command",
        "type": "object",
        "properties": {
            "new_size": {"type": "number"},
            "chunk_size": {"type": "number"},
            "source": {"type": "string"},
        },
    }

    @stb_lrc.long_running_command
    @stb.validators.validate_json_args
    def Grow(
        self, new_size: int, chunk_size: int, source: str
    ) -> stb.type_hints.TaskFunctionType:
        """
        Grow the file.

        The file will be grown in chunks with bytes from a given source.

        The source must be able to provide all of the required bytes.  If not
        enough bytes are available then the operation will fail and the file
        will be reset to its original size.

        This is a long running command.

        :param new_size: New size to grow to
        :param chunk_size: Size of the chunks
        :param source: path to source of bytes
        :return: ResultCode, message
        """

        @stb.executor.TaskExecutor.task
        def task(
            progress_callback: stb.type_hints.ProgressCallbackType,
            task_abort_event: threading.Event,
        ) -> tuple[scm.ResultCode, str]:
            try:
                initial_size = self.metadata.get_size_or_raise()
            except Exception as ex:
                return (scm.ResultCode.FAILED, f"{ex}")

            if new_size < initial_size:
                return (
                    scm.ResultCode.FAILED,
                    f"New size '{new_size}' is less than current size '{initial_size}'"
                    "Cannot grow file.",
                )

            to_transfer = new_size - initial_size
            source_file = open(source, "br")
            dest_file = open(self.FilePath, "ba")

            try:
                chunked_transfer(
                    dest_file,
                    source_file,
                    chunk_size,
                    to_transfer,
                    progress_callback=progress_callback,
                    task_abort_event=task_abort_event,
                )
            except Exception as ex:
                source_file.close()
                dest_file.close()
                os.truncate(self.FilePath, initial_size)
                if isinstance(ex, stb.executor.TaskAborted):
                    self.logger.error(
                        f"Grow aborted. Shrinking back to initial size '{initial_size}'"
                    )
                    raise
                else:
                    self.logger.exception(
                        f"Grow failed. Shrinking back to initial size '{initial_size}'"
                    )
                    return scm.ResultCode.FAILED, f"Chunked transfer failed: {ex}"

            return scm.ResultCode.OK, f"File size increased to {new_size}"

        return task


class Metadata(stb.software_bus.SharingObserver):
    """File metadata."""

    size_signal = stb.software_bus.AttrSignal[int](name="size")
    mode_signal = stb.software_bus.AttrSignal[str](name="mode")
    owner_signal = stb.software_bus.AttrSignal[str](name="owner")
    last_modified_time_signal = stb.software_bus.AttrSignal[str](
        name="last_modified_time"
    )

    stat_failed = stb.software_bus.Signal[Exception | None](stored=True)

    def __init__(self, path: str, logger: logging.Logger) -> None:
        self.path = os.path.abspath(path)
        self.logger = logger
        self.data: os.stat_result | None = None
        self.timestamp: float | None = None
        self.lock = threading.Lock()

    def reset(self) -> None:
        """Reset metadata to None."""
        with self.lock:
            self.data = None
            self.timestamp = None
            self.stat_failed = None

            self.logger.debug(f"Resetting metadata")
            self._emit_signals()

    def refresh(self) -> None:
        """Refresh metadata.

        If refreshing fails, the metadata is set to None.
        """
        with self.lock:
            try:
                self.data = os.stat(self.path)
                self.timestamp = time.time()
                self.stat_failed = None
            except OSError as ex:
                self.data = None
                self.timestamp = None
                self.stat_failed = ex

            self.logger.debug("Refreshing metadata: %s", self.data)
            self._emit_signals()

    def _emit_signals(self) -> None:
        if self.timestamp is not None:
            self.size_signal = (
                typing.cast(int, self.size),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
            self.mode_signal = (
                typing.cast(str, self.mode),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
            self.owner_signal = (
                typing.cast(str, self.owner),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
            self.last_modified_time_signal = (
                typing.cast(str, self.last_modified_time),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
        else:
            self.size_signal = None
            self.mode_signal = None
            self.owner_signal = None
            self.last_modified_time_signal = None

    def monitor_for_updates(self, stop_event: threading.Event) -> None:
        """Setup monitoring for file metadata changes."""
        parent = os.path.dirname(self.path)

        def only_filename(change: watchfiles.Change, name: str) -> bool:
            _ = change
            return os.path.abspath(name) == self.path

        # Monitor from parent directory to know if the file is created or deleted.
        # TODO: Support missing parent directory
        watch = watchfiles.watch(
            parent, watch_filter=only_filename, stop_event=stop_event
        )

        # Refresh after we start watch to ensure we are up-to-date and
        # don't miss a change
        self.refresh()

        self.logger.debug("Watching %s", parent)
        for changes in watch:
            self.logger.debug("Got changes=%s", changes)
            self.refresh()

    @property
    def size(self) -> int | None:
        """Return the size of the file in bytes."""
        if self.data is not None:
            return self.data.st_size
        return None

    @property
    def mode(self) -> str | None:
        """
        Return the mode (permission) of the file.

        :return: mode in the same format as ``ls -l``.
        """
        if self.data is not None:
            return stat.filemode(self.data.st_mode)
        return None

    @property
    def owner(self) -> str | None:
        """
        Return the owner of the file.

        :return: has the format "<user>:<group>"
        """
        if self.data is not None:
            user = pwd.getpwuid(self.data.st_uid)
            group = grp.getgrgid(self.data.st_gid)
            return f"{user.pw_name}:{group.gr_name}"
        return None

    @property
    def last_modified_time(self) -> str | None:
        """
        Return the time the file was last modified.

        :return: last modified time in the ``ctime`` format
        """
        if self.data is not None:
            return time.ctime(self.data.st_mtime)
        return None

    def get_size_or_raise(self) -> int:
        """
        Return the size of the file.

        If the size is not available raises the exception that was generated
        when the last stat failed.

        :return: size of file in bytes
        :raises Exception: if size not available
        """

        with self.lock:
            size = self.size
            if size is not None:
                return size

            ex = self.stat_failed
            if ex is not None:
                raise ex

        raise ValueError(
            f"Unable to determine size of '{self.path}' for unknown reason."
        )


def chunked_transfer(
    destination: typing.BinaryIO,
    source: typing.BinaryIO,
    chunk_size: int,
    total_bytes: int,
    *,
    progress_callback: stb.type_hints.ProgressCallbackType,
    task_abort_event: threading.Event,
) -> None:
    progress_callback(progress=0)
    written = 0
    last_progress = 0

    def transfer_chunk(size: int) -> None:
        nonlocal written, last_progress

        bytes = source.read(size)
        if len(bytes) != size:
            raise RuntimeError(
                f"Not enough data. Found {written + len(bytes)} bytes, "
                f"expected at least {total_bytes} bytes."
            )

        destination.write(bytes)
        destination.flush()
        os.fsync(destination.fileno())

        written += len(bytes)
        progress = int(100 * (written / total_bytes))
        if progress > last_progress:
            progress_callback(progress=progress)
            last_progress = progress

    n_complete_chunks = total_bytes // chunk_size
    remainder = total_bytes - n_complete_chunks * chunk_size

    for i in range(n_complete_chunks):
        transfer_chunk(chunk_size)

        if task_abort_event.is_set():
            raise stb.executor.TaskAborted()

    if remainder > 0:
        transfer_chunk(remainder)

    assert written == total_bytes


if __name__ == "__main__":
    tango.server.run((FileStats,))
