"""
This module contains implementations of the AbstractRepository class, using
the filesystem as the data store.
"""
import functools
import logging
import operator
import os
import re
from os import W_OK, PathLike, access, environ
from pathlib import Path
from typing import Dict, List, Optional, TypeVar, Union
from ska_oso_pdm import Metadata
from ska_db_oda.domain import (
OSOEntity,
get_identifier,
get_identifier_or_fetch_from_skuid,
)
from ska_db_oda.domain.query import DateQuery, MatchType, QueryParams, UserQuery
from ska_db_oda.domain.repository import RepositoryBridge
from ska_db_oda.infrastructure.filesystem.mapping import FilesystemMapping
LOGGER = logging.getLogger(__name__)
T = TypeVar("T", bound=OSOEntity)
U = TypeVar("U")
[docs]
class FilesystemBridge(RepositoryBridge[T, U]):
"""
Implementation of the Repository bridge which persists entities to a filesystem.
Entities will be stored under the following filesystem structure:
`/<base_working_dir>/<entity_type_dir/<entity_id>/<version>.json`
For example, by default version 1 of an SBDefinition with sbd_id sbi-mvp01-20200325-00001
will be stored at: `/var/lib/oda/sbd/sbi-mvp01-20200325-00001/1.json`
"""
def __init__(
self,
filesystem_mapping: FilesystemMapping,
base_working_dir: Union[str, PathLike] = Path("/var/lib/oda"),
):
base_working_dir = Path(environ.get("ODA_DATA_DIR", base_working_dir))
LOGGER.info(
"Initialising ODA filesystem backend. Working directory=%s",
base_working_dir,
)
if not base_working_dir.is_dir():
raise FileNotFoundError(f"Directory {base_working_dir} not found")
if not access(base_working_dir, W_OK):
raise PermissionError(f"Directory {base_working_dir} not writable")
self._base_working_dir = base_working_dir
self.working_dir = self._base_working_dir / filesystem_mapping.entity_type_dir
Path(self.working_dir).mkdir(parents=True, exist_ok=True)
self._transactions: Dict[Path, str] = {}
self._serialise = filesystem_mapping.serialise
self._deserialise = filesystem_mapping.deserialise
self._entity_id_from_path = filesystem_mapping.entity_id_from_path
def __len__(self):
"""
Return the size of this repository.
Note that this is a naive implementation that simply counts the number
JSON files in the working directory. It does not verify that each JSON
file is a serialised, valid SB.
"""
return sum(1 for _ in self.working_dir.rglob("*.json"))
def __contains__(self, entity_id: U):
"""
Return True if a version of an entity with the given ID is present in this repository.
:param entity_id: ID to search for
"""
entity_dir = self._path_for_entity_id_dir(entity_id)
return bool(entity_dir.exists() and os.listdir(entity_dir))
[docs]
def create(self, entity: T) -> T:
"""Implementation of the RepositoryBridge method.
To mimic the real database, entities are added to a list of pending transactions and only
written to the filesystem when the unit of work is committed.
See :func:`~ska_db_oda.domain.repository.RepositoryBridge.create` docstring for details
"""
entity_id = get_identifier_or_fetch_from_skuid(entity)
entity = self.update_metadata(entity)
Path(self._path_for_entity_id_dir(entity_id)).mkdir(parents=True, exist_ok=True)
entity_path = self._path_for_entity(entity)
LOGGER.debug(
"Adding entity with ID %s to the filesystem transactions under path %s",
entity_id,
entity_path,
)
serialised_entity = self._serialise(entity)
self._transactions[entity_path] = serialised_entity
return entity
[docs]
def read(self, entity_id: U) -> T:
"""
Gets the latest version of the entity with the given entity_id.
As this method will always be accessed in the context of a UnitOfWork, the pending transactions
also need to be checked for a version to return.
(Similar to with a database implementation where an entity that was added to a transaction but
not committed would still be accessible inside the transaction.)
"""
LOGGER.debug("Getting entity with ID %s from the filesystem", entity_id)
pending_versions = [
int(os.path.splitext(path.name)[0])
for path in self._transactions.keys()
if entity_id in str(path)
]
entity_dir_path = self._path_for_entity_id_dir(entity_id)
# The metadata checks will mean a version in the pending transactions would always
# be a newer version than any in the filesystem, so check the pending versions first.
if pending_versions:
latest_entity_path = entity_dir_path / f"{max(pending_versions)}.json"
return self._deserialise(self._transactions[latest_entity_path])
elif entity_dir_path.exists():
# Filenames are of the form 1.json
versions = [
int(os.path.splitext(entity_path.name)[0])
for entity_path in entity_dir_path.glob("*.json")
]
if versions:
latest_entity_path = entity_dir_path / f"{max(versions)}.json"
return self._deserialise(latest_entity_path.read_text())
raise KeyError(
f"Not found. The requested entity_id {entity_id} could not be found."
)
[docs]
def update(self, entity: T) -> T:
"""Implementation of the RepositoryBridge method.
To mimic the real database, entities are added to a list of pending transactions and only
written to the filesystem when the unit of work is committed.
See :func:`~ska_db_oda.domain.repository.RepositoryBridge.update` docstring for details
"""
entity_id = get_identifier_or_fetch_from_skuid(entity)
entity = self._set_new_metadata(entity)
Path(self._path_for_entity_id_dir(entity_id)).mkdir(parents=True, exist_ok=True)
entity_path = self._path_for_entity(entity)
LOGGER.debug(
"Adding entity with ID %s to the filesystem transactions under path %s",
entity_id,
entity_path,
)
serialised_entity = self._serialise(entity)
self._transactions[entity_path] = serialised_entity
return entity
[docs]
def query(self, qry_params: QueryParams) -> List[T]:
# strategy for this implementation is to:
#
# 1. create a list of filter functions matching the requirements of the query
# 2. for each entity in the repo, apply the filter functions
# 3. if the entity passes each test, add its ID to the list of results
#
# With this strategy we can reuse filter functions to build compound
# complex queries, e.g., entities created by user X after 1/1/2023
final_result = []
filter_fns = QueryFilterFactory.filter_functions_for_query(qry_params)
for entity_id in self._all_entity_ids():
entity = self.read(entity_id)
if all(fn(entity) for fn in filter_fns):
final_result.append(entity)
return final_result
def _get_latest_metadata(self, entity: T) -> Optional[Metadata]:
"""Implementation of the abstract MetaDataMixin method for a filesystem backend.
See :func:`~ska_db_oda.domain.metadatamixin.MetadataMixin._get_latest_metadata` docstring for details
"""
try:
return self.read(get_identifier(entity)).metadata
except KeyError:
return None
def _path_for_entity(self, entity: T) -> Path:
"""
Returns the final part of path where the serialised entity is stored,
eg `sbd/sbi-mvp01-20200325-00001/2.json`
"""
return (
self.working_dir
/ get_identifier(entity)
/ f"{entity.metadata.version}.json"
)
def _path_for_entity_id_dir(self, entity_id: U) -> Path:
"""
Returns the path of the directory that all versions of the entity with the given
entity_id are stored under, eg `/var/lib/oda/sbd/sbi-mvp01-20200325-00001/`
"""
return self.working_dir / str(entity_id)
def _all_entity_ids(self) -> List[U]:
"""
Return a list of entity IDs, one entity ID for each entity stored in
the repository.
"""
# path format is <entity ID>/<version>.json so the list of directories
# in the working directory should give a list of entity IDs.
# iterdir() gives paths whereas we entity IDs of the correct type,
# hence we call _entity_id_from_path for each directory
return [
self._entity_id_from_path(f)
for f in self.working_dir.iterdir()
if f.is_dir()
]
[docs]
class QueryFilterFactory:
"""
Factory class that returns a list of Python functions equivalent to a user
query. Each function processes an entity, returning True if the entity
passes the query test.
"""
@staticmethod
def filter_functions_for_query(query: QueryParams):
filter_fns = []
if isinstance(query, UserQuery):
filter_fns.append(QueryFilterFactory.match_editor(query))
elif isinstance(query, DateQuery):
filter_fns.append(QueryFilterFactory.filter_between_dates(query))
else:
raise ValueError(f"Unrecognised query: {query}")
return filter_fns
[docs]
@staticmethod
def match_editor(query: UserQuery):
"""
Returns a function that returns True if a document editor matches a
(sub)string.
"""
def regex_match(obj):
"""
creted new function for pattern match
"""
pattern = re.compile(r"{}".format(query.entity_id))
if pattern.search(obj):
return obj
if (query.user is None or query.user == "") and (
query.entity_id is None or query.entity_id == ""
):
raise ValueError(
f"User or Entity match must to be specified. Got {query.user!r} or"
f" {query.entity_id!r}"
)
if query.match_type == MatchType.EQUALS:
match_fn = functools.partial(
operator.eq, query.user if query.user else query.entity_id
)
elif query.match_type == MatchType.STARTS_WITH:
match_fn = operator.methodcaller(
"startswith", query.user if query.user else query.entity_id
)
elif query.match_type == MatchType.CONTAINS:
# created custom fuction for regex and pass to match_fn
match_fn = regex_match
else:
raise ValueError(f"Invalid match type: {query.match_type}")
def match(obj):
if query.user:
return match_fn(obj.metadata.created_by) or match_fn(
obj.metadata.last_modified_by
)
# id is difer for each entity so added if conditions for each classes
# SBDefinition has sbd_id, Project has prj_id
if query.entity_id:
return match_fn(get_identifier(obj))
return match
[docs]
@staticmethod
def filter_between_dates(query: DateQuery):
"""
Returns a function that returns True if a date is between a given range.
"""
if query.query_type == DateQuery.QueryType.CREATED_BETWEEN:
accessor = operator.attrgetter("metadata.created_on")
elif query.query_type == DateQuery.QueryType.MODIFIED_BETWEEN:
accessor = operator.attrgetter("metadata.last_modified_on")
else:
raise ValueError(f"Unrecognised date query type: {query.query_type}")
if query.start is None and query.end is None:
raise ValueError("Query start and query end can not be None")
if (
query.start is not None
and query.end is not None
and query.start >= query.end
):
raise ValueError("Query end date must be later than query start date")
def ge_start(value):
# always match if no start date specified
if query.start is None:
return True
return value.timestamp() >= query.start.timestamp()
def lt_end(value):
# always match if no end date specified
if query.end is None:
return True
return value.timestamp() < query.end.timestamp()
def match(obj):
o = accessor(obj)
return ge_start(o) and lt_end(o)
return match