Source code for ska_sdp_piper.extensions.sphinx_piper_pipeline

# pragma: exclude from coverage
import inspect
from importlib import import_module

from docutils import nodes
from docutils.parsers.rst import Directive, Parser
from docutils.utils import new_document
from sphinx.util import logging
from sphinx.util.docstrings import prepare_docstring

from .. import __version__
from ..piper.pipeline import Pipeline
from .common import COL_WIDTHS, HEADER_ROW, ConfigTableRow, dataclass_to_list

# This list is taken from napolean extension documentation
# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html#docstring-sections
NAPOLEAN_DOCSTYLE_DIRECTIVES = (
    "Args",
    "Arguments",
    "Attention",
    "Attributes",
    "Caution",
    "Danger",
    "Error",
    "Example",
    "Examples",
    "Hint",
    "Important",
    "Keyword Args",
    "Keyword Arguments",
    "Methods",
    "Note",
    "Notes",
    "Other Parameters",
    "Parameters",
    "Return",
    "Returns",
    "Raise",
    "Raises",
    "References",
    "See Also",
    "Tip",
    "Todo",
    "Warning",
    "Warnings",
    "Warn",
    "Warns",
    "Yield",
    "Yields",
)
SPHINX_REST_DIRECTIVES = (":param", ":type", ":raises", ":return", ":rtype")

logger = logging.getLogger(__name__)


[docs] class PiperPipelineConfigDirective(Directive): has_content = False required_arguments = 1 # e.g., app.pipeline
[docs] def run(self): dotted_path = self.arguments[0] try: pipeline: Pipeline = self._resolve_object(dotted_path) except Exception as e: logger.warning(f"Could not resolve pipeline: {dotted_path!r}: {e}") return [] if not isinstance(pipeline, Pipeline): logger.warning( f"The object {dotted_path} is not an instance of Pipeline" ) return [] all_nodes = [] for stage in pipeline._stages: # Heading with stage name stage_name = stage.name section_title = nodes.title(text=stage_name) section = nodes.section(ids=[nodes.make_id(stage_name)]) section += section_title # Extract docstring summary from stage_definition summary = self._extract_summary(stage.stage_definition) parsed_nodes = self._parse_rst(summary) section.extend(parsed_nodes) # Config table config_rows = stage.generate_config_rows_for_stage() if len(config_rows) > 0: table_node = self._build_config_table(config_rows) section += table_node all_nodes.append(section) return all_nodes
def _parse_rst(self, text): """Parse reST text into a list of docutils nodes.""" source = self.state.document.current_source or "<auto>" temp_doc = new_document(source, self.state.document.settings) parser = Parser() parser.parse(text, temp_doc) return temp_doc.children def _resolve_object(self, dotted_path: str): module_path, _, attr = dotted_path.rpartition(".") module = import_module(module_path) return getattr(module, attr) def _extract_summary(self, function) -> str: """ Given a python function, extract its description part from the docstring. It returns a string which contains the part of docstring until it finds the first item which is either a "numpy docstyle" OR a sphinx reST directive. """ docstring = inspect.getdoc(function) or "" lines = prepare_docstring(docstring) directives_to_start_with = tuple( word.lower() for word in (NAPOLEAN_DOCSTYLE_DIRECTIVES + SPHINX_REST_DIRECTIVES) ) summary_lines = [] for line in lines: if line.strip().lower().startswith(directives_to_start_with): break summary_lines.append(line) return "\n".join(summary_lines).strip() def _build_config_table(self, data_rows: list[ConfigTableRow]): table = nodes.table() tgroup = nodes.tgroup(cols=len(COL_WIDTHS)) table += tgroup for col_width in COL_WIDTHS: tgroup += nodes.colspec(colwidth=col_width) table_rows = map(dataclass_to_list, [HEADER_ROW, *data_rows]) # Header Row thead = nodes.thead() header_row_node = nodes.row() for col in next(table_rows): entry = nodes.entry() entry += nodes.paragraph(text=col) header_row_node += entry thead += header_row_node tgroup += thead # Data Rows tbody = nodes.tbody() for row in table_rows: data_row_node = nodes.row() for col in row: entry = nodes.entry() entry.extend(self._parse_rst(col)) data_row_node += entry tbody += data_row_node tgroup += tbody return table
[docs] def setup(app): app.add_directive("pipelineconfig", PiperPipelineConfigDirective) return {"version": __version__, "parallel_read_safe": True}