Source code for ska_pst.testutils.dada.dada_file_reader

# -*- coding: utf-8 -*-
#
# This file is part of the SKA PST project
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.

"""Module class to read PSR DADA file."""

from __future__ import annotations

__all__ = [
    "DadaFileReader",
    "WeightsFileReader",
]

import itertools
import logging
import mmap
import os
import pathlib
import struct
from types import TracebackType
from typing import Any, Dict, Tuple

import nptyping as npt
import numpy as np

from ska_pst.common import get_udp_nsamp_for_format

DEFAULT_HEADER_SIZE = 4096
HEADER_SIZE_KEY = "HDR_SIZE"
SECONDS_PER_FILE = 10

ScalesType = npt.NDArray[Any, npt.Single]
WeightsType = npt.NDArray[Any, npt.UShort]


[docs]class DadaFileReader: """Class that can be used to read a PSR DADA file.""" def __init__(self: DadaFileReader, file: pathlib.Path, logger: logging.Logger | None = None) -> None: """Create instance of file reader.""" assert file.exists() and file.is_file() self.file = file self.header_size = DEFAULT_HEADER_SIZE self._header: Dict[str, str] = {} self._logger = logger or logging.getLogger(__name__) def __enter__(self: DadaFileReader) -> DadaFileReader: """Enter context manager for this file.""" self._read_header() return self def __exit__( self: DadaFileReader, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """Exit context manager.""" def _read_header(self: DadaFileReader) -> Dict[str, str]: """Read the header of file.""" self._logger.debug(f"Loading header for {self}") with open(self.file, "rb") as f: # memory map file - just want the first 4096 bytes with mmap.mmap(f.fileno(), DEFAULT_HEADER_SIZE, prot=mmap.PROT_READ) as mm: header, header_str = self._read_header_from_mmap(mm) # if key doesn't exist or its not an int we expect this to fail self.header_size = int(header[HEADER_SIZE_KEY]) if self.header_size != DEFAULT_HEADER_SIZE: with mmap.mmap(f.fileno(), self.header_size, prot=mmap.PROT_READ) as mm: header, header_str = self._read_header_from_mmap(mm) self._logger.debug(f"Header from file {self.file}:\n{header_str}") self._header = header return header def _read_header_from_mmap(self: DadaFileReader, file: mmap.mmap) -> Tuple[Dict[str, str], str]: """Read the lines of the memory mapped file into a dictionary.""" header: Dict[str, str] = {} # this is only used for logging: header_str = "" for currline in iter(file.readline, b""): line: str = currline.decode() line = line.replace("\0", " ").strip() # ignore a comment if line.startswith("#"): continue if len(line) == 0: continue header_str += line header_str += "\n" [key, value] = line.lstrip().split(" ", maxsplit=1) assert len(key) > 0, f"Expected header key of line {str(currline)} to not be empty" header[key] = value.lstrip() return header, header_str @property def file_name(self: DadaFileReader) -> str: """Get name of file.""" return self.file.name @property def header(self: DadaFileReader) -> Dict[str, str]: """Get header for file.""" if self._header == {}: self._read_header() return {**self._header} @property def file_size(self: DadaFileReader) -> int: """Get size of file in bytes.""" stats = self.file.stat() return stats.st_size @property def data_size(self: DadaFileReader) -> int: """Get the size of the data.""" return self.file_size - self.header_size @property def obs_offset(self: DadaFileReader) -> int: """Get the OBS_OFFSET value.""" return int(self.header["OBS_OFFSET"]) @property def file_number(self: DadaFileReader) -> int: """Get the FILE_NUMBER value from header.""" return int(self.header["FILE_NUMBER"]) @property def scan_id(self: DadaFileReader) -> int: """Get the SCAN_ID value from header.""" return int(self.header["SCAN_ID"]) def __repr__(self: DadaFileReader) -> str: """Get string representation of file.""" return f'DadaFileReader(file="{self.file.parent.name}/{self.file_name}")'
[docs]class WeightsFileReader(DadaFileReader): """Class that can be used to read a Weights PSRDADA file generated by ska_pst_dsp_disk.""" def __init__( self: WeightsFileReader, file: pathlib.Path, unpack_scales: bool = True, unpack_weights: bool = True, logger: logging.Logger | None = None, ) -> None: """Create instance of weights file reader.""" super().__init__(file, logger=logger) self.unpack_scales = unpack_scales self.unpack_weights = unpack_weights self._scales: ScalesType | None = None self._weights: WeightsType | None = None def __enter__(self: WeightsFileReader) -> WeightsFileReader: """Enter context manager for this file.""" self._read_data() return self def __exit__( self: WeightsFileReader, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """Exit context manager for this file.""" def _read_data(self: WeightsFileReader) -> None: """Read the scales and weights in the file.""" self._read_header() # extract the required parameters from the header self.nchan = int(self._header["NCHAN"]) self.nbit = int(self._header["NBIT"]) self.nsamp_per_weight = int(self._header["NSAMP_PER_WEIGHT"]) self.packet_weights_size = int(self._header["PACKET_WEIGHTS_SIZE"]) self.packet_scales_size = int(self._header["PACKET_SCALES_SIZE"]) udp_format = self._header["UDP_FORMAT"] # the CBF to PSR ICD specifies all weights a 16 bits per sample assert self.nbit == 16, f"Expected nbit={self.nbit} to be 16" # compute the number of relative weights in each packet udp_nsamp = get_udp_nsamp_for_format(udp_format) assert ( udp_nsamp % self.nsamp_per_weight == 0 ), f"Expect {udp_nsamp=} to be a multiple of {self.nsamp_per_weight}" self.nweight_per_packet = udp_nsamp // self.nsamp_per_weight self._logger.debug(f"computed weights per packet as {self.nweight_per_packet}") # compute the number of channels in each packet assert self.packet_weights_size % (self.nweight_per_packet * self.nbit // 8) == 0, ( f"Expected packet_weights_size={self.packet_weights_size} to be a " f"multiple of {self.nweight_per_packet * self.nbit // 8}" ) self.nchan_per_packet = self.packet_weights_size // (self.nweight_per_packet * self.nbit // 8) self._logger.debug( f"packet_weights_size={self.packet_weights_size} nweight_per_packet={self.nweight_per_packet} " f"nbit={self.nbit} nchan_per_packet={self.nchan_per_packet}" ) self.weights_packet_stride = self.packet_weights_size + self.packet_scales_size self._logger.debug( f"weights_packet_stride={self.weights_packet_stride} packet_scales_size=" f"{self.packet_scales_size} packet_weights_size={self.packet_weights_size}" ) with open(self.file, "rb") as f: # memory map file - want all bytes after the header with mmap.mmap(f.fileno(), self.data_size, prot=mmap.PROT_READ, offset=self.header_size) as mm: self._read_data_from_mmap(mm) def _read_data_from_mmap(self: WeightsFileReader, file: mmap.mmap) -> None: """Read the scales and weights in the memory mapped file.""" # weights are written to file in the order: assert ( self.data_size % self.weights_packet_stride == 0 ), f"Expected data_size={self.data_size} to be multiple of {self.weights_packet_stride}" num_packets = self.data_size // self.weights_packet_stride # packets are organised into heaps, where a heap contains all the scale factors # and weights for all the channels assert ( self.nchan % self.nchan_per_packet == 0 ), f"Expected nchan={self.nchan} to be multiple of {self.nchan_per_packet}" packets_per_heap = self.nchan // self.nchan_per_packet # may not get a full heap at the end of the file num_heaps = num_packets // packets_per_heap if num_packets % packets_per_heap != 0: num_heaps + 1 self._logger.debug( ( f"data_size={self.data_size}, num_packets={num_packets} " f"packets_per_heap={packets_per_heap} num_heaps={num_heaps}" ) ) # scales exist for each heap and packet if self.unpack_scales: self._scales = np.zeros((num_heaps, packets_per_heap), dtype=np.single) # weights exist for each heap and channel if self.unpack_weights: self._weights = np.zeros((num_heaps * self.nweight_per_packet, self.nchan), dtype=np.ushort) # no need to assert that nbit % 8 is 0 as we have already asserted it is 16 nbit_as_bytes = self.nbit // 8 assert ( self.packet_weights_size % nbit_as_bytes == 0 ), f"Expected packet_weights_size={self.packet_weights_size} to be a multiple of {nbit_as_bytes}" nweights = self.packet_weights_size // nbit_as_bytes # if we're not unpacking anything then don't do anything. if self.unpack_scales or self.unpack_weights: self._unpack_weights_data(file, packets_per_heap, num_heaps, nweights) def _unpack_weights_data( self: WeightsFileReader, file: mmap.mmap, packets_per_heap: int, num_heaps: int, nweights: int, ) -> None: """Unpacks the weights data for current file.""" byte_offset = 0 heap_range = range(num_heaps) packet_range = range(packets_per_heap) for heap, packet in itertools.product(heap_range, packet_range): if byte_offset >= self.data_size: return if self.unpack_scales: # packet scale factor is stored as 32-bit float self._scales[heap][packet] = struct.unpack( # type: ignore "f", file.read(self.packet_scales_size) )[0] else: file.seek(self.packet_scales_size, os.SEEK_CUR) byte_offset += self.packet_scales_size if self.unpack_weights: # weights are stored as unsigned 16-bit integers packet_weights = struct.unpack(f"{nweights}H", file.read(self.packet_weights_size)) channel_range = range(self.nchan_per_packet) weight_range = range(self.nweight_per_packet) # transpose is required for idx, (channel, weight) in enumerate(itertools.product(channel_range, weight_range)): osamp = heap * self.nweight_per_packet + weight ochan = packet * self.nchan_per_packet + channel self._weights[osamp][ochan] = packet_weights[idx] # type: ignore else: file.seek(self.packet_weights_size, os.SEEK_CUR) byte_offset += self.packet_weights_size @property def scales(self: WeightsFileReader) -> ScalesType: """Return the unpacked scales.""" if not self.unpack_scales: raise RuntimeError("Cannot return scales as they were not unpacked from the file.") return self._scales # type: ignore @property def weights(self: WeightsFileReader) -> WeightsType: """Return the unpacked weights.""" if not self.unpack_weights: raise RuntimeError("Cannot return weights as they were not unpacked from the file.") return self._weights # type: ignore @property def packets_weights(self: WeightsFileReader) -> WeightsType: """ Get the unpacked weights per packet. The ``self.weights`` property returns the weights in heaps but this property will use packets in the first dimension. The shape of the resulting array is: ``(num_packets, nchan_per_packet)`` :return: the weights grouped by packets :rtype: WeightsType """ weights = self.weights return np.reshape(weights.flatten(), newshape=(-1, self.nchan_per_packet)) @property def dropped_packets(self: WeightsFileReader) -> np.ndarray: """Return a list of the dropped packets by inspecting NaNs in the scales.""" # flatten the 2D array packet_scales = self.scales.flatten() # convert the array of floats to boolean via isnan, then get the indices of the True values dropped_packet_list = np.isnan(packet_scales).nonzero()[0] self._logger.debug(f"found {len(dropped_packet_list)} dropped packets via scale factor NaNs") return dropped_packet_list + self.packet_offset @property def zeroed_packets(self: WeightsFileReader) -> np.ndarray: """Return a list of the zeroed out packets by inspecting the weights.""" packets_weights = self.packets_weights # find all packets where all the weights within the packet are zeros zero_weights_packet_list = np.nonzero(np.all(packets_weights == 0, axis=1))[0] self._logger.debug(f"found {len(zero_weights_packet_list)} zeroed weights packets") return zero_weights_packet_list + self.packet_offset @property def packet_offset(self: WeightsFileReader) -> int: """Get the package offset for current file. This converts the obs_offset a packet offset by dividing the value by the weights_packet_stride. This will assert that the obs_offset is a multiple of weights_packet_stride """ # offset the packet index by the packet_offset deduced from the OBS_OFFSET assert ( self.obs_offset % self.weights_packet_stride == 0 ), f"Expected obs_offset={self.obs_offset} to be a multiple of {self.weights_packet_stride}" return self.obs_offset // self.weights_packet_stride