# -*- coding: utf-8 -*-
#
# This file is part of the SKA PST project
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.
"""Module class for managing DADA files."""
from __future__ import annotations
import itertools
import logging
import pathlib
import struct
from types import TracebackType
from typing import Any, List
import nptyping as npt
import numpy as np
from ska_pst.common.constants import BITS_PER_BYTE
from ska_pydada import AsciiHeader, DadaFile
from .telescope_configuration import get_udp_nsamp_for_format
__all__ = [
"DadaFileManager",
"DadaFileReader",
"WeightsFileReader",
]
DEFAULT_HEADER_SIZE = 4096
HEADER_SIZE_KEY = "HDR_SIZE"
SECONDS_PER_FILE = 10
NBIT16 = 16
ScalesType = npt.NDArray[Any, npt.Single]
WeightsType = npt.NDArray[Any, npt.UShort]
[docs]class DadaFileManager:
"""Class that captures PST data files.
Parses attributes from a set of PST voltage recorder data and weights files and computes some of the
derived quantities from the scales and weights, such as the inferred number of dropped/invalid packets.
"""
def __init__(
self: DadaFileManager,
folder: pathlib.Path,
logger: logging.Logger | None = None,
) -> None:
"""
Initialise the DADA file manager.
This scans the provided folder for DADA files and instantiates
corresponding reader objects for both data and weights files.
:param folder: absolute path to directory containing ``data/`` and
``weights/`` subdirectories with DADA files.
:type folder: pathlib.Path
:param logger: optional logger instance to use for debug and warning messages.
:type logger: logging.Logger | None
:raises AssertionError: if the provided folder does not exist or is not a directory.
"""
assert folder.exists() and folder.is_dir()
self.folder = folder
self._data_files: List[DadaFileReader] = []
self._weights_files: List[WeightsFileReader] = []
self._logger = logger or logging.getLogger(__name__)
self._get_dada_files()
def _get_dada_files(self: DadaFileManager) -> None:
"""Populate list of Data and Weights files."""
data_paths = list(self.folder.glob("data/*.dada"))
weights_paths = list(self.folder.glob("weights/*.dada"))
if len(data_paths) != len(weights_paths):
self._logger.warning(
f"WARNING: data_paths ({len(data_paths)}) != weights_paths ({len(weights_paths)})"
)
data_paths.sort()
weights_paths.sort()
data_files = [DadaFileReader(f) for f in data_paths]
weights_files = [WeightsFileReader(f) for f in weights_paths]
self._data_files = data_files
self._weights_files = weights_files
@property
def data_files(self: DadaFileManager) -> List[DadaFileReader]:
"""
Get the list of data file readers.
:return: list of DADA data file reader instances.
:rtype: List[DadaFileReader]
"""
return self._data_files
@property
def weights_files(self: DadaFileManager) -> List[WeightsFileReader]:
"""
Get the list of weights file readers.
:return: list of DADA weights file reader instances.
:rtype: List[WeightsFileReader]
"""
return self._weights_files
[docs]class DadaFileReader:
"""
Reader class for PSR DADA files.
This class provides a high-level interface to access metadata and
header values from a DADA file using the underlying ``ska_pydada``
library. It supports both header-only and full-file access modes.
"""
def __init__(
self: DadaFileReader,
file: pathlib.Path,
header_only: bool = True,
logger: logging.Logger | None = None,
) -> None:
"""
Initialise a DADA file reader.
:param file: path to the DADA file.
:type file: pathlib.Path
:param header_only: if True, only the header is loaded to minimise memory usage.
:type header_only: bool
:param logger: optional logger instance.
:type logger: logging.Logger | None
"""
self._dada_file = DadaFile.load_from_file(file=file, header_only=header_only)
self._logger = logger or logging.getLogger(__name__)
def __enter__(self: DadaFileReader) -> DadaFileReader:
"""
Enter context manager.
:return: the current instance.
:rtype: DadaFileReader
"""
return self
def __exit__(
self: DadaFileReader,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""
Exit context manager.
:param exc_type: exception type if raised.
:type exc_type: type[BaseException] | None
:param exc_val: exception value.
:type exc_val: BaseException | None
:param exc_tb: traceback object.
:type exc_tb: TracebackType | None
"""
@property
def header(self: DadaFileReader) -> AsciiHeader:
"""Get header for file."""
return self._dada_file.header
@property
def file_path(self: DadaFileReader) -> str:
"""Get path of file in string."""
return str(self._dada_file.file.resolve())
@property
def file_size(self: DadaFileReader) -> int:
"""Get size of file in bytes."""
stats = self._dada_file.file.stat()
return stats.st_size
@property
def header_size(self: DadaFileReader) -> int:
"""Get the size of the header in bytes."""
return self.header.header_size
@property
def data_size(self: DadaFileReader) -> int:
"""Get the size of the data."""
return self.file_size - self.header_size
@property
def obs_offset(self: DadaFileReader) -> int:
"""Get the OBS_OFFSET value."""
return self._dada_file.obs_offset
@property
def file_number(self: DadaFileReader) -> int:
"""Get the FILE_NUMBER value from header."""
return self._dada_file.file_number
@property
def eb_id(self: DadaFileReader) -> str:
"""Get the EB_ID value from the header."""
return self._get_header_str("EB_ID")
@property
def scan_id(self: DadaFileReader) -> int:
"""Get the SCAN_ID value from header."""
return self._dada_file.scan_id
@property
def observer(self: DadaFileReader) -> str:
"""Get the OBSERVER value from header."""
return self._get_header_str("OBSERVER")
@property
def intent(self: DadaFileReader) -> str:
"""Build value using SOURCE from header."""
return f"Tied-array beam observation of {self.source}"
@property
def notes(self: DadaFileReader) -> str:
"""Get the NOTES value from header."""
# note this key does not current exist
return self._get_header_str("INTENT")
@property
def source(self: DadaFileReader) -> str:
"""Get the SOURCE value from header."""
return self._get_header_str("SOURCE")
@property
def utc_start(self: DadaFileReader) -> str:
"""Get the UTC_START value from header."""
return self._get_header_str("UTC_START")
@property
def picoseconds(self: DadaFileReader) -> int:
"""Get the PICOSECONDS value from the header."""
return self.header.get_int_opt("PICOSECONDS") or 0
@property
def tsamp(self: DadaFileReader) -> float:
"""Get the TSAMP value from header."""
return self.header.get_float("TSAMP")
@property
def udp_nsamp(self: DadaFileReader) -> int:
"""Get the UDP_NSAMP value from header."""
return self.header.get_int("UDP_NSAMP")
@property
def resolution(self: DadaFileReader) -> int:
"""Get the RESOLUTION value from header."""
return self.header.get_int("RESOLUTION")
@property
def resolution_per_sample(self: DadaFileReader) -> int:
"""
The amount of bytes needed for one time sample of all channels, polarisations and bits.
Note that this may be different to the ``resolution`` property as that value may include
a factor ``UDP_NSAMP`` in the resultant value.
"""
return self.nbit * self.nchan * self.npol * self.ndim // BITS_PER_BYTE
@property
def telescope(self: DadaFileReader) -> str:
"""Get the TELESCOPE value from header."""
return self._get_header_str("TELESCOPE")
@property
def nchan(self: DadaFileReader) -> int:
"""Get the NCHAN value from header."""
return self.header.get_int("NCHAN")
@property
def freq(self: DadaFileReader) -> str:
"""Get the FREQ value from header."""
return self._get_header_str("FREQ")
@property
def bw(self: DadaFileReader) -> str:
"""Get the BW value from header."""
return self._get_header_str("BW")
@property
def npol(self: DadaFileReader) -> int:
"""Get the NPOL value from header."""
return self.header.get_int("NPOL")
@property
def poln_ft(self: DadaFileReader) -> str:
"""Get the POLN_FT value from header."""
return self._get_header_str("POLN_FT")
@property
def stt_crd1(self: DadaFileReader) -> str:
"""Get the STT_CRD1 value from header."""
return self._get_header_str("STT_CRD1", "00:00:00.0")
@property
def stt_crd2(self: DadaFileReader) -> str:
"""Get the STT_CRD2 value from header."""
return self._get_header_str("STT_CRD2", "00:00:00.0")
@property
def equinox(self: DadaFileReader) -> str:
"""Get the EQUINOX value from header."""
return self._get_header_str("EQUINOX", "2000")
@property
def sky_coord_equinox(self: DadaFileReader) -> str:
"""Get the EQUINOX value from header in the format of J<value>."""
return f"J{self.equinox}"
@property
def nbit(self: DadaFileReader) -> int:
"""Get the number of bits the data is encoded in."""
return self.header.get_int("NBIT")
@property
def ndim(self: DadaFileReader) -> int:
"""Get the number of dimensions (2=complex, 1=real)."""
return self.header.get_int("NDIM")
def __getattr__(self: DadaFileReader, item: str) -> Any:
"""Get an attribute from the DADA file header."""
return self._get_header_str(key=item)
def _get_header_str(self: DadaFileReader, key: str, default_value: str = "Unknown") -> str:
"""
Get the header value of the specified key, or the default_value if not available.
:param key: header key to look for.
:param default_value: value to return if the key does not exist in the header.
:return: value of the header hey.
:rtype: str
"""
try:
return self.header.get_value(key)
except KeyError:
self._logger.debug(
f"key: {key} not present in self.header. default_value: {default_value} is used for {key}"
)
return default_value
[docs]class WeightsFileReader(DadaFileReader):
"""Class that can be used to read a Weights PSRDADA file generated by DSP.DISK."""
def __init__(
self: WeightsFileReader,
file: pathlib.Path,
unpack_scales: bool = True,
unpack_weights: bool = True,
logger: logging.Logger | None = None,
) -> None:
"""
Initialise the weights file reader.
:param file: path to the weights DADA file.
:type file: pathlib.Path
:param unpack_scales: whether to unpack scale factors.
:type unpack_scales: bool
:param unpack_weights: whether to unpack weights.
:type unpack_weights: bool
:param logger: optional logger instance.
:type logger: logging.Logger | None
"""
header_only = False
super().__init__(file=file, header_only=header_only, logger=logger)
self.unpack_scales = unpack_scales
self.unpack_weights = unpack_weights
self._scales: ScalesType | None = None
self._weights: WeightsType | None = None
def __enter__(self: WeightsFileReader) -> WeightsFileReader:
"""Enter context manager for this file."""
self._read_data()
return self
def __exit__(
self: WeightsFileReader,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit context manager for this file."""
def _read_data(self: WeightsFileReader) -> None:
"""Read the scales and weights in the file."""
# extract the required parameters from the header
self.nsamp_per_weight = self.header.get_int("NSAMP_PER_WEIGHT")
self.packet_weights_size = self.header.get_int("PACKET_WEIGHTS_SIZE")
self.packet_scales_size = self.header.get_int("PACKET_SCALES_SIZE")
udp_format = self._get_header_str("UDP_FORMAT")
# the CBF to PSR ICD specifies all weights a 16 bits per sample
assert self.nbit == NBIT16, f"Expected nbit={self.nbit} to be {NBIT16}"
# compute the number of relative weights in each packet
udp_nsamp = get_udp_nsamp_for_format(udp_format)
assert (
udp_nsamp % self.nsamp_per_weight == 0
), f"Expect {udp_nsamp=} to be a multiple of {self.nsamp_per_weight}"
self.nweight_per_packet = udp_nsamp // self.nsamp_per_weight
self._logger.debug(f"computed weights per packet as {self.nweight_per_packet}")
# compute the number of channels in each packet
assert self.packet_weights_size % (self.nweight_per_packet * self.nbit // 8) == 0, (
f"Expected packet_weights_size={self.packet_weights_size} to be a "
f"multiple of {self.nweight_per_packet * self.nbit // 8}"
)
self.nchan_per_packet = self.packet_weights_size // (self.nweight_per_packet * self.nbit // 8)
self._logger.debug(
f"packet_weights_size={self.packet_weights_size} nweight_per_packet={self.nweight_per_packet} "
f"nbit={self.nbit} nchan_per_packet={self.nchan_per_packet}"
)
self.weights_packet_stride = self.packet_weights_size + self.packet_scales_size
self._logger.debug(
f"weights_packet_stride={self.weights_packet_stride} packet_scales_size="
f"{self.packet_scales_size} packet_weights_size={self.packet_weights_size}"
)
raw_data = self._dada_file.raw_data
self.read_raw_data(raw_data)
[docs] def read_raw_data(self: WeightsFileReader, raw_data: bytes) -> None:
"""Read the scales and weights from the raw data bytes."""
# weights are written to file in the order:
assert (
self.data_size % self.weights_packet_stride == 0
), f"Expected data_size={self.data_size} to be multiple of {self.weights_packet_stride}"
num_packets = self.data_size // self.weights_packet_stride
# packets are organised into heaps, where a heap contains all the scale factors
# and weights for all the channels
assert (
self.nchan % self.nchan_per_packet == 0
), f"Expected nchan={self.nchan} to be multiple of {self.nchan_per_packet}"
packets_per_heap = self.nchan // self.nchan_per_packet
# may not get a full heap at the end of the file
num_heaps = num_packets // packets_per_heap
if num_packets % packets_per_heap != 0:
num_heaps + 1
self._logger.debug(
(
f"data_size={self.data_size}, num_packets={num_packets} "
f"packets_per_heap={packets_per_heap} num_heaps={num_heaps}"
)
)
# scales exist for each heap and packet
if self.unpack_scales:
self._scales = np.zeros((num_heaps, packets_per_heap), dtype=np.single)
# weights exist for each heap and channel
if self.unpack_weights:
self._weights = np.zeros((num_heaps * self.nweight_per_packet, self.nchan), dtype=np.ushort)
# no need to assert that nbit % 8 is 0 as we have already asserted it is 16
nbit_as_bytes = self.nbit // 8
assert (
self.packet_weights_size % nbit_as_bytes == 0
), f"Expected packet_weights_size={self.packet_weights_size} to be a multiple of {nbit_as_bytes}"
nweights = self.packet_weights_size // nbit_as_bytes
# if we're not unpacking anything then don't do anything.
if self.unpack_scales or self.unpack_weights:
self._unpack_weights_data_from_raw(raw_data, packets_per_heap, num_heaps, nweights)
def _unpack_weights_data_from_raw(
self: WeightsFileReader,
raw_data: bytes,
packets_per_heap: int,
num_heaps: int,
nweights: int,
) -> None:
"""Unpacks the weights data for current file from raw bytes."""
byte_offset = 0
data_len = len(raw_data)
heap_range = range(num_heaps)
packet_range = range(packets_per_heap)
for heap, packet in itertools.product(heap_range, packet_range):
if byte_offset >= data_len:
return
# -----------------------------
# Read scale (float32)
# -----------------------------
if self.unpack_scales:
# struct.unpack_from avoids slicing
(scale,) = struct.unpack_from("f", raw_data, byte_offset)
self._scales[heap][packet] = scale # type: ignore
byte_offset += self.packet_scales_size
# -----------------------------
# Read weights (uint16 array)
# -----------------------------
if self.unpack_weights:
# unpack all weights in this packet
packet_weights = struct.unpack_from(f"{nweights}H", raw_data, byte_offset)
channel_range = range(self.nchan_per_packet)
weight_range = range(self.nweight_per_packet)
# transpose mapping (same as original logic)
for idx, (channel, weight) in enumerate(itertools.product(channel_range, weight_range)):
osamp = heap * self.nweight_per_packet + weight
ochan = packet * self.nchan_per_packet + channel
self._weights[osamp][ochan] = packet_weights[idx] # type: ignore
byte_offset += self.packet_weights_size
@property
def scales(self: WeightsFileReader) -> ScalesType:
"""Return the unpacked scales."""
if not self.unpack_scales:
raise RuntimeError("Cannot return scales as they were not unpacked from the file.")
return self._scales # type: ignore
@property
def weights(self: WeightsFileReader) -> WeightsType:
"""Return the unpacked weights."""
if not self.unpack_weights:
raise RuntimeError("Cannot return weights as they were not unpacked from the file.")
return self._weights # type: ignore
@property
def packets_weights(self: WeightsFileReader) -> WeightsType:
"""
Get the unpacked weights per packet.
The ``self.weights`` property returns the weights in heaps but
this property will use packets in the first dimension. The shape
of the resulting array is: ``(num_packets, nchan_per_packet)``
:return: the weights grouped by packets
:rtype: WeightsType
"""
weights = self.weights
return np.reshape(weights.flatten(), newshape=(-1, self.nchan_per_packet))
@property
def dropped_packets(self: WeightsFileReader) -> np.ndarray:
"""Return a list of the dropped packets by inspecting NaNs in the scales."""
# flatten the 2D array
packet_scales = self.scales.flatten()
# convert the array of floats to boolean via isnan, then get the indices of the True values
dropped_packet_list = np.isnan(packet_scales).nonzero()[0]
self._logger.debug(f"found {len(dropped_packet_list)} dropped packets via scale factor NaNs")
return dropped_packet_list + self.packet_offset
@property
def zeroed_packets(self: WeightsFileReader) -> np.ndarray:
"""Return a list of the zeroed out packets by inspecting the weights."""
packets_weights = self.packets_weights
# find all packets where all the weights within the packet are zeros
zero_weights_packet_list = np.nonzero(np.all(packets_weights == 0, axis=1))[0]
self._logger.debug(f"found {len(zero_weights_packet_list)} zeroed weights packets")
return zero_weights_packet_list + self.packet_offset
@property
def packet_offset(self: WeightsFileReader) -> int:
"""Get the packet offset for current file.
This converts the obs_offset to a packet offset by dividing the value by the weights_packet_stride.
This will assert that the obs_offset is a multiple of weights_packet_stride
"""
# offset the packet index by the packet_offset deduced from the OBS_OFFSET
assert (
self.obs_offset % self.weights_packet_stride == 0
), f"Expected obs_offset={self.obs_offset} to be a multiple of {self.weights_packet_stride}"
return self.obs_offset // self.weights_packet_stride