import ska_control_model as scm
import ska_tango_base as stb
import tango
import tango.server
import watchfiles
import release

import os
import pwd
import grp
import time
import stat
import threading
import logging
import typing

# Disable watchfiles info logs as they can be quite noisy
logging.getLogger("watchfiles.main").setLevel(logging.WARNING)


class FileStats(stb.base.BaseInterface):
    """Tango device that exposes file statistics as attributes."""

    FilePath = tango.server.device_property(
        dtype="DevString", default_value="dummy", doc="Path of file to monitor"
    )

    def init_device(self) -> None:
        """Initialise the device."""
        super().init_device()

        self._version_id = release.version
        self._build_state = f"{release.name} {release.version}: {release.description}"

        self.stop_event: threading.Event | None = None
        self.monitor_thread: threading.Thread | None = None

        self.init_completed()

    def on_new_shared_bus(self) -> None:
        """Create sharing observers."""
        super().on_new_shared_bus()
        self.metadata = Metadata(self.FilePath, self.logger)

    def delete_device(self) -> None:
        """De-initialise device."""
        self._stop_monitor_thread()
        super().delete_device()

    def change_control_level(self, control_level: stb.base.ControlLevel) -> None:
        """Change how the device is interacting with the system under control."""
        if control_level == stb.base.ControlLevel.NO_CONTACT:
            self._stop_monitor_thread()
            self.metadata.reset()
            self.component_disconnected()
        elif control_level == stb.base.ControlLevel.FULL_CONTROL:
            self.component_on()
            self._start_monitor_thread()
        else:
            raise ValueError(f"Unknown control_level {control_level}")

    def _start_monitor_thread(self) -> None:
        def target(device: FileStats) -> None:
            try:
                assert device.stop_event is not None
                device._status = f"Monitoring '{self.FilePath}'"
                device.metadata.monitor_for_updates(device.stop_event)
            except Exception as ex:
                device.logger.exception("File monitor raised unexpected exception")
                device.component_fault()
                device._status = f"Monitor thread crashed: {ex}"
                return
            device._status = f"Monitoring disabled"

        self.stop_event = threading.Event()
        self.monitor_thread = threading.Thread(target=target, args=(self,))
        self.monitor_thread.start()

    def _stop_monitor_thread(self) -> None:
        if self.stop_event is not None:
            assert self.monitor_thread is not None

            self.stop_event.set()
            self.monitor_thread.join()

            self.monitor_thread = None
            self.stop_event = None

    def read_attr_hardware(self, attr_list: list[int]) -> None:
        """Prepare for attribute read."""
        _ = attr_list
        if self.get_state() != tango.DevState.DISABLE:
            self.metadata.refresh()
            with self.allow_internal_threads():
                self.shared_bus.wait_for_thread()

    size = stb.software_bus.attribute_from_signal(
        "metadata.size", dtype=int, abs_change=1
    )
    lastModifiedTime = stb.software_bus.attribute_from_signal(
        "metadata.last_modified_time", dtype=str
    )
    owner = stb.software_bus.attribute_from_signal("metadata.owner", dtype=str)
    mode = stb.software_bus.attribute_from_signal("metadata.mode", dtype=str)

    @stb.software_bus.listen_to_signal("metadata.stat_failed")
    def __on_stat_failed(self, ex: Exception | None) -> None:
        if ex is None:
            self.report_health(scm.HealthState.OK, [])
        else:
            self.report_health(scm.HealthState.FAILED, [f"stat('{self.FilePath}'): {ex}"])

    def is_Shrink_allowed(self) -> bool:
        """Return true if the Shrink command is allowed."""
        return typing.cast(bool, self.get_state() != tango.DevState.DISABLE)

    @tango.server.command(dtype_out="DevVarLongStringArray")
    def Shrink(self, new_size: int) -> stb.type_hints.DevVarLongStringArrayType:
        """
        Shrink the file to the provided size.

        :param new_size: size to shrink to
        :return: ResultCode.OK, message
        :raises ValueError: if the new file size is greater than the current size or
                            the current size is unknown
        """

        current_size = self.metadata.get_size_or_raise()

        if new_size > current_size:
            raise ValueError(
                f"New size '{new_size}' is greater than current size '{current_size}'. "
                "Cannot shrink to new size."
            )

        os.truncate(self.FilePath, new_size)

        return [scm.ResultCode.OK], [f"File shrunk to size '{new_size}'"]


