"""
This module contains implementations of the AbstractRepository class, using
Postgres as the data store.
"""
import logging
from typing import Dict, List, Optional, Tuple, TypeVar, Union
from psycopg import Connection, sql
from psycopg.errors import Error, InvalidTextRepresentation, UniqueViolation
from psycopg.types.json import set_json_loads
from ska_oso_pdm import Metadata, SBDefinition
from ska_db_oda.persistence.domain import OSOEntity, get_identifier_or_fetch_from_skuid
from ska_db_oda.persistence.domain.errors import (
ODAError,
ODANotFound,
QueryParameterError,
StatusHistoryException,
UniqueConstraintViolation,
)
from ska_db_oda.persistence.domain.query import (
CustomQuery,
DateQuery,
QueryParams,
StatusQuery,
UserQuery,
)
from ska_db_oda.persistence.domain.repository import RepositoryBridge
from ska_db_oda.persistence.infrastructure.postgres.mapping import PostgresMapping
from ska_db_oda.persistence.infrastructure.postgres.sqlqueries import (
count_identifier_query,
count_query,
get_metadata_query,
insert_query,
result_to_metadata,
select_by_custom_query,
select_by_date_query,
select_by_user_query,
select_latest_query,
select_latest_relationship_query,
select_status_by_id_query,
update_query,
)
LOGGER = logging.getLogger(__name__)
T = TypeVar("T", bound=OSOEntity)
U = TypeVar("U")
SqlTypes = Union[str, int]
[docs]
class PostgresBridge(RepositoryBridge[T, U]):
"""
Implementation of the Repository bridge which persists entities in a PostgreSQL instance.
"""
def __init__(self, postgres_mapping: PostgresMapping, connection: Connection):
"""
The initialization of a Repository is the responsibility of the UoW, which should inject the connection.
The lifespan of a PostgresRepository instance is the same as a connection.
"""
self._connection = connection
self._postgres_mapping = postgres_mapping
[docs]
def create(self, entity: T, user: str = None) -> T:
"""Implementation of the RepositoryBridge method.
See :func:`~ska_db_oda.persistence.domain.repository.RepositoryBridge.create` docstring for details
"""
entity_id = get_identifier_or_fetch_from_skuid(entity)
LOGGER.debug("Creating version of entity with ID %s in postgres", entity_id)
entity = self.update_metadata_increment_version(entity, user)
query, params = insert_query(self._postgres_mapping.table_details, entity)
self._execute(query, params)
return entity
[docs]
def read(
self,
entity_id: U,
version: U = None,
) -> T:
"""Implementation of the RepositoryBridge method.
See :func:`~ska_db_oda.persistence.domain.repository.RepositoryBridge.read` docstring for details
:param entity_id: provided entity_id for filter records
:result: result based on mapped entity
"""
LOGGER.debug("Getting entity with ID %s from postgres", entity_id)
query, params = select_latest_query(
self._postgres_mapping.table_details, entity_id, version
)
result = self._execute_and_return_row(query, params)
if result is None:
raise ODANotFound(identifier=entity_id)
return self._postgres_mapping.result_to_entity(result)
[docs]
def read_relationship(
self, entity_id: U, parent_entity: U, associated_entity: U
) -> T:
"""Implementation of the RepositoryBridge method.
Generate query based on parent_entity, associated_entity and execute the query
to get the result.
See :func:`~ska_db_oda.persistence.domain.repository.RepositoryBridge.read` docstring for details
:param entity_id: provided entity_id for filter records
:param parent_entity: relational primary entity table name
:param associated_entity: relational secondary entity table name
:result: Result based on parent and associated entity relationship on given entity id
"""
LOGGER.debug("Getting entity with ID %s from postgres", entity_id)
query, params = select_latest_relationship_query(
entity_id,
self._postgres_mapping.get_mapping(parent_entity),
self._postgres_mapping.get_mapping(associated_entity),
)
result = self._execute_and_return_row(
query, params, parent_entity, associated_entity
)
if result is None:
raise ODANotFound(identifier=entity_id)
return self._postgres_mapping.result_to_entity_relationship(
result, parent_entity, associated_entity
)
[docs]
def update(self, entity: T) -> T:
"""Implementation of the RepositoryBridge method.
See :func:`~ska_db_oda.persistence.domain.repository.RepositoryBridge.update` docstring for details
"""
entity_id = get_identifier_or_fetch_from_skuid(entity)
LOGGER.debug("Updating entity with ID %s in postgres", entity_id)
entity = self.update_metadata_maintain_version(entity)
query, params = update_query(self._postgres_mapping.table_details, entity)
# Postgres upsert equivalent doesn't quite work for us as there is no database constraint
# on the identifier. Instead, we do a conditional update, and if this doesn't return a result
# (ie no rows match the identifier + version) then we create a new row
result = self._execute_and_return_row(query, params)
if not result:
self.create(entity)
return entity
[docs]
def query(
self,
qry_params: QueryParams,
) -> List[T]:
match qry_params:
case DateQuery():
query, params = select_by_date_query(
self._postgres_mapping.table_details, qry_params=qry_params
)
case UserQuery():
query, params = select_by_user_query(
self._postgres_mapping.table_details, qry_params=qry_params
)
case StatusQuery():
query, params = select_status_by_id_query(
self._postgres_mapping.table_details, qry_params=qry_params
)
case CustomQuery():
query, params = select_by_custom_query(
self._postgres_mapping.table_details, qry_params=qry_params
)
case _:
raise QueryParameterError(qry_params=qry_params)
result = self._execute_and_return_rows(query, params)
if result is None:
return []
# In status history tables, we want to return the full row, not just the entity. and info is not a column
# The jsonb column 'info' contains the full entity, which will have been deserialised into a PDM object.
return [
(
row["info"]
if "info" in row.keys()
else self._postgres_mapping.result_to_entity(row)
)
for row in result
]
def _get_latest_metadata(self, entity: T) -> Optional[Metadata]:
"""Implementation of the abstract MetaDataMixin method for a Postgres backend.
See :func:`~ska_db_oda.persistence.domain.metadatamixin.MetadataMixin._get_latest_metadata` docstring for details
"""
query, params = get_metadata_query(self._postgres_mapping.table_details, entity)
result = self._execute_and_return_row(query, params)
if result:
if isinstance(entity, SBDefinition):
return result_to_metadata(result)
return result_to_metadata(result)
return None
def _execute_and_return_row(
self,
query: sql.Composed,
params: Tuple,
parent_entity: str = None,
associated_entity: str = None,
) -> Optional[Dict]:
"""
Executes a query which returns a single row.
"""
try:
if not parent_entity and not associated_entity:
set_json_loads(self._postgres_mapping.jsonb_load, self._connection)
LOGGER.info(
"Executing query: %s", query.as_string(self._connection) % params
)
result = self._connection.execute(
query,
params,
).fetchone()
return result
except InvalidTextRepresentation as err:
LOGGER.exception(err)
raise StatusHistoryException(err.args) from err
except UniqueViolation as err:
LOGGER.exception(err)
raise UniqueConstraintViolation(err.args[0]) from err
except Error as err:
msg = f"Error whilst executing query: {err.args}"
LOGGER.exception(msg)
raise ODAError(msg) from err
def _execute_and_return_rows(
self, query: sql.Composed, params: Tuple
) -> Optional[List[Dict]]:
"""
Executes a query which returns multiple rows.
"""
try:
set_json_loads(self._postgres_mapping.jsonb_load, self._connection)
sql_query = query.as_string(self._connection) % params
msg = f"Executing query: {sql_query}"
LOGGER.info(msg)
result = self._connection.execute(query, params).fetchall()
return result
except InvalidTextRepresentation as err:
LOGGER.exception(err)
raise StatusHistoryException(err.args) from err
except UniqueViolation as err:
LOGGER.exception(err)
raise UniqueConstraintViolation(err.args[0]) from err
except Error as err:
msg = f"Error whilst executing query: {err.args}"
LOGGER.exception(msg)
raise ODAError(msg) from err
def _execute(self, query: sql.Composed, params: Tuple) -> None:
"""
Executes a query without returning a value
"""
try:
set_json_loads(self._postgres_mapping.jsonb_load, self._connection)
LOGGER.debug(
"Executing query: %s", query.as_string(self._connection) % params
)
self._connection.execute(query, params)
except InvalidTextRepresentation as err:
LOGGER.exception(err)
raise StatusHistoryException(err.args) from err
except UniqueViolation as err:
LOGGER.exception(err)
raise UniqueConstraintViolation(err.args[0]) from err
except Error as err:
msg = f"Error whilst executing query: {err.args}"
LOGGER.exception(msg)
raise ODAError(msg) from err
def __len__(self):
result = self._connection.execute(
*count_query(self._postgres_mapping.table_details)
).fetchone()
return result["count"]
def __contains__(self, entity_id: str) -> bool:
result = self._connection.execute(
*count_identifier_query(self._postgres_mapping.table_details, entity_id)
).fetchone()
return result["count"] > 0