"""Common schema loading functions"""
import pyarrow
import ska_sdp_dal
_COMPLEX_TYPES = ("complex64", "complex128")
_DATA_TYPES = {name: getattr(ska_sdp_dal, name) for name in _COMPLEX_TYPES}
_DATA_TYPES.update(
{
name: getattr(pyarrow, name)()
for name in (
"bool_",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"float16",
"float32",
"float64",
"string",
)
}
)
def _get_arrow_field(field_spec, support_complex):
data_type = field_spec["type"]
if not support_complex and data_type in _COMPLEX_TYPES:
raise ValueError(
f"Data type f{data_type} not supported in this schema"
)
if data_type == "struct":
data_type = pyarrow.struct(
[
_get_arrow_field(subfield, support_complex=support_complex)
for subfield in field_spec["subfields"]
]
)
elif data_type == "list":
item_type = field_spec["item"]
size = int(field_spec.get("size", -1))
data_type = pyarrow.list_(_DATA_TYPES[item_type], size)
else:
data_type = _DATA_TYPES[data_type]
return pyarrow.field(field_spec["field"], data_type)
[docs]
def load_table_schemas(table_schemas):
"""
Load Arrow schemas from the `table_schemas`, which contains Table
specifications.
:param table_schemas: A dictionary describing all table schemas to load.
"""
return {
name: pyarrow.schema(
[
_get_arrow_field(field, support_complex=False)
for field in fields
]
)
for name, fields in table_schemas.items()
}
def _load_in_param(param_spec, table_schemas):
name = param_spec["name"]
otype = param_spec["object"]
if otype not in ("table", "tensor"):
raise ValueError(f'object {otype} should be "table" or "tensor"')
if otype == "table":
return ska_sdp_dal.make_table_input_par(
name, table_schemas[param_spec["schema"]]
)
# a tensor
dim_names = param_spec.get("dimensions", [])
return ska_sdp_dal.make_tensor_input_par(
name, _DATA_TYPES[param_spec["type"]], dim_names
)
def _load_out_param(param_spec):
# only tensors are supported currently
dim_names = param_spec.get("dimensions", [])
return ska_sdp_dal.make_tensor_output_par(
param_spec["name"], _DATA_TYPES[param_spec["type"]], dim_names
)
[docs]
def load_procedure_schemas(procedures_spec, table_schemas):
"""
Load Arrow schemas from the `procedures_spec` dictionary, which contains
remote procedure specifications.
:param procedures_spec: A dictionary describing all procedure schemas to
load.
:param table_schemas: Map of Arrow Table schemas, should contain all tables
refered to by the procedures.
"""
return {
name: ska_sdp_dal.make_call_schema(
name,
[
_load_in_param(param, table_schemas)
for param in params["inputs"]
]
+ [_load_out_param(param) for param in params["outputs"]],
)
for name, params in procedures_spec.items()
}