class Metadata(stb.software_bus.SharingObserver):
    """File metadata."""

    size_signal = stb.software_bus.AttrSignal[int](name="size")
    mode_signal = stb.software_bus.AttrSignal[str](name="mode")
    owner_signal = stb.software_bus.AttrSignal[str](name="owner")
    last_modified_time_signal = stb.software_bus.AttrSignal[str](
        name="last_modified_time"
    )

    stat_failed = stb.software_bus.Signal[Exception | None](stored=True)

    def __init__(self, path: str, logger: logging.Logger) -> None:
        self.path = os.path.abspath(path)
        self.logger = logger
        self.data: os.stat_result | None = None
        self.timestamp: float | None = None
        self.lock = threading.Lock()

    def reset(self) -> None:
        """Reset metadata to None."""
        with self.lock:
            self.data = None
            self.timestamp = None
            self.stat_failed = None

            self.logger.debug(f"Resetting metadata")
            self._emit_signals()

    def refresh(self) -> None:
        """Refresh metadata.

        If refreshing fails, the metadata is set to None.
        """
        with self.lock:
            try:
                self.data = os.stat(self.path)
                self.timestamp = time.time()
                self.stat_failed = None
            except OSError as ex:
                self.data = None
                self.timestamp = None
                self.stat_failed = ex

            self.logger.debug("Refreshing metadata: %s", self.data)
            self._emit_signals()

    def _emit_signals(self) -> None:
        if self.timestamp is not None:
            self.size_signal = (
                typing.cast(int, self.size),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
            self.mode_signal = (
                typing.cast(str, self.mode),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
            self.owner_signal = (
                typing.cast(str, self.owner),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
            self.last_modified_time_signal = (
                typing.cast(str, self.last_modified_time),
                self.timestamp,
                tango.AttrQuality.ATTR_VALID,
            )
        else:
            self.size_signal = None
            self.mode_signal = None
            self.owner_signal = None
            self.last_modified_time_signal = None

    def monitor_for_updates(self, stop_event: threading.Event) -> None:
        """Setup monitoring for file metadata changes."""
        parent = os.path.dirname(self.path)

        def only_filename(change: watchfiles.Change, name: str) -> bool:
            _ = change
            return os.path.abspath(name) == self.path

        # Monitor from parent directory to know if the file is created or deleted.
        # TODO: Support missing parent directory
        watch = watchfiles.watch(
            parent, watch_filter=only_filename, stop_event=stop_event
        )

        # Refresh after we start watch to ensure we are up-to-date and
        # don't miss a change
        self.refresh()

        self.logger.debug("Watching %s", parent)
        for changes in watch:
            self.logger.debug("Got changes=%s", changes)
            self.refresh()

    @property
    def size(self) -> int | None:
        """Return the size of the file in bytes."""
        if self.data is not None:
            return self.data.st_size
        return None

    @property
    def mode(self) -> str | None:
        """
        Return the mode (permission) of the file.

        :return: mode in the same format as ``ls -l``.
        """
        if self.data is not None:
            return stat.filemode(self.data.st_mode)
        return None

    @property
    def owner(self) -> str | None:
        """
        Return the owner of the file.

        :return: has the format "<user>:<group>"
        """
        if self.data is not None:
            user = pwd.getpwuid(self.data.st_uid)
            group = grp.getgrgid(self.data.st_gid)
            return f"{user.pw_name}:{group.gr_name}"
        return None

    @property
    def last_modified_time(self) -> str | None:
        """
        Return the time the file was last modified.

        :return: last modified time in the ``ctime`` format
        """
        if self.data is not None:
            return time.ctime(self.data.st_mtime)
        return None

    def get_size_or_raise(self) -> int:
        """
        Return the size of the file.

        If the size is not available raises the exception that was generated
        when the last stat failed.

        :return: size of file in bytes
        :raises Exception: if size not available
        """

        with self.lock:
            size = self.size
            if size is not None:
                return size

            ex = self.stat_failed
            if ex is not None:
                raise ex

        raise ValueError(
            f"Unable to determine size of '{self.path}' for unknown reason."
        )


def chunked_transfer(
    destination: typing.BinaryIO,
    source: typing.BinaryIO,
    chunk_size: int,
    total_bytes: int,
    *,
    progress_callback: stb.type_hints.ProgressCallbackType,
    task_abort_event: threading.Event,
) -> None:
    progress_callback(progress=0)
    written = 0
    last_progress = 0

    def transfer_chunk(size: int) -> None:
        nonlocal written, last_progress

        bytes = source.read(size)
        if len(bytes) != size:
            raise RuntimeError(
                f"Not enough data. Found {written + len(bytes)} bytes, "
                f"expected at least {total_bytes} bytes."
            )

        destination.write(bytes)
        destination.flush()
        os.fsync(destination.fileno())

        written += len(bytes)
        progress = int(100 * (written / total_bytes))
        if progress > last_progress:
            progress_callback(progress=progress)
            last_progress = progress

    n_complete_chunks = total_bytes // chunk_size
    remainder = total_bytes - n_complete_chunks * chunk_size

    for i in range(n_complete_chunks):
        transfer_chunk(chunk_size)

        if task_abort_event.is_set():
            raise stb.executor.TaskAborted()

    if remainder > 0:
        transfer_chunk(remainder)

    assert written == total_bytes


if __name__ == "__main__":
    tango.server.run((FileStats,))
