# -*- coding: utf-8 -*-
#
# This file is part of the SKA PST project
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.
"""Module for codifying the assertions of values."""
from __future__ import annotations
import functools
import logging
from typing import Any, Callable, Dict, Tuple
import astropy.units as u
import numpy as np
from ska_pst.common.constants import (
BITS_PER_BYTE,
COMPLEX_NDIMS,
REAL_NDIMS,
SIZE_OF_FLOAT32_IN_BYTES,
WEIGHTS_NBITS,
)
from ska_pydada import AsciiHeader, DadaFile
from ska_pst.common import CbfPstConfig
SECONDS_PER_FILE: float = 10.0
"""Number of seconds each DSP.DISK file is for."""
[docs]def assert_const_value(expected_value: str | int | float) -> Callable[..., None]:
"""Assert the header has a fixed/constant value.
This function converts the :py:func:`assert_header_value` into a
partial function but sending the value argument as the expected_value.
:param expected_value: the constant
:type expected_value: str | int | float
:return: a callable to do the assertion.
:rtype: Callable[..., None]
"""
return functools.partial(assert_header_value, expected_value=expected_value)
[docs]def assert_equal_to(other_key: str) -> Callable[..., None]:
"""Assert that header is equal to another header value.
This function is a partial function that returns an assertion
function that when called will check if the header value
is equal a value from another header key.
:param other_key: the header key of the other value to assert
equality against.
:type other_key: str
:return: a callable that the will assert the values of current
header value and the other header value are the same.
:rtype: Callable[..., None]
"""
def _assert(*, header: AsciiHeader, **kwargs: Any) -> None:
other_value = header[other_key]
assert_header_value(header=header, expected_value=other_value, **kwargs)
return _assert
[docs]def assert_ndim(*, is_weights: bool, **kwargs: Any) -> None:
"""Assert the ``NDIM`` value is correct for given file type.
For data files this is set to 2 (i.e. complex valued data) and for
weights it is set to 1 (i.e. real valued data).
:param is_weights: whether asserting against a data or weights file.
:type is_weights: bool
:raises AssertionError: if ``NDIM`` is incorrect for file type.
"""
if is_weights:
assert_header_value(expected_value=REAL_NDIMS, **kwargs)
else:
assert_header_value(expected_value=COMPLEX_NDIMS, **kwargs)
[docs]def assert_npol(*, is_weights: bool, cbf_pst_config: CbfPstConfig, **kwargs: Any) -> None:
"""Assert the ``NPOL`` value is correct for given file type and scan configuration.
For weights files this value should be 1. For data files this should be
equal to the value of ``npol`` from the frequency band configuration.
:param is_weights: whether asserting against a data or weights file.
:type is_weights: bool
:param cbf_pst_config: the CBF/PST configuration for the current frequency band.
:type cbf_pst_config: CbfPstConfig
:raises AssertionError: if ``NPOL`` is incorrect for file type and scan configuration.
"""
if is_weights:
assert_header_value(expected_value=1, **kwargs)
else:
expected_npol = cbf_pst_config.npol
assert_header_value(expected_value=expected_npol, **kwargs)
[docs]def assert_nbit(*, is_weights: bool, cbf_pst_config: CbfPstConfig, **kwargs: Any) -> None:
"""Assert the ``NBIT`` value is correct for given file type and scan configuration.
For weights files this value should be 16. For data files this should be
equal ``nbit`` from frequency band configuration. The scan configuration
passes through the overall number of bits for the data including the real and
imaginary but in PST the ``NBIT`` value is per dimension.
:param scan_config: the scan configuration as a dictionary.
:type scan_config: dict
:param is_weights: whether asserting against a data or weights file.
:type is_weights: bool
:raises AssertionError: if ``NBIT`` is incorrect for file type and scan configuration.
"""
if is_weights:
assert_header_value(expected_value=WEIGHTS_NBITS, **kwargs)
else:
assert_header_value(expected_value=cbf_pst_config.nbit, **kwargs)
[docs]def assert_tsamp(**kwargs: Any) -> None:
"""
Assert that the ``TSAMP`` header value is correct for the given file and configuration.
See :py:func:`get_expected_tsamp` for the details about the expected ``TSAMP`` value.
:raises AssertionError: if ``TSAMP`` is incorrect for file type and scan configuration.
"""
tsamp = get_expected_tsamp(**kwargs)
assert_header_value(expected_value=tsamp, **kwargs)
[docs]def assert_bytes_per_second(**kwargs: Any) -> None:
"""Assert that the ``BYTES_PER_SECOND`` header value is correct for the given file and configuration.
See :py:func:`calculate_bytes_per_second` for the details about calculation the expected
``BYTES_PER_SECOND`` value.
:raises AssertionError: if ``BYTES_PER_SECOND`` is incorrect for file type and scan configuration.
"""
bytes_per_second = calculate_bytes_per_second(**kwargs)
assert_header_value(expected_value=bytes_per_second, **kwargs)
[docs]def assert_resolution(**kwargs: Any) -> None:
"""Assert that the ``RESOLUTION`` header value is correct for the given file and configuration.
See :py:func:`calculate_resolution` for the details about calculation the expected
``RESOLUTION`` value.
:raises AssertionError: if ``RESOLUTION`` is incorrect for file type and scan configuration.
"""
resolution = calculate_resolution(**kwargs)
assert_header_value(expected_value=resolution, **kwargs)
def _get_expected_channel_range(
scan_config: dict, cbf_pst_config: CbfPstConfig, **kwargs: Any
) -> Tuple[int, int]:
centre_freq_mhz = scan_config["centre_freq_mhz"]
bandwidth_mhz = scan_config["bandwidth_mhz"]
workloads = cbf_pst_config.calculate_channel_ranges(
bandwidth_mhz=bandwidth_mhz, centre_freq_mhz=centre_freq_mhz
)
return (workloads[0].start_channel, workloads[0].end_channel)
def assert_start_chan(*, scan_config: dict, cbf_pst_config: CbfPstConfig, **kwargs: Any) -> None:
"""
Assert that the ``START_NCHAN`` header value is correct for the given file and configuration.
This uses the ``centre_frequency`` and ``total_bandwidth`` from the
scan configuration, and the ``cbf_pst_config`` that relates to the current ``frequency_band``
to determine the expected start and end channels and then uses the ``start_chan``
value to assert against the header value of ``START_CHAN``.
:param scan_config: the scan configuration to assert the start channel against
:type scan_config: dict
:param cbf_pst_config: the CBF/PST configuration for the current frequency band.
:type cbf_pst_config: CbfPstConfig
:raises AssertionError: if ``START_CHAN`` is incorrect for file type and scan configuration.
"""
start_chan, _ = _get_expected_channel_range(
scan_config=scan_config, cbf_pst_config=cbf_pst_config, **kwargs
)
assert_header_value(expected_value=start_chan, **kwargs)
def assert_end_chan(*, scan_config: dict, cbf_pst_config: CbfPstConfig, **kwargs: Any) -> None:
"""
Assert that the ``END_NCHAN`` header value is correct for the given file and configuration.
This uses the ``centre_frequency`` and ``total_bandwidth`` from the
scan configuration, and the ``cbf_pst_config`` that relates to the current ``frequency_band``
to determine the expected start and end channels and then uses the ``end_chan``
value to assert against the header value of ``END_NCHAN``.
:param scan_config: the scan configuration to assert the end channel against
:type scan_config: dict
:param cbf_pst_config: the CBF/PST configuration for the current frequency band.
:type cbf_pst_config: CbfPstConfig
:raises AssertionError: if ``END_CHAN`` is incorrect for file type and scan configuration.
"""
_, end_chan = _get_expected_channel_range(
scan_config=scan_config, cbf_pst_config=cbf_pst_config, **kwargs
)
assert_header_value(expected_value=end_chan, **kwargs)
[docs]def assert_obs_offset(
*,
file: DadaFile,
header: AsciiHeader,
scan_config: dict,
is_weights: bool,
cbf_pst_config: CbfPstConfig,
**kwargs: Any,
) -> None:
"""
Assert that the ``OBS_OFFSET`` header value for given file and configuration.
``OBS_OFFSET`` is a multiple of the ``RESOLUTION`` value, but is determined by first
using ``BYTES_PER_SECOND`` to find the number of bytes every 10 seconds. This value
is rounded up to be a multiple of ``RESOLUTION``.
For weights files, ``OBS_OFFSET`` is based off the data ``OBS_OFFSET``
and is then scaled ``WEIGHTS_RESOLUTION``/``DATA_RESOLUTION``.
The file name also includes the ``OBS_OFFSET`` and this method asserts that the
value is correct.
:param file: the current file that the ``OBS_OFFSET`` is being asserted
:type file: DadaFile
:param scan_config: the scan configuration as a dictionary.
:type scan_config: dict
:param is_weights: whether the current file is a weights or data file.
:type is_weights: bool
:param cbf_pst_config: the CBF/PST configuration for the current frequency band.
:type cbf_pst_config: CbfPstConfig
:raises AssertionError: if ``OBS_OFFSET`` is incorrect for the current file.
"""
resolution = calculate_resolution(
scan_config=scan_config, is_weights=False, cbf_pst_config=cbf_pst_config, **kwargs
)
bytes_per_second = calculate_bytes_per_second(
scan_config=scan_config, is_weights=False, cbf_pst_config=cbf_pst_config, **kwargs
)
bytes_per_file = int(bytes_per_second * SECONDS_PER_FILE)
remainder = bytes_per_file % resolution
if remainder > 0:
bytes_per_file += resolution - remainder
file_number = header.get_int("FILE_NUMBER")
obs_offset = file_number * bytes_per_file
if is_weights:
weights_resolution = calculate_resolution(
scan_config=scan_config, is_weights=True, cbf_pst_config=cbf_pst_config, **kwargs
)
obs_offset = obs_offset // resolution * weights_resolution
file_name_parts = file.file.stem.split("_")
file_obs_offset = int(file_name_parts[-2])
assert obs_offset == file_obs_offset, (
f"expected file {file.file}'s obs_offset suffix to be {obs_offset} but was {file_obs_offset}. "
f"{resolution=}, {bytes_per_second=}, {bytes_per_file=}, {file_number=}"
)
assert_header_value(expected_value=obs_offset, header=header, file=file, **kwargs)
[docs]def assert_nant(*, scan_config: dict, **kwargs: Any) -> None:
"""Assert that the ``NANT`` header value is correct for the given configuration.
This value should be the length of the ``receptors`` value in the scan configuration.
:param scan_config: the scan configuration as a dictionary.
:type scan_config: dict
:raises AssertionError: if ``NANT`` is incorrect.
"""
nant = len(scan_config["receptors"])
assert_header_value(expected_value=nant, **kwargs)
[docs]def assert_file_number(*, file: DadaFile, **kwargs: Any) -> None:
"""Assert that the ``FILE_NUMBER`` header matches file name.
The ``FILE_NUMBER`` header should match the last part of the file name.
:param file: the current file that the ``FILE_NUMBER`` is being asserted
:type file: DadaFile
:raises AssertionError: if ``FILE_NUMBER`` doesn't match the file name.
"""
file_name_parts = file.file.stem.split("_")
file_number = int(file_name_parts[-1])
assert_header_value(file=file, expected_value=file_number, **kwargs)
DADA_VALUE_ASSERTIONS: Dict[str, Callable[..., None]] = {
"UDP_FORMAT": assert_udp_format,
"NSUBBAND": assert_const_value(1),
"COORD_MD": assert_const_value("J2000"),
"TRK_MODE": assert_const_value("TRACK"),
"START_CHANNEL": assert_start_chan,
"END_CHANNEL": assert_end_chan,
# Start/end channel out will change when we use subbands
"START_CHANNEL_OUT": assert_start_chan,
"END_CHANNEL_OUT": assert_end_chan,
"NDIM": assert_ndim,
"NPOL": assert_npol,
"NBIT": assert_nbit,
"TSAMP": assert_tsamp,
"BYTES_PER_SECOND": assert_bytes_per_second,
"RESOLUTION": assert_resolution,
"OBS_OFFSET": assert_obs_offset,
"FILE_NUMBER": assert_file_number,
"NANT": assert_nant,
}
[docs]def get_expected_tsamp(*, is_weights: bool, cbf_pst_config: CbfPstConfig, **kwargs: Any) -> float:
"""
Get the expected ``TSAMP`` for given file type and frequency band.
This gets the ``tsamp`` value from the ``cbf_pst_config`` which is relates
to the telescope and frequency band. All files should have the ``TSAMP`` value
derived from this value. For weights, the data tsamp is scaled by the number of
samples per packet as the weights are valid for each sample within a packet.
:param is_weights: whether the current file is a weights or data file.
:type is_weights: bool
:param cbf_pst_config: the CBF/PST configuration for the current frequency band.
:type cbf_pst_config: CbfPstConfig
:return: the time per sample in microseconds.
:rtype: float
"""
tsamp = cbf_pst_config.tsamp
if is_weights:
udp_nsamp = cbf_pst_config.udp_nsamp
tsamp *= udp_nsamp
return tsamp
[docs]def calculate_bytes_per_second(
*, is_weights: bool, scan_config: dict, cbf_pst_config: CbfPstConfig, **kwargs: Any
) -> float:
"""Calculate the expected bytes per seconds given file type and scan configuration.
This calculates the expected number of bytes per seconds that each file should be generating.
The number of bytes / sample is calculated based on the file type and this value is then
divided by the tsamp value for the given file type. As tsamp is in microseconds there
is a scale factor of 1e6 to ensure that the value is per second not per microsecond.
:param is_weights: whether the current file is a weights or data file.
:type is_weights: bool
:param scan_config: the scan configuration as a dictionary.
:type scan_config: dict
:param cbf_pst_config: the CBF/PST configuration for the current frequency band.
:type cbf_pst_config: CbfPstConfig
:return: the bytes per second for the given file type.
:rtype: float
"""
# tsamp is in microseconds, need to convert
tsamp = u.Quantity(
get_expected_tsamp(
is_weights=is_weights,
scan_config=scan_config,
cbf_pst_config=cbf_pst_config,
**kwargs,
),
unit=u.microsecond,
)
bandwidth_mhz = scan_config["bandwidth_mhz"]
nchan = cbf_pst_config.nchan_for_bandwidth(bandwidth_mhz=bandwidth_mhz)
if is_weights:
bytes_per_sample = nchan * WEIGHTS_NBITS // BITS_PER_BYTE
else:
npol = cbf_pst_config.npol
nbit = cbf_pst_config.nbit
ndim = cbf_pst_config.ndim
bytes_per_sample = nchan * npol * nbit * ndim // BITS_PER_BYTE
return (bytes_per_sample / tsamp).si.value # tsamp is in microseconds
[docs]def calculate_resolution(
*, is_weights: bool, scan_config: dict, cbf_pst_config: CbfPstConfig, **kwargs: Any
) -> int:
"""
Calculate the ``RESOLUTION`` for a given file and scan configuration.
The ``RESOLUTION`` value is amount of bytes needed to get all the data
for the channels ``NCHAN`` when the number of samples per channel per
UDP packet is ``udp_nsamp``. For weights files the value includes a
floating point scale factor per packet.
:param is_weights: whether the current file is a weights or data file.
:type is_weights: bool
:param scan_config: the scan configuration as a dictionary.
:type scan_config: dict
:param cbf_pst_config: the CBF/PST configuration for the current frequency band.
:type cbf_pst_config: CbfPstConfig
:return: the expected ``RESOLUTION`` for the given file and scan configuration.
:rtype: int
"""
cbf_pst_config = cbf_pst_config
bandwidth_mhz = scan_config["bandwidth_mhz"]
nchan = cbf_pst_config.nchan_for_bandwidth(bandwidth_mhz=bandwidth_mhz)
udp_nsamp = cbf_pst_config.udp_nsamp
wt_nsamp = cbf_pst_config.wt_nsamp
udp_nchan = cbf_pst_config.udp_nchan
if is_weights:
packets_scale_stride = nchan // udp_nchan * SIZE_OF_FLOAT32_IN_BYTES
weights_stride = (udp_nsamp // wt_nsamp) * nchan * WEIGHTS_NBITS // BITS_PER_BYTE
return packets_scale_stride + weights_stride
else:
npol = cbf_pst_config.npol
nbit = cbf_pst_config.nbit
ndim = cbf_pst_config.ndim
return (udp_nsamp * nchan * nbit * ndim * npol) // BITS_PER_BYTE