# pragma: exclude from coverage
import inspect
from importlib import import_module
from docutils import nodes
from docutils.parsers.rst import Directive, Parser
from docutils.utils import new_document
from sphinx.util import logging
from sphinx.util.docstrings import prepare_docstring
from .. import __version__
from ..piper.pipeline import Pipeline
from .common import COL_WIDTHS, HEADER_ROW, ConfigTableRow, dataclass_to_list
# This list is taken from napolean extension documentation
# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html#docstring-sections
NAPOLEAN_DOCSTYLE_DIRECTIVES = (
"Args",
"Arguments",
"Attention",
"Attributes",
"Caution",
"Danger",
"Error",
"Example",
"Examples",
"Hint",
"Important",
"Keyword Args",
"Keyword Arguments",
"Methods",
"Note",
"Notes",
"Other Parameters",
"Parameters",
"Return",
"Returns",
"Raise",
"Raises",
"References",
"See Also",
"Tip",
"Todo",
"Warning",
"Warnings",
"Warn",
"Warns",
"Yield",
"Yields",
)
SPHINX_REST_DIRECTIVES = (":param", ":type", ":raises", ":return", ":rtype")
logger = logging.getLogger(__name__)
[docs]
class PiperPipelineConfigDirective(Directive):
has_content = False
required_arguments = 1 # e.g., app.pipeline
[docs]
def run(self):
dotted_path = self.arguments[0]
try:
pipeline: Pipeline = self._resolve_object(dotted_path)
except Exception as e:
logger.warning(f"Could not resolve pipeline: {dotted_path!r}: {e}")
return []
if not isinstance(pipeline, Pipeline):
logger.warning(
f"The object {dotted_path} is not an instance of Pipeline"
)
return []
all_nodes = []
for stage in pipeline._stages:
# Heading with stage name
stage_name = stage.name
section_title = nodes.title(text=stage_name)
section = nodes.section(ids=[nodes.make_id(stage_name)])
section += section_title
# Extract docstring summary from stage_definition
summary = self._extract_summary(stage.stage_definition)
parsed_nodes = self._parse_rst(summary)
section.extend(parsed_nodes)
# Config table
config_rows = stage.generate_config_rows_for_stage()
if len(config_rows) > 0:
table_node = self._build_config_table(config_rows)
section += table_node
all_nodes.append(section)
return all_nodes
def _parse_rst(self, text):
"""Parse reST text into a list of docutils nodes."""
source = self.state.document.current_source or "<auto>"
temp_doc = new_document(source, self.state.document.settings)
parser = Parser()
parser.parse(text, temp_doc)
return temp_doc.children
def _resolve_object(self, dotted_path: str):
module_path, _, attr = dotted_path.rpartition(".")
module = import_module(module_path)
return getattr(module, attr)
def _extract_summary(self, function) -> str:
"""
Given a python function, extract its description part
from the docstring.
It returns a string which contains the part of docstring
until it finds the first item which is either a "numpy docstyle" OR
a sphinx reST directive.
"""
docstring = inspect.getdoc(function) or ""
lines = prepare_docstring(docstring)
directives_to_start_with = tuple(
word.lower()
for word in (NAPOLEAN_DOCSTYLE_DIRECTIVES + SPHINX_REST_DIRECTIVES)
)
summary_lines = []
for line in lines:
if line.strip().lower().startswith(directives_to_start_with):
break
summary_lines.append(line)
return "\n".join(summary_lines).strip()
def _build_config_table(self, data_rows: list[ConfigTableRow]):
table = nodes.table()
tgroup = nodes.tgroup(cols=len(COL_WIDTHS))
table += tgroup
for col_width in COL_WIDTHS:
tgroup += nodes.colspec(colwidth=col_width)
table_rows = map(dataclass_to_list, [HEADER_ROW, *data_rows])
# Header Row
thead = nodes.thead()
header_row_node = nodes.row()
for col in next(table_rows):
entry = nodes.entry()
entry += nodes.paragraph(text=col)
header_row_node += entry
thead += header_row_node
tgroup += thead
# Data Rows
tbody = nodes.tbody()
for row in table_rows:
data_row_node = nodes.row()
for col in row:
entry = nodes.entry()
entry.extend(self._parse_rst(col))
data_row_node += entry
tbody += data_row_node
tgroup += tbody
return table
[docs]
def setup(app):
app.add_directive("pipelineconfig", PiperPipelineConfigDirective)
return {"version": __version__, "parallel_read_safe": True}