Source code for ska_pst.testutils.stats.assert_stats

# -*- coding: utf-8 -*-
#
# This file is part of the SKA PST project
#
# Distributed under the terms of the BSD 3-clause new license.
# See LICENSE for more info.

"""
Module for asserting that estimated sample statics with known population statistics.

In particular this will assert that the sample mean and variance are close to known
population mean and variance within a given tolerance.
"""

from __future__ import annotations

import dataclasses
from typing import List

import numpy as np
import pandas as pd


[docs]@dataclasses.dataclass class SampleStatistics: """Data class that models the statistics of a sample. :ivar mean: the mean of the sample :vartype mean: float :ivar variance: the variance of the sample :vartype variance: float :ivar num_samples: the number of samples used to calculate the statistics :vartype num_samples: float """ mean: float variance: float num_samples: int
[docs]def assert_statistics( population_mean: float, population_var: float, sample_stats: SampleStatistics, channel: int, pol: int, tolerance: float = 6.0, ) -> None: """Assert that sample mean and var are within a given tolerance of population stats. :param population_mean: the mean of the population :type population_mean: float :param population_var: the variance of the population :type population_var: float :param sample_stats: the samples statistics to assert against the population stats. :type sample_stats: SampleStatistics :param num_samples: the sample size :type num_samples: int :param channel: the channel that is being tested :type channel: int :param pol: the polarisation that is being tested :type pol: int :param tolerance: the number of sigma to allow being away from population value, defaults to 6.0 :type tolerance: float, optional """ N = sample_stats.num_samples S = population_var mu = population_mean # This is the 4th moment of a gaussian distribution mu_4 = 3.0 * S**2 E = sample_stats.mean V = sample_stats.variance # expected variance in E var_e = S / N sigma_e = np.sqrt(var_e) # expected variance in V var_v = (mu_4 - (N - 3) / (N - 1) * S**2) / N sigma_v = np.sqrt(var_v) n_sigma_e = np.fabs(E - mu) / sigma_e n_sigma_v = np.fabs(V - S) / sigma_v assert n_sigma_e <= tolerance and n_sigma_v <= tolerance, ( f"Expected sample mean ({E:0.6f}) and variance ({V:0.6f}) to be within {tolerance:0.6f} sigma" f" of {mu:0.6f} and {S:0.6f} respectively for {channel=} and {pol=}. n_sigma_e={n_sigma_e:0.6f}, " f"n_sigma_v={n_sigma_v:0.6f}" )
NBIT_POPULATION_MEAN_VAR = { 1: (0.0, 1.0), 2: (0.0, 0.8608), 4: (0.0, 0.9789), 8: (0.0, 1.0), 16: (0.0, 1.0), } """ Empirical calculated population mean and variance for digitised data. The following code was used to generate the values. Only the 2 and 4 bit values are the ones that show significant quantisation effects and should be recalculated if there is a change to the digi mean and scales used. .. code-block:: python import numpy as np from typing import Tuple # update these values if there is a change to the scales and means applied nbit_scales_and_means = { 1: (0.5, 0.5), 2: (1.03, -0.5), 4: (3.14, -0.5), 8: (10.1, -0.5), 16: (1106.4, -0.5), } def calc_renorm_mean_var(nbit: int) -> Tuple[float,float]: (scale, mean) = nbit_scales_and_means[nbit] if nbit == 1: clip_min = 0 clip_max = 1 else: clip_min = -pow(2, nbit - 1) clip_max = -clip_min - 1 norm_data = np.random.randn(10_000_000) scaled_data = scale * norm_data + mean digi_data = np.clip(np.round(scaled_data).astype(np.int32), clip_min, clip_max) renorm_data = (digi_data - mean) / scale renorm_mean = np.round(np.mean(renorm_data), 4) renorm_var = np.round(np.var(renorm_data, ddof=1), 4) return (renorm_mean, renorm_var) nbit_renorm_mean_var = { nbit: calc_renorm_mean_var(nbit) for nbit in nbit_scales_and_means.keys() } print(nbit_renorm_mean_var) """
[docs]def assert_statistics_for_digitised_data( data: np.ndarray, nbit: int, tolerance: float = 9.0, ) -> None: """ Assert that sample mean and var are within a given tolerance of population stats for TFP data. This function asserts that the given Numpy array of data has a mean and variance within a given tolerance of the population mean and variance based on the number of bits used in the digitisation of the data. :param data: an array of either real or complex value floating point data. :type data: np.ndarray :param nbit: the number of bits used in the digitisation of the data. :type nbit: int :param tolerance: the number of sigma to allow being away from population value, defaults to 9.0 :type tolerance: float, optional """ population_mean, population_var = NBIT_POPULATION_MEAN_VAR[nbit] ndat, nchan, npol = data.shape # if complex convert back to 32-bit float and add a dimension if data.dtype == np.complex64: data = data[:, :, :, np.newaxis] data = data.view(np.float32) # only assert statistics on a limited number of samples max_nsamp = 8192 if ndat > max_nsamp: data = data[:max_nsamp] ndat = max_nsamp means = np.mean(data, axis=(0, -1)) variances = np.var(data, axis=(0, -1), ddof=1) for chan in range(nchan): for pol in range(npol): sample_stats = SampleStatistics( mean=means[chan][pol], variance=variances[chan][pol], num_samples=ndat ) assert_statistics( population_mean=population_mean, population_var=population_var, sample_stats=sample_stats, channel=chan, pol=pol, tolerance=tolerance, )
[docs]def assert_statistics_for_channels( channel_data: pd.DataFrame, population_mean: float, population_var: float, pol: str, tolerance: float = 6.0, ) -> None: """Assert that sample mean and var are within a given tolerance of population stats for each channel. :param channel_data: a data frame with statistics split by channel. This must include the following columns: "Mean", "Var.", "Num Samples". This should also be specific for a given polarisation and complex data dimension (e.g. for Pol A real data). :type channel_data: pd.DataFrame :param population_mean: the mean of the population :type population_mean: float :param population_var: the variance of the population :type population_var: float :param pol: the polarisation to be tested, A or B :type pol: str :param tolerance: the number of sigma to allow being away from population value, defaults to 6.0 :type tolerance: float, optional """ errors: List[str] = [] ipol = 0 if pol == "A" else 1 for channel, (sample_mean, sample_var, num_samples) in channel_data[ ["Mean", "Var.", "Num. Samples"] ].iterrows(): try: sample_stats = SampleStatistics( num_samples=num_samples, mean=sample_mean, variance=sample_var, ) assert_statistics( population_mean=population_mean, population_var=population_var, sample_stats=sample_stats, channel=int(channel), # type: ignore pol=ipol, tolerance=tolerance, ) except AssertionError as e: errors.append(str(e)) assert len(errors) == 0, f"Expected no errors. Error messages = {errors}"