Source code for ska_sdp_config.backend.etcd3

"""Etcd3 backend for SKA SDP configuration DB."""
# pylint: disable=duplicate-code

from __future__ import annotations

import logging
import os
import time
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, cast

import etcd3
import requests
import semantic_version

from .backend import Backend, DbRevision, DbTransaction, Lease, RecurseType
from .common import (
    ConfigCollision,
    ConfigVanished,
    _check_path,
    _tag_depth,
    _untag_depth,
    depth_of_path,
)

if TYPE_CHECKING:
    from .etcd3_watcher import Etcd3Watcher

LOGGER = logging.getLogger(__name__)

# Change the log level for the imported package 'etcd3'
# and the dependent package 'urllib3'
lowerLevelLog = os.getenv("SDP_CONFIG_ETCD3_LOG_LEVEL", "INFO")
logging.getLogger("etcd3").setLevel(lowerLevelLog)
logging.getLogger("urllib3").setLevel(lowerLevelLog)


[docs] class Etcd3Backend(Backend): """ Highly consistent database backend store. See https://github.com/kragniz/python-etcd3 """ def __init__( self, host="localhost", port="2379", max_retries: int = 15, retry_time: float = 0.1, **kw_args, ): """ Instantiate the database client. """ self._max_retries = max_retries self._retry_time = retry_time # Make endpoint for (presumably singular) host with retry timeout of 10 # seconds - that is the fastest that a gRPC connection can exit the # UNAVAILABLE state apparently. endpoint = etcd3.Endpoint( host, port, secure=kw_args.get("uses_secure_channel"), creds=kw_args.get("creds"), opts=kw_args.get("grpc_options"), time_retry=10, ) # Create "multi-endpoint" client with failover so that we get # NoServerAvailableError raised. self._client = etcd3.MultiEndpointEtcd3Client( endpoints=[endpoint], failover=True, **kw_args ) self._verify_server_version() def _verify_server_version(self): """ Verify that etcd server release is new enough to guarantee correct packet order for progress notififcations """ # Get version via HTTP endpoint = self._client.endpoint_in_use response = requests.get( f"{endpoint.protocol}://{endpoint.netloc}/version", timeout=0.3 ) # 300ms will do? response.raise_for_status() # Progress notifications are handled correctly from 3.4.26 and # 3.5.8 forward. We also allow 3.6.0 prerelease so development # builds work. ver = semantic_version.Version( response.json().get("etcdserver", "0.0.0") ) LOGGER.debug("Detected etcd server version %s", ver) spec = semantic_version.NpmSpec( ">=3.4.26 <3.5 || >=3.5.8 || >=3.6.0-a" ) if ver not in spec: raise RuntimeError( f"Etcd Server version is {ver}, need 3.4.26 or 3.5.8 for " "watcher to work correctly!" ) def _retry_loop(self, code_to_try: Callable) -> Any: """ Helper that retries code if an exception gets thrown that typically indicates a loss of connection. Note that this *can* rarely mean that the effect of the code in question was executed multiple times. """ # Retry loop retry_time = self._retry_time for i in range(self._max_retries): # Common retry code def log_exception(ex, i): LOGGER.warning( "Caught %s, retry %d after %gs", repr(ex), i, retry_time, ) # Run the code, catching typical exceptions try: return code_to_try() except etcd3.exceptions.ConnectionFailedError as ex: log_exception(ex, i) except etcd3.exceptions.ConnectionTimeoutError as ex: log_exception(ex, i) except etcd3.exceptions.NoServerAvailableError as ex: log_exception(ex, i) # Delay before next iteration time.sleep(retry_time) retry_time *= 1.5 # back off # Attempt one final time - without safety net return code_to_try()
[docs] def get( self, path: str, revision: Optional[DbRevision] = None ) -> tuple[str, DbRevision]: # Check/prepare parameters _check_path(path) tagged_path = _tag_depth(path) rev = None if revision is None else revision.revision # Get value and revision range_response = self._retry_loop( lambda: self._client.get_response(tagged_path, revision=rev) ) # handle non-existence of key if range_response.count < 1: value = None else: popped_kv_pair = range_response.kvs.pop() value = popped_kv_pair.value.decode("utf-8") # set revision whether key exists or not revision = range_response.header.revision rev = DbRevision(revision) return value, rev
[docs] def create( self, path: str, value: str, lease: Optional[etcd3.Lease] = None ) -> None: # Prepare parameters _check_path(path) tagged_path = _tag_depth(path) lease_id = 0 if lease is None else lease.id value = str(value) response = self._retry_loop( lambda: self._client.put_if_not_exists( tagged_path, value, lease_id ) ) if not response: raise ConfigCollision( path, f"Cannot create {path}, as it already exists!" )
[docs] def update( self, path: str, value: str, ) -> None: # Validate parameters _check_path(path) tagged_path = _tag_depth(path) value = str(value) # Execute in a transaction. # Supported operators are equality/less/greater (not boolean). status, _ = self._retry_loop( lambda: self._client.transaction( compare=[self._client.transactions.version(tagged_path) > 0], success=[self._client.transactions.put(tagged_path, value)], failure=[], ) ) if not status: raise ConfigVanished( path, f"Cannot update {path}, as it does not exist!" )
# pylint: disable=cell-var-from-loop
[docs] def list_keys( self, path: str, recurse: RecurseType = 0, revision: Optional[DbRevision] = None, with_values: bool = False, ) -> tuple[list[str], DbRevision]: """ List keys under given path. :param path: Prefix of keys to query. Append '/' to list child paths. :param recurse: Maximum recursion level to query. If iterable, cover exactly the recursion levels specified. :param revision: Database revision for which to list :param with_values: Also return key values and mod revisions (i.e. sorted list of key-value-rev tuples) :returns: (sorted key list, DbRevision object) """ # Prepare parameters path_depth = depth_of_path(path) rev = None keys = [] if with_values: vals = [] revs = [] if revision is not None: rev = revision.revision if isinstance(recurse, Iterable): depth_iter = iter(recurse) else: depth_iter = range(recurse + 1) for depth in depth_iter: tagged_path = _tag_depth(path, depth + path_depth) range_response = self._retry_loop( lambda: self._client.get_prefix_response( tagged_path, revision=rev, keys_only=not with_values ) ) if rev is None: rev = range_response.header.revision for kv_pair in range_response.kvs: keys.append(_untag_depth(kv_pair.key)) if with_values: vals.append(kv_pair.value) revs.append(DbRevision(rev)) revision = DbRevision(rev) if range_response is None: return [], revision if with_values: return ( sorted(zip(keys, vals, revs), key=lambda kv: kv[0]), revision, ) return sorted(keys), revision
[docs] def lease(self, ttl: float = 10) -> Lease: """ Generate a new lease. Once entered, it can be associated with keys which will be kept alive until the end of the lease. Note that this involves starting a daemon thread that will refresh the lease periodically (default seems to be TTL/4). :param ttl: Time to live for lease :return: lease object """ return self._retry_loop( lambda: cast(Lease, self._client.lease(ttl=ttl)) )
[docs] def txn(self, max_retries: int = 64) -> Iterable["Etcd3Transaction"]: for txn in Etcd3Transaction(self, self._client, max_retries): yield txn
[docs] def watcher( self, timeout=None, txn_wrapper: Callable[["Etcd3Transaction"], object] = None, requery_progress: float = 0.2, ) -> Iterable[Etcd3Watcher]: """Create a new watcher. Useful for waiting for changes in the configuration. See :py:class:`Etcd3Watcher`. :param timeout: Timeout for waiting. Watcher will loop after this time. :param txn_wrapper: Function to wrap transactions returned by the wrapper. :param requery_progress: How often we "refresh" the current database state for watcher transactions even without watcher notification (upper bound on how "stale" non-watched values retrieved in transactions can be) :returns: Watcher iterator """ # To get around cyclic imports # pylint: disable=import-outside-toplevel from .etcd3_watcher import Etcd3Watcher return Etcd3Watcher( self, self._client, timeout, txn_wrapper, requery_progress )
# pylint: disable=cell-var-from-loop def _delete_recursive( self, path: str, must_exist: bool = True, prefix: bool = False, max_depth: int = 16, ): # Factored out from delete due to too high cognitive complexity. depth = depth_of_path(path) delete_count = 0 for level in range(depth + 1, depth + max_depth): dpath = _tag_depth(path if prefix else path + "/", level) prefix_response = self._retry_loop( lambda: self._client.delete_prefix(dpath) ) if prefix_response: delete_count += 1 response = delete_count >= 1 if not response and must_exist: raise ConfigVanished( path, f"Cannot delete {path}, as it does not exist!" ) # pylint: disable=too-many-arguments
[docs] def delete( self, path: str, must_exist: bool = True, recursive: bool = False, prefix: bool = False, max_depth: int = 16, ): # Prepare parameters tagged_path = _tag_depth(path) if prefix: prefix_response = self._retry_loop( lambda: self._client.delete_prefix(tagged_path) ) response = prefix_response.deleted >= 1 else: response = self._retry_loop( lambda: self._client.delete(tagged_path) ) if not response and must_exist: raise ConfigVanished( path, f"Cannot delete {path}, as it does not exist!" ) if recursive: self._delete_recursive(path, must_exist, prefix, max_depth)
[docs] def close(self) -> None: self._client.close()
# pylint: disable=too-many-instance-attributes
[docs] class Etcd3Transaction(DbTransaction): """ A series of queries and updates to be executed atomically. """ def __init__( self, backend: Etcd3Backend, client: etcd3.client, max_retries: int = 64, ): """Initialise transaction.""" super().__init__(backend) self._client = client self._max_retries = max_retries self._revision = None # Revision baked in after first read self._get_queries: dict[ str, tuple[str, DbRevision] ] = {} # Get query log self._updates: dict[ str, tuple[Optional[str], Optional[Lease]] ] = {} # Delayed updates self._list_queries: dict[ tuple[str, int], tuple[list[str], DbRevision] ] = {} self._committed = False self._retries = 0 self._commit_callbacks: list[Callable[[], None]] = [] @property def revision(self) -> int: """The last-committed database revision. Only valid to call after the transaction has been committed. :returns: revision from DbRevision """ if not self._committed: raise RuntimeError( "Revision is undefined on an uncommitted transaction!" ) return self._revision.revision def _ensure_uncommitted(self) -> None: if self._committed: raise RuntimeError("Attempted to modify committed transaction!") # pylint: disable=duplicate-code
[docs] def get(self, path: str) -> Optional[str]: """ Get value of a key. :param path: Path of key to query :returns: Key value. None if it doesn't exist. """ self._ensure_uncommitted() # Check whether it was written as part of this transaction if path in self._updates: return self._updates[path][0] # Check whether we already have the request response if path in self._get_queries: return self._get_queries[path][0] # Perform get request # rev is from KVMetadata val, rev = self._get_queries[path] = self.backend.get( path, revision=self._revision ) if self._revision is None: self._revision = rev return val
[docs] def create( self, path: str, value: str, lease: Optional[etcd3.Lease] = None ) -> None: self._ensure_uncommitted() value = str(value) # Attempt to get the value - mainly to check whether it exists # and put it into the query log result = self.get(path) if result is not None: raise ConfigCollision( path, f"Cannot create {path}, as it already exists!" ) # Add update request self._updates[path] = (value, lease)
# pylint: disable=duplicate-code
[docs] def update(self, path: str, value: str) -> None: self._ensure_uncommitted() value = str(value) result = self.get(path) if result is None: raise ConfigVanished( path, f"Cannot update {path}, as it does not exist!" ) # Add update request self._updates[path] = (value, None)
# pylint: disable=too-many-arguments
[docs] def delete( self, path: str, must_exist: bool = True, recursive: bool = False, max_depth: int = 16, prefix: bool = False, ) -> None: keys = [] if prefix: keys = self.list_keys(path, recurse=max_depth if recursive else 0) else: if self.get(path) is not None: keys = [path] if recursive: keys += self.list_keys(path + "/", recurse=max_depth) if must_exist and not keys: raise ConfigVanished( path, f"Cannot delete {path}, it does not exist!" ) # Add delete request for key in keys: self._updates[key] = (None, None)
def _compare_list(self, txn: etcd3.Transactions) -> list: # Create list to store revision comparisons to pass to # compare operation in transaction compare_list = [] # For every get call add revision comparison to compare list for path, (_, rev) in self._get_queries.items(): tagged_path = _tag_depth(path) if rev.revision is None: # key did not exist? Verify it still doesn't exist. # Note that the key could have been created and # deleted in the meantime. compare_list.append(txn.version(tagged_path) == 0) else: # Otherwise, add an assertion to the compare list that # checks that the revision has not changed. # This guarantees the key has not been modified # since we last read it. compare_list.append(txn.mod(tagged_path) < rev.revision + 1) # Verify list_keys calls from the query log for (path, depth), (result, rev) in self._list_queries.items(): tagged_path = _tag_depth(path, depth) # check returned list of keys still exist for res_path in result: tagged_res_path = _tag_depth(res_path) compare_list.append(txn.version(tagged_res_path) > 0) # check no new keys have been added to the returned list # by checking whether the request contains any keys with # create revisions newer than the embedded revision of the # request tagged_path_end = etcd3.utils.prefix_range_end(tagged_path) compare_list.append( txn.create(tagged_path, tagged_path_end) < self._revision.revision + 1 ) return compare_list def _success_list(self, txn: etcd3.Transactions) -> list: # Create list to store put and delete to pass to success # operation in transaction success_list = [] # For every update add a put or delete to the success list for path, (value, lease) in self._updates.items(): tagged_path = _tag_depth(path) lease_id = None if lease is None else lease.id if value is None: success_list.append(txn.delete(tagged_path, value)) else: success_list.append( txn.put(tagged_path, value, lease=lease_id) ) return success_list # pylint: disable=protected-access
[docs] def commit(self) -> bool: self._ensure_uncommitted() # If we have made no updates, we don't need to verify the get query log if not self._updates: self._committed = True return True # Use the transaction from the etcd3 client txn: etcd3.Transactions = self._client.transactions # The client transaction method carries out the actions # in the success_list if all assertions in the compare_list # are true. succeeded, _ = self.backend._retry_loop( lambda: self._client.transaction( compare=self._compare_list(txn), success=self._success_list(txn), failure=[], ) ) # Done self._committed = True if succeeded: for callback in self._commit_callbacks: callback() self._commit_callbacks = [] return succeeded
[docs] def on_commit(self, callback: Callable[[], None]) -> None: """Register a callback to call when the transaction succeeds. Exists mostly to enable test cases. :param callback: Callback to call """ self._commit_callbacks.append(callback)
[docs] def reset(self, revision: Optional[DbRevision] = None) -> None: if not self._committed: raise RuntimeError("Called reset on an uncommitted transaction!") self._revision = revision self._get_queries.clear() self._list_queries.clear() self._updates.clear() self._committed = False
[docs] def list_keys(self, path: str, recurse: RecurseType = 0) -> list[str]: self._ensure_uncommitted() # Walk through depths, collecting known keys if isinstance(recurse, Iterable): depth_iter = iter(recurse) else: depth_iter = range(recurse + 1) keys: list[str] = [] for depth in depth_iter: tagged_path = _tag_depth(path, depth_of_path(path) + depth) matching_vals = [ kv_pair for kv_pair in self._updates.items() if _tag_depth(kv_pair[0]).startswith(tagged_path) ] added_keys = { key for key, val in matching_vals if val[0] is not None } removed_keys = { key for key, val in matching_vals if val[0] is None } query = (path, depth + depth_of_path(path)) if query not in self._list_queries: self._list_queries[query] = self.backend.list_keys( path, recurse=(depth,) ) # Add to key set result, rev = self._list_queries[query] keys.extend(set(result) - removed_keys | added_keys) # Bake in revision if not already done so if self._revision is None: self._revision = rev # Sort return sorted(keys)
def __iter__(self) -> "Etcd3Transaction": """ Iterate transaction until it succeeds. """ try: while self._retries <= self._max_retries: # Should build up a transaction yield self # Try to commit, count how many times we have tried if not self.commit(): self._retries += 1 else: self._retries = 0 return self.reset() finally: if self._updates and not self._committed: LOGGER.warning( "Transaction loop aborted - dropping updates to %s!", list(self._updates.keys()), ) raise RuntimeError( f"Transaction did not succeed after {self._max_retries} retries!" )