Source code for ska_db_oda.infrastructure.postgres.repository

"""
This module contains implementations of the AbstractRepository class, using
Postgres as the data store.
"""
import logging
from typing import Dict, List, Optional, Tuple, TypeVar, Union

from psycopg import Connection, sql
from psycopg.errors import Error
from psycopg.types.json import set_json_loads
from ska_oso_pdm import Metadata, SBDefinition

from ska_db_oda.domain import OSOEntity, get_identifier_or_fetch_from_skuid
from ska_db_oda.domain.query import DateQuery, QueryParams, UserQuery
from ska_db_oda.domain.repository import RepositoryBridge
from ska_db_oda.infrastructure.postgres.mapping import PostgresMapping
from ska_db_oda.infrastructure.postgres.sqlqueries import (
    count_identifier_query,
    count_query,
    get_metadata_query,
    insert_query,
    result_to_metadata,
    select_by_date_query,
    select_by_user_query,
    select_latest_query,
    update_query,
)

LOGGER = logging.getLogger(__name__)

T = TypeVar("T", bound=OSOEntity)
U = TypeVar("U")
SqlTypes = Union[str, int]


[docs] class PostgresBridge(RepositoryBridge[T, U]): """ Implementation of the Repository bridge which persists entities in a PostgreSQL instance. """ def __init__(self, postgres_mapping: PostgresMapping, connection: Connection): """ The initialisation of a Repository is the responsibility of the UoW, which should inject the connection. The lifespan of a PostgresRepository instance is the same as a connection. """ self._connection = connection self._postgres_mapping = postgres_mapping
[docs] def create(self, entity: T) -> T: """Implementation of the RepositoryBridge method. See :func:`~ska_db_oda.domain.repository.RepositoryBridge.create` docstring for details """ entity_id = get_identifier_or_fetch_from_skuid(entity) LOGGER.debug("Creating version of entity with ID %s in postgres", entity_id) entity = self.update_metadata(entity) query, params = insert_query(self._postgres_mapping.table_details, entity) self._execute(query, params) return entity
[docs] def read(self, entity_id: U) -> T: """Implementation of the RepositoryBridge method. See :func:`~ska_db_oda.domain.repository.RepositoryBridge.read` docstring for details """ LOGGER.debug("Getting entity with ID %s from postgres", entity_id) query, params = select_latest_query( self._postgres_mapping.table_details, entity_id ) result = self._execute_and_return_row(query, params) if result is None: raise KeyError( f"Not found. The requested sbd_id {entity_id} could not be found." ) return self._postgres_mapping.result_to_entity(result)
[docs] def update(self, entity: T) -> T: """Implementation of the RepositoryBridge method. See :func:`~ska_db_oda.domain.repository.RepositoryBridge.update` docstring for details """ entity_id = get_identifier_or_fetch_from_skuid(entity) LOGGER.debug("Updating entity with ID %s in postgres", entity_id) entity = self._set_new_metadata(entity) query, params = update_query(self._postgres_mapping.table_details, entity) # Postgres upsert equivalent doesn't quite work for us as there is no database constraint # on the identifier. Instead, we do a conditional update, and if this doesn't return a result # (ie no rows match the identifier + version) then we create a new row result = self._execute_and_return_row(query, params) if not result: self.create(entity) return entity
[docs] def query(self, qry_params: QueryParams) -> List[T]: match qry_params: case DateQuery(): query, params = select_by_date_query( self._postgres_mapping.table_details, qry_params=qry_params ) case UserQuery(): query, params = select_by_user_query( self._postgres_mapping.table_details, qry_params=qry_params ) case _: raise ValueError( f"Unsupported query type {qry_params.__class__.__name__}" ) result = self._execute_and_return_rows(query, params) if result is None: raise KeyError( f"Not found. The requested sbd_id {result} could not be found." ) # The jsonb column 'info' contains the full entity, which will have been deserialised into a PDM object. return [row["info"] for row in result]
def _get_latest_metadata(self, entity: T) -> Optional[Metadata]: """Implementation of the abstract MetaDataMixin method for a Postgres backend. See :func:`~ska_db_oda.domain.metadatamixin.MetadataMixin._get_latest_metadata` docstring for details """ query, params = get_metadata_query(self._postgres_mapping.table_details, entity) result = self._execute_and_return_row(query, params) if result: if isinstance(entity, SBDefinition): return result_to_metadata(result) return result_to_metadata(result) return None def _execute_and_return_row( self, query: sql.Composed, params: Tuple ) -> Optional[Dict]: """ Executes a query which returns a single row. """ try: set_json_loads(self._postgres_mapping.jsonb_load, self._connection) LOGGER.info( "Executing query: %s", query.as_string(self._connection) % params ) result = self._connection.execute( query, params, ).fetchone() return result except Error as err: msg = f"Error whilst executing query: {err.args}" LOGGER.exception(msg) raise OSError(msg) from err def _execute_and_return_rows( self, query: sql.Composed, params: Tuple ) -> Optional[List[Dict]]: """ Executes a query which returns multiple rows. """ try: set_json_loads(self._postgres_mapping.jsonb_load, self._connection) LOGGER.info( "Executing query: %s", query.as_string(self._connection) % params ) result = self._connection.execute(query, params).fetchall() return result except Error as err: msg = f"Error whilst executing query: {err.args}" LOGGER.exception(msg) raise OSError(msg) from err def _execute(self, query: sql.Composed, params: Tuple) -> None: """ Executes a query without returning a value """ try: set_json_loads(self._postgres_mapping.jsonb_load, self._connection) LOGGER.debug( "Executing query: %s", query.as_string(self._connection) % params ) self._connection.execute(query, params) except Error as err: msg = f"Error whilst executing query: {err.args}" LOGGER.exception(msg) raise OSError(msg) from err def __len__(self): result = self._connection.execute( *count_query(self._postgres_mapping.table_details) ).fetchone() return result["count"] def __contains__(self, entity_id: str) -> bool: result = self._connection.execute( *count_identifier_query(self._postgres_mapping.table_details, entity_id) ).fetchone() return result["count"] > 0