Source code for ska_pst.common.dada_file

# -*- coding: utf-8 -*-
#
# This file is part of the SKA PST project
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.

"""Module class for managing DADA files."""

from __future__ import annotations

import itertools
import logging
import pathlib
import struct
from types import TracebackType
from typing import Any, List

import nptyping as npt
import numpy as np
from ska_pst.common.constants import BITS_PER_BYTE
from ska_pydada import AsciiHeader, DadaFile

from .telescope_configuration import get_udp_nsamp_for_format

__all__ = [
    "DadaFileManager",
    "DadaFileReader",
    "WeightsFileReader",
]


DEFAULT_HEADER_SIZE = 4096
HEADER_SIZE_KEY = "HDR_SIZE"
SECONDS_PER_FILE = 10
NBIT16 = 16
ScalesType = npt.NDArray[Any, npt.Single]
WeightsType = npt.NDArray[Any, npt.UShort]


[docs]class DadaFileManager: """Class that captures PST data files. Parses attributes from a set of PST voltage recorder data and weights files and computes some of the derived quantities from the scales and weights, such as the inferred number of dropped/invalid packets. """ def __init__( self: DadaFileManager, folder: pathlib.Path, logger: logging.Logger | None = None, ) -> None: """ Initialise the DADA file manager. This scans the provided folder for DADA files and instantiates corresponding reader objects for both data and weights files. :param folder: absolute path to directory containing ``data/`` and ``weights/`` subdirectories with DADA files. :type folder: pathlib.Path :param logger: optional logger instance to use for debug and warning messages. :type logger: logging.Logger | None :raises AssertionError: if the provided folder does not exist or is not a directory. """ assert folder.exists() and folder.is_dir() self.folder = folder self._data_files: List[DadaFileReader] = [] self._weights_files: List[WeightsFileReader] = [] self._logger = logger or logging.getLogger(__name__) self._get_dada_files() def _get_dada_files(self: DadaFileManager) -> None: """Populate list of Data and Weights files.""" data_paths = list(self.folder.glob("data/*.dada")) weights_paths = list(self.folder.glob("weights/*.dada")) if len(data_paths) != len(weights_paths): self._logger.warning( f"WARNING: data_paths ({len(data_paths)}) != weights_paths ({len(weights_paths)})" ) data_paths.sort() weights_paths.sort() data_files = [DadaFileReader(f) for f in data_paths] weights_files = [WeightsFileReader(f) for f in weights_paths] self._data_files = data_files self._weights_files = weights_files @property def data_files(self: DadaFileManager) -> List[DadaFileReader]: """ Get the list of data file readers. :return: list of DADA data file reader instances. :rtype: List[DadaFileReader] """ return self._data_files @property def weights_files(self: DadaFileManager) -> List[WeightsFileReader]: """ Get the list of weights file readers. :return: list of DADA weights file reader instances. :rtype: List[WeightsFileReader] """ return self._weights_files
[docs]class DadaFileReader: """ Reader class for PSR DADA files. This class provides a high-level interface to access metadata and header values from a DADA file using the underlying ``ska_pydada`` library. It supports both header-only and full-file access modes. """ def __init__( self: DadaFileReader, file: pathlib.Path, header_only: bool = True, logger: logging.Logger | None = None, ) -> None: """ Initialise a DADA file reader. :param file: path to the DADA file. :type file: pathlib.Path :param header_only: if True, only the header is loaded to minimise memory usage. :type header_only: bool :param logger: optional logger instance. :type logger: logging.Logger | None """ self._dada_file = DadaFile.load_from_file(file=file, header_only=header_only) self._logger = logger or logging.getLogger(__name__) def __enter__(self: DadaFileReader) -> DadaFileReader: """ Enter context manager. :return: the current instance. :rtype: DadaFileReader """ return self def __exit__( self: DadaFileReader, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """ Exit context manager. :param exc_type: exception type if raised. :type exc_type: type[BaseException] | None :param exc_val: exception value. :type exc_val: BaseException | None :param exc_tb: traceback object. :type exc_tb: TracebackType | None """ @property def header(self: DadaFileReader) -> AsciiHeader: """Get header for file.""" return self._dada_file.header @property def file_path(self: DadaFileReader) -> str: """Get path of file in string.""" return str(self._dada_file.file.resolve()) @property def file_size(self: DadaFileReader) -> int: """Get size of file in bytes.""" stats = self._dada_file.file.stat() return stats.st_size @property def header_size(self: DadaFileReader) -> int: """Get the size of the header in bytes.""" return self.header.header_size @property def data_size(self: DadaFileReader) -> int: """Get the size of the data.""" return self.file_size - self.header_size @property def obs_offset(self: DadaFileReader) -> int: """Get the OBS_OFFSET value.""" return self._dada_file.obs_offset @property def file_number(self: DadaFileReader) -> int: """Get the FILE_NUMBER value from header.""" return self._dada_file.file_number @property def eb_id(self: DadaFileReader) -> str: """Get the EB_ID value from the header.""" return self._get_header_str("EB_ID") @property def scan_id(self: DadaFileReader) -> int: """Get the SCAN_ID value from header.""" return self._dada_file.scan_id @property def observer(self: DadaFileReader) -> str: """Get the OBSERVER value from header.""" return self._get_header_str("OBSERVER") @property def intent(self: DadaFileReader) -> str: """Build value using SOURCE from header.""" return f"Tied-array beam observation of {self.source}" @property def notes(self: DadaFileReader) -> str: """Get the NOTES value from header.""" # note this key does not current exist return self._get_header_str("INTENT") @property def source(self: DadaFileReader) -> str: """Get the SOURCE value from header.""" return self._get_header_str("SOURCE") @property def utc_start(self: DadaFileReader) -> str: """Get the UTC_START value from header.""" return self._get_header_str("UTC_START") @property def picoseconds(self: DadaFileReader) -> int: """Get the PICOSECONDS value from the header.""" return self.header.get_int_opt("PICOSECONDS") or 0 @property def tsamp(self: DadaFileReader) -> float: """Get the TSAMP value from header.""" return self.header.get_float("TSAMP") @property def udp_nsamp(self: DadaFileReader) -> int: """Get the UDP_NSAMP value from header.""" return self.header.get_int("UDP_NSAMP") @property def resolution(self: DadaFileReader) -> int: """Get the RESOLUTION value from header.""" return self.header.get_int("RESOLUTION") @property def resolution_per_sample(self: DadaFileReader) -> int: """ The amount of bytes needed for one time sample of all channels, polarisations and bits. Note that this may be different to the ``resolution`` property as that value may include a factor ``UDP_NSAMP`` in the resultant value. """ return self.nbit * self.nchan * self.npol * self.ndim // BITS_PER_BYTE @property def telescope(self: DadaFileReader) -> str: """Get the TELESCOPE value from header.""" return self._get_header_str("TELESCOPE") @property def nchan(self: DadaFileReader) -> int: """Get the NCHAN value from header.""" return self.header.get_int("NCHAN") @property def freq(self: DadaFileReader) -> str: """Get the FREQ value from header.""" return self._get_header_str("FREQ") @property def bw(self: DadaFileReader) -> str: """Get the BW value from header.""" return self._get_header_str("BW") @property def npol(self: DadaFileReader) -> int: """Get the NPOL value from header.""" return self.header.get_int("NPOL") @property def poln_ft(self: DadaFileReader) -> str: """Get the POLN_FT value from header.""" return self._get_header_str("POLN_FT") @property def stt_crd1(self: DadaFileReader) -> str: """Get the STT_CRD1 value from header.""" return self._get_header_str("STT_CRD1", "00:00:00.0") @property def stt_crd2(self: DadaFileReader) -> str: """Get the STT_CRD2 value from header.""" return self._get_header_str("STT_CRD2", "00:00:00.0") @property def equinox(self: DadaFileReader) -> str: """Get the EQUINOX value from header.""" return self._get_header_str("EQUINOX", "2000") @property def sky_coord_equinox(self: DadaFileReader) -> str: """Get the EQUINOX value from header in the format of J<value>.""" return f"J{self.equinox}" @property def nbit(self: DadaFileReader) -> int: """Get the number of bits the data is encoded in.""" return self.header.get_int("NBIT") @property def ndim(self: DadaFileReader) -> int: """Get the number of dimensions (2=complex, 1=real).""" return self.header.get_int("NDIM") def __getattr__(self: DadaFileReader, item: str) -> Any: """Get an attribute from the DADA file header.""" return self._get_header_str(key=item) def _get_header_str(self: DadaFileReader, key: str, default_value: str = "Unknown") -> str: """ Get the header value of the specified key, or the default_value if not available. :param key: header key to look for. :param default_value: value to return if the key does not exist in the header. :return: value of the header hey. :rtype: str """ try: return self.header.get_value(key) except KeyError: self._logger.debug( f"key: {key} not present in self.header. default_value: {default_value} is used for {key}" ) return default_value
[docs]class WeightsFileReader(DadaFileReader): """Class that can be used to read a Weights PSRDADA file generated by DSP.DISK.""" def __init__( self: WeightsFileReader, file: pathlib.Path, unpack_scales: bool = True, unpack_weights: bool = True, logger: logging.Logger | None = None, ) -> None: """ Initialise the weights file reader. :param file: path to the weights DADA file. :type file: pathlib.Path :param unpack_scales: whether to unpack scale factors. :type unpack_scales: bool :param unpack_weights: whether to unpack weights. :type unpack_weights: bool :param logger: optional logger instance. :type logger: logging.Logger | None """ header_only = False super().__init__(file=file, header_only=header_only, logger=logger) self.unpack_scales = unpack_scales self.unpack_weights = unpack_weights self._scales: ScalesType | None = None self._weights: WeightsType | None = None def __enter__(self: WeightsFileReader) -> WeightsFileReader: """Enter context manager for this file.""" self._read_data() return self def __exit__( self: WeightsFileReader, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """Exit context manager for this file.""" def _read_data(self: WeightsFileReader) -> None: """Read the scales and weights in the file.""" # extract the required parameters from the header self.nsamp_per_weight = self.header.get_int("NSAMP_PER_WEIGHT") self.packet_weights_size = self.header.get_int("PACKET_WEIGHTS_SIZE") self.packet_scales_size = self.header.get_int("PACKET_SCALES_SIZE") udp_format = self._get_header_str("UDP_FORMAT") # the CBF to PSR ICD specifies all weights a 16 bits per sample assert self.nbit == NBIT16, f"Expected nbit={self.nbit} to be {NBIT16}" # compute the number of relative weights in each packet udp_nsamp = get_udp_nsamp_for_format(udp_format) assert ( udp_nsamp % self.nsamp_per_weight == 0 ), f"Expect {udp_nsamp=} to be a multiple of {self.nsamp_per_weight}" self.nweight_per_packet = udp_nsamp // self.nsamp_per_weight self._logger.debug(f"computed weights per packet as {self.nweight_per_packet}") # compute the number of channels in each packet assert self.packet_weights_size % (self.nweight_per_packet * self.nbit // 8) == 0, ( f"Expected packet_weights_size={self.packet_weights_size} to be a " f"multiple of {self.nweight_per_packet * self.nbit // 8}" ) self.nchan_per_packet = self.packet_weights_size // (self.nweight_per_packet * self.nbit // 8) self._logger.debug( f"packet_weights_size={self.packet_weights_size} nweight_per_packet={self.nweight_per_packet} " f"nbit={self.nbit} nchan_per_packet={self.nchan_per_packet}" ) self.weights_packet_stride = self.packet_weights_size + self.packet_scales_size self._logger.debug( f"weights_packet_stride={self.weights_packet_stride} packet_scales_size=" f"{self.packet_scales_size} packet_weights_size={self.packet_weights_size}" ) raw_data = self._dada_file.raw_data self.read_raw_data(raw_data)
[docs] def read_raw_data(self: WeightsFileReader, raw_data: bytes) -> None: """Read the scales and weights from the raw data bytes.""" # weights are written to file in the order: assert ( self.data_size % self.weights_packet_stride == 0 ), f"Expected data_size={self.data_size} to be multiple of {self.weights_packet_stride}" num_packets = self.data_size // self.weights_packet_stride # packets are organised into heaps, where a heap contains all the scale factors # and weights for all the channels assert ( self.nchan % self.nchan_per_packet == 0 ), f"Expected nchan={self.nchan} to be multiple of {self.nchan_per_packet}" packets_per_heap = self.nchan // self.nchan_per_packet # may not get a full heap at the end of the file num_heaps = num_packets // packets_per_heap if num_packets % packets_per_heap != 0: num_heaps + 1 self._logger.debug( ( f"data_size={self.data_size}, num_packets={num_packets} " f"packets_per_heap={packets_per_heap} num_heaps={num_heaps}" ) ) # scales exist for each heap and packet if self.unpack_scales: self._scales = np.zeros((num_heaps, packets_per_heap), dtype=np.single) # weights exist for each heap and channel if self.unpack_weights: self._weights = np.zeros((num_heaps * self.nweight_per_packet, self.nchan), dtype=np.ushort) # no need to assert that nbit % 8 is 0 as we have already asserted it is 16 nbit_as_bytes = self.nbit // 8 assert ( self.packet_weights_size % nbit_as_bytes == 0 ), f"Expected packet_weights_size={self.packet_weights_size} to be a multiple of {nbit_as_bytes}" nweights = self.packet_weights_size // nbit_as_bytes # if we're not unpacking anything then don't do anything. if self.unpack_scales or self.unpack_weights: self._unpack_weights_data_from_raw(raw_data, packets_per_heap, num_heaps, nweights)
def _unpack_weights_data_from_raw( self: WeightsFileReader, raw_data: bytes, packets_per_heap: int, num_heaps: int, nweights: int, ) -> None: """Unpacks the weights data for current file from raw bytes.""" byte_offset = 0 data_len = len(raw_data) heap_range = range(num_heaps) packet_range = range(packets_per_heap) for heap, packet in itertools.product(heap_range, packet_range): if byte_offset >= data_len: return # ----------------------------- # Read scale (float32) # ----------------------------- if self.unpack_scales: # struct.unpack_from avoids slicing (scale,) = struct.unpack_from("f", raw_data, byte_offset) self._scales[heap][packet] = scale # type: ignore byte_offset += self.packet_scales_size # ----------------------------- # Read weights (uint16 array) # ----------------------------- if self.unpack_weights: # unpack all weights in this packet packet_weights = struct.unpack_from(f"{nweights}H", raw_data, byte_offset) channel_range = range(self.nchan_per_packet) weight_range = range(self.nweight_per_packet) # transpose mapping (same as original logic) for idx, (channel, weight) in enumerate(itertools.product(channel_range, weight_range)): osamp = heap * self.nweight_per_packet + weight ochan = packet * self.nchan_per_packet + channel self._weights[osamp][ochan] = packet_weights[idx] # type: ignore byte_offset += self.packet_weights_size @property def scales(self: WeightsFileReader) -> ScalesType: """Return the unpacked scales.""" if not self.unpack_scales: raise RuntimeError("Cannot return scales as they were not unpacked from the file.") return self._scales # type: ignore @property def weights(self: WeightsFileReader) -> WeightsType: """Return the unpacked weights.""" if not self.unpack_weights: raise RuntimeError("Cannot return weights as they were not unpacked from the file.") return self._weights # type: ignore @property def packets_weights(self: WeightsFileReader) -> WeightsType: """ Get the unpacked weights per packet. The ``self.weights`` property returns the weights in heaps but this property will use packets in the first dimension. The shape of the resulting array is: ``(num_packets, nchan_per_packet)`` :return: the weights grouped by packets :rtype: WeightsType """ weights = self.weights return np.reshape(weights.flatten(), newshape=(-1, self.nchan_per_packet)) @property def dropped_packets(self: WeightsFileReader) -> np.ndarray: """Return a list of the dropped packets by inspecting NaNs in the scales.""" # flatten the 2D array packet_scales = self.scales.flatten() # convert the array of floats to boolean via isnan, then get the indices of the True values dropped_packet_list = np.isnan(packet_scales).nonzero()[0] self._logger.debug(f"found {len(dropped_packet_list)} dropped packets via scale factor NaNs") return dropped_packet_list + self.packet_offset @property def zeroed_packets(self: WeightsFileReader) -> np.ndarray: """Return a list of the zeroed out packets by inspecting the weights.""" packets_weights = self.packets_weights # find all packets where all the weights within the packet are zeros zero_weights_packet_list = np.nonzero(np.all(packets_weights == 0, axis=1))[0] self._logger.debug(f"found {len(zero_weights_packet_list)} zeroed weights packets") return zero_weights_packet_list + self.packet_offset @property def packet_offset(self: WeightsFileReader) -> int: """Get the packet offset for current file. This converts the obs_offset to a packet offset by dividing the value by the weights_packet_stride. This will assert that the obs_offset is a multiple of weights_packet_stride """ # offset the packet index by the packet_offset deduced from the OBS_OFFSET assert ( self.obs_offset % self.weights_packet_stride == 0 ), f"Expected obs_offset={self.obs_offset} to be a multiple of {self.weights_packet_stride}" return self.obs_offset // self.weights_packet_stride