# coding: utf-8
"""
Basic objects for defining statistical inference models.
"""
from __future__ import annotations
import enum
import copy as _copy
import law
import order as od
import yaml
from columnflow.types import Generator, Callable, TextIO, Sequence, Any, Hashable, Type, T
from columnflow.util import (
CachedDerivableMeta, Derivable, DotDict, is_pattern, is_regex, pattern_matcher, get_docs_url, freeze,
)
logger = law.logger.get_logger(__name__)
default_dataset = law.config.get_expanded("analysis", "default_dataset")
[docs]
class ParameterType(enum.Enum):
"""
Parameter type flag.
:cvar rate_gauss: Gaussian rate parameter.
:cvar rate_uniform: Uniform rate parameter.
:cvar rate_unconstrained: Unconstrained rate parameter.
:cvar shape: Shape parameter.
"""
rate_gauss = "rate_gauss"
rate_uniform = "rate_uniform"
rate_unconstrained = "rate_unconstrained"
shape = "shape"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.value}>"
def __str__(self) -> str:
return self.value
@property
def is_rate(self) -> bool:
"""
Checks if the parameter type is a rate type.
:returns: *True* if the parameter type is a rate type, *False* otherwise.
"""
return self in {
self.rate_gauss,
self.rate_uniform,
self.rate_unconstrained,
}
@property
def is_shape(self) -> bool:
"""
Checks if the parameter type is a shape type.
:returns: *True* if the parameter type is a shape type, *False* otherwise.
"""
return self in {
self.shape,
}
[docs]
class FlowStrategy(enum.Enum):
"""
Strategy to handle over- and underflow bin contents.
:cvar ignore: Ignore over- and underflow bins.
:cvar warn: Issue a warning when over- and underflow bins are encountered.
:cvar remove: Remove over- and underflow bins.
:cvar move: Move over- and underflow bins to the first and last bin, respectively.
"""
ignore = "ignore"
warn = "warn"
remove = "remove"
move = "move"
def __str__(self) -> str:
return self.value
[docs]
class InferenceModel(Derivable, metaclass=InferenceModelMeta):
"""
Interface to statistical inference models with connections to config objects (such as py:class:`order.Config` or
:py:class:`order.Dataset`).
The internal structure to describe a model looks as follows (in yaml style) and is accessible through
:py:attr:`model` as well as property access to its top-level objects.
.. code-block:: yaml
categories:
- name: cat1
postfix: null
config_data:
22pre_v14:
category: 1e
variable: ht
data_datasets: [data_mu_a]
data_from_processes: []
mc_stats: 10
flow_strategy: warn
rate_precision: 5
empty_bin_value: 1e-05
processes:
- name: HH
is_signal: True
config_data:
22pre_v14:
process: hh
mc_datasets: [hh_ggf]
scale: 1.0
is_dynamic: False
parameters:
- name: lumi
type: rate_gauss
effect: 1.02
effect_precision: 4
config_data: {}
transformations: []
- name: pu
type: rate_gauss
effect: [0.97, 1.02]
effect_precision: 4
config_data: {}
transformations: []
- name: pileup
type: shape
effect: 1.0
effect_precision: 4
config_data:
22pre_v14:
shift_source: minbias_xs
transformations: []
- name: tt
is_signal: False
config_data:
22pre_v14:
process: tt
mc_datasets: [tt_sl, tt_dl, tt_fh]
scale: 1.0
is_dynamic: False
parameters:
- name: lumi
type: rate_gauss
effect: 1.02
effect_precision: 4
config_data: {}
transformations: []
- name: cat2
...
parameter_groups:
- name: rates
parameters_names: [lumi, pu]
- ...
.. py:attribute:: name
type: str
The unique name of this model.
.. py:attribute:: config_insts
type: list[order.Config]
Reference to :py:class:`order.Config` objects.
.. py:attribute:: model
type: DotDict
The internal data structure representing the model, see :py:meth:`InferenceModel.model_spec`.
"""
# optional initialization method
init_func = None
[docs]
class YamlDumper(yaml.SafeDumper):
"""
YAML dumper for statistical inference models with ammended representers to serialize
internal, structured objects as safe, standard objects.
"""
@classmethod
def _map_repr(cls, dumper: yaml.Dumper, data: dict) -> str:
return dumper.represent_dict(data if isinstance(data, dict) else dict(data))
@classmethod
def _list_repr(cls, dumper: yaml.Dumper, data: list) -> str:
return dumper.represent_list(data if isinstance(data, list) else list(data))
@classmethod
def _str_repr(cls, dumper: yaml.Dumper, data: str) -> str:
return dumper.represent_str(data if isinstance(data, str) else str(data))
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# ammend representers
self.add_representer(DotDict, self._map_repr)
self.add_representer(tuple, self._list_repr)
self.add_representer(ParameterType, self._str_repr)
self.add_representer(ParameterTransformation, self._str_repr)
self.add_representer(ParameterTransformations, self._list_repr)
self.add_representer(FlowStrategy, self._str_repr)
def ignore_aliases(self, *args, **kwargs) -> bool:
return True
[docs]
@classmethod
def inference_model(
cls: T,
func: Callable | None = None,
bases: tuple[type] = (),
**kwargs,
) -> Type[T] | Callable:
"""
Decorator for creating a new :py:class:`InferenceModel` subclass with additional, optional
*bases* and attaching the decorated function to it as ``init_func``. All additional *kwargs*
are added as class members of the new subclass.
:param func: The function to be decorated and attached as ``init_func``.
:param bases: Optional tuple of base classes for the new subclass.
:returns: The new subclass or a decorator function.
"""
def decorator(func: Callable) -> Type[T]:
# create the class dict
cls_dict = {**kwargs, "init_func": func}
# create the subclass
subclass = cls.derive(func.__name__, bases=bases, cls_dict=cls_dict)
return subclass
return decorator(func) if func else decorator
[docs]
@classmethod
def used_datasets(cls, config_inst: od.Config) -> list[str]:
"""
Used datasets for which the `upstream_task_cls.resolve_instances` will be called.
Defaults to the default dataset.
"""
return [default_dataset]
[docs]
@classmethod
def model_spec(cls) -> DotDict:
"""
Returns a dictionary representing the top-level structure of the model.
- *categories*: List of :py:meth:`category_spec` objects.
- *parameter_groups*: List of :py:meth:`parameter_group_spec` objects.
"""
return DotDict([
("categories", []),
("parameter_groups", []),
])
[docs]
@classmethod
def category_spec(
cls,
name: str,
config_data: dict[str, DotDict] | None = None,
data_from_processes: Sequence[str] | None = None,
flow_strategy: FlowStrategy | str = FlowStrategy.warn,
mc_stats: int | float | tuple | None = None,
postfix: str | None = None,
empty_bin_value: float = 1e-5,
rate_precision: int = 5,
) -> DotDict:
"""
Returns a dictionary representing a category (interchangeably called bin or channel in other
tools), forwarding all arguments.
:param name: The name of the category in the model.
:param config_data: Dictionary mapping names of :py:class:`order.Config` objects to dictionaries following the
:py:meth:`category_config_spec` that wrap settings like category, variable and real datasets in that config.
:param data_from_processes: Optional list of names of :py:meth:`process_spec` objects that, when
*config_data_datasets* is not defined, make up a fake data contribution.
:param flow_strategy: A :py:class:`FlowStrategy` instance describing the strategy to handle over- and underflow
bin contents.
:param mc_stats: Either *None* to disable MC stat uncertainties, or an integer, a float or a tuple thereof to
control the options of MC stat options.
:param postfix: A postfix that is appended to (e.g.) file names created for this model.
:param empty_bin_value: When bins have no content, they are filled with this value.
:param rate_precision: The precision of reported rates.
:returns: A dictionary representing the category.
"""
return DotDict([
("name", str(name)),
("config_data", (
{k: cls.category_config_spec(**v) for k, v in config_data.items()}
if config_data
else {}
)),
("data_from_processes", list(map(str, data_from_processes or []))),
("flow_strategy", (
flow_strategy
if isinstance(flow_strategy, FlowStrategy)
else FlowStrategy[flow_strategy]
)),
("mc_stats", mc_stats),
("postfix", str(postfix) if postfix else None),
("empty_bin_value", empty_bin_value),
("rate_precision", rate_precision),
("processes", []),
])
[docs]
@classmethod
def process_spec(
cls,
name: str,
is_signal: bool = False,
config_data: dict[str, DotDict] | None = None,
scale: float | int = 1.0,
is_dynamic: bool = False,
) -> DotDict:
"""
Returns a dictionary representing a process, forwarding all arguments.
:param name: The name of the process in the model.
:param is_signal: A boolean flag deciding whether this process describes signal.
:param config_data: Dictionary mapping names of :py:class:`order.Config` objects to dictionaries following the
:py:meth:`process_config_spec` that wrap settings like process and mc datasets in that config.
:param scale: A float value to scale the process, defaulting to 1.0.
:param is_dynamic: A boolean flag deciding whether this process is dynamic, i.e., whether it is created
on-the-fly.
:returns: A dictionary representing the process.
"""
return DotDict([
("name", str(name)),
("is_signal", bool(is_signal)),
("config_data", (
{k: cls.process_config_spec(**v) for k, v in config_data.items()}
if config_data
else {}
)),
("scale", float(scale)),
("is_dynamic", bool(is_dynamic)),
("parameters", []),
])
[docs]
@classmethod
def parameter_spec(
cls,
name: str,
type: ParameterType | str,
transformations: Sequence[ParameterTransformation | str] = (),
config_data: dict[str, DotDict] | None = None,
effect: Any | None = 1.0,
effect_precision: int = 4,
) -> DotDict:
"""
Returns a dictionary representing a (nuisance) parameter, forwarding all arguments.
:param name: The name of the parameter in the model.
:param type: A :py:class:`ParameterType` instance describing the type of this parameter.
:param transformations: A sequence of :py:class:`ParameterTransformation` instances describing transformations
to be applied to the effect of this parameter.
:param config_data: Dictionary mapping names of :py:class:`order.Config` objects to dictionaries following the
:py:meth:`parameter_config_spec` that wrap settings like corresponding shift source in that config.
:param effect: An arbitrary object describing the effect of the parameter (e.g. float for symmetric rate
effects, 2-tuple for down/up variation, etc).
:param effect_precision: The precision of reported effects.
:returns: A dictionary representing the parameter.
"""
return DotDict([
("name", str(name)),
("type", type if isinstance(type, ParameterType) else ParameterType[type]),
("transformations", ParameterTransformations(transformations)),
("config_data", (
{k: cls.parameter_config_spec(**v) for k, v in config_data.items()}
if config_data
else {}
)),
("effect", effect),
("effect_precision", effect_precision),
])
[docs]
@classmethod
def parameter_group_spec(
cls,
name: str,
parameter_names: Sequence[str] | None = None,
) -> DotDict:
"""
Returns a dictionary representing a group of parameter names.
:param name: The name of the parameter group in the model.
:param parameter_names: Names of parameter objects this group contains.
:returns: A dictionary representing the group of parameter names.
"""
return DotDict([
("name", str(name)),
("parameter_names", list(map(str, parameter_names or []))),
])
[docs]
@classmethod
def category_config_spec(
cls,
category: str | None = None,
variable: str | None = None,
data_datasets: Sequence[str] | None = None,
) -> DotDict:
"""
Returns a dictionary representing configuration specific data, forwarding all arguments.
:param category: The name of the source category in the config to use.
:param variable: The name of the variable in the config to use.
:param data_datasets: List of names or patterns of datasets in the config to use for real data.
:returns: A dictionary representing category specific settings.
"""
return DotDict([
("category", str(category) if category else None),
("variable", str(variable) if variable else None),
("data_datasets", list(map(str, data_datasets or []))),
])
[docs]
@classmethod
def process_config_spec(
cls,
process: str | None = None,
mc_datasets: Sequence[str] | None = None,
) -> DotDict:
"""
Returns a dictionary representing configuration specific data, forwarding all arguments.
:param process: The name of the process in the config to use.
:param mc_datasets: List of names or patterns of datasets in the config to use for mc.
:returns: A dictionary representing process specific settings.
"""
return DotDict([
("process", str(process) if process else None),
("mc_datasets", list(map(str, mc_datasets or []))),
])
[docs]
@classmethod
def parameter_config_spec(
cls,
shift_source: str | None = None,
) -> DotDict:
"""
Returns a dictionary representing configuration specific data, forwarding all arguments.
:param shift_source: The name of a systematic shift source in the config.
:returns: A dictionary representing parameter specific settings.
"""
return DotDict([
("shift_source", str(shift_source) if shift_source else None),
])
def __init__(self, config_insts: list[od.Config] | None = None) -> None:
super().__init__()
# store attributes
self.config_insts = config_insts or []
# temporary attributes for as long as we issue deprecation warnings
self.__config_inst = None
# model info
self.model = self.model_spec()
# custom init function when set
if callable(self.init_func):
self.init_func()
[docs]
def to_yaml(self, stream: TextIO | None = None) -> str | None:
"""
Writes the content of the :py:attr:`model` into a file-like object *stream* when given, and
returns a string representation otherwise.
:param stream: A file-like object to write the model content into.
:returns: A string representation of the model content if *stream* is not provided.
"""
return yaml.dump(self.model, stream=stream, Dumper=self.YamlDumper)
[docs]
def pprint(self) -> None:
"""
Pretty-prints the content of the :py:attr:`model` in yaml-style.
"""
print(self.to_yaml())
#
# property access to top-level objects
#
# !! to be removed in a future release
@property
def config_inst(self) -> od.Config:
if self.__config_inst:
return self.__config_inst
# trigger a verbose warning in case the deprecated attribute is accessed
docs_url = get_docs_url("user_guide", "02_03_transition.html")
api_url = get_docs_url("api", "inference", "index.html", anchor="columnflow.inference.InferenceModel")
logger.warning_once(
"inference_model_deprected_config_inst",
"access to attribute 'config_inst' in inference model was removed; use 'config_insts' instead; also, make "
"sure to use the new 'config_data' attribute in 'add_category()' for a more fine-grained control over the "
f"composition of your inference model categories; see {docs_url} and {api_url} for more info",
)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute 'config_inst'")
@config_inst.setter
def config_inst(self, config_inst: od.Config) -> None:
self.__config_inst = config_inst
@property
def categories(self) -> DotDict:
return self.model.categories
@property
def parameter_groups(self) -> DotDict:
return self.model.parameter_groups
#
# handling of categories
#
[docs]
def get_categories(
self,
category: str | Sequence[str] | None = None,
only_names: bool = False,
match_mode: Callable = any,
) -> list[DotDict | str]:
"""
Returns a list of categories whose name match *category*. *category* can be a string, a
pattern, or sequence of them. When *only_names* is *True*, only names of categories are
returned rather than structured dictionaries.
:param category: A string, pattern, or sequence of them to match category names.
:param only_names: A boolean flag to return only names of categories if set to *True*.
:param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:returns: A list of matching categories or their names.
"""
# rename arguments to make their meaning explicit
category_pattern = category
match = pattern_matcher(category_pattern or "*", mode=match_mode)
return [
(category.name if only_names else category)
for category in self.categories
if match(category.name)
]
[docs]
def get_category(
self,
category: str | Sequence[str],
only_name: bool = False,
match_mode: Callable = any,
silent: bool = False,
) -> DotDict | str:
"""
Returns a single category whose name matches *category*. *category* can be a string, a
pattern, or sequence of them. An exception is raised if no or more than one category is
found, unless *silent* is *True* in which case *None* is returned. When *only_name* is
*True*, only the name of the category is returned rather than a structured dictionary.
:param category: A string, pattern, or sequence of them to match category names.
:param only_name: A boolean flag to return only the name of the category if set to *True*.
:param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param silent: A boolean flag to return *None* instead of raising an exception if no or more than one category
is found.
:returns: A single matching category or its name.
"""
# rename arguments to make their meaning explicit
category_name = category
categories = self.get_categories(category_name, only_names=only_name, match_mode=match_mode)
# length checks
if not categories:
if silent:
return None
raise ValueError(f"no category named '{category_name}' found")
if len(categories) > 1:
if silent:
return None
raise ValueError(f"category name '{category_name}' matches more than one category")
return categories[0]
[docs]
def has_category(
self,
category: str | Sequence[str],
match_mode: Callable = any,
) -> bool:
"""
Returns *True* if a category whose name matches *category* is existing, and *False*
otherwise. *category* can be a string, a pattern, or sequence of them.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if a matching category exists, *False* otherwise.
"""
# rename arguments to make their meaning explicit
category_pattern = category
# simple length check
return len(self.get_categories(category_pattern, only_names=True, match_mode=match_mode)) > 0
[docs]
def add_category(self, *args, **kwargs) -> None:
"""
Adds a new category with all *args* and *kwargs* used to create the structured category
dictionary via :py:meth:`category_spec`. If a category with the same name already exists, an
exception is raised.
:raises ValueError: If a category with the same name already exists.
"""
# create the object
category = self.category_spec(*args, **kwargs)
# checks
if self.has_category(category.name):
raise ValueError(f"category named '{category.name}' already registered")
# add it
self.categories.append(category)
[docs]
def remove_category(
self,
category: str | Sequence[str],
match_mode: Callable = any,
) -> bool:
"""
Removes one or more categories whose names match *category*.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if at least one category was removed, *False* otherwise.
"""
# rename arguments to make their meaning explicit
category_pattern = category
match = pattern_matcher(category_pattern, mode=match_mode)
n_before = len(self.categories)
self.categories[:] = [
category
for category in self.categories
if not match(category.name)
]
# stop early if nothing was removed
removed_any = len(self.categories) != n_before
if not removed_any:
return False
return True
#
# handling of processes
#
[docs]
def get_processes(
self,
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
only_names: bool = False,
flat: bool = False,
match_mode: Callable = any,
category_match_mode: Callable = any,
) -> dict[str, DotDict | str] | list[str]:
"""
Returns a dictionary of processes whose names match *process*, mapped to the name of the
category they belong to. Categories can optionally be filtered through *category*. Both
*process* and *category* can be a string, a pattern, or sequence of them.
When *only_names* is *True*, only names of processes are returned rather than structured
dictionaries. When *flat* is *True*, a flat, unique list of process names is returned.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to filter categories.
:param only_names: A boolean flag to return only names of processes if set to *True*.
:param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param flat: A boolean flag to return a flat, unique list of process names if set to *True*.
:returns: A dictionary of processes mapped to the category name, or a list of process names.
"""
# rename arguments to make their meaning explicit
process_pattern = process
category_pattern = category
# when flat is True, one can only return names
if flat:
only_names = True
# get matching categories first
categories = self.get_categories(category_pattern, match_mode=category_match_mode)
# look for the process pattern in each of them
match = pattern_matcher(process_pattern or "*", mode=match_mode)
pairs = (
(category.name, [
(process.name if only_names else process)
for process in category.processes
if match(process.name)
])
for category in categories
)
# only store entries with processes
processes = DotDict.wrap(pair for pair in pairs if pair[1])
# flatten
if flat:
processes = law.util.make_unique(sum(processes.values(), []))
return processes
[docs]
def get_process(
self,
process: str | Sequence[str],
category: str | Sequence[str] | None = None,
only_name: bool = False,
match_mode: Callable = any,
category_match_mode: Callable = any,
silent: bool = False,
) -> DotDict | str:
"""
Returns a single process whose name matches *process*, and optionally, whose category's
name matches *category*. Both *process* and *category* can be a string, a pattern, or
sequence of them.
An exception is raised if no or more than one process is found, unless *silent* is *True*
in which case *None* is returned. When *only_name* is *True*, only the name of the
process is returned rather than a structured dictionary.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param only_name: A boolean flag to return only the name of the process if set to *True*.
:param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param silent: A boolean flag to return *None* instead of raising an exception if no or more than one process is
found.
:returns: A single matching process or its name.
:raises ValueError: If no process or more than one process is found and *silent* is *False*.
"""
# rename arguments to make their meaning explicit
process_name = process
category_pattern = category
processes = self.get_processes(
process_name,
category=category_pattern,
only_names=only_name,
match_mode=match_mode,
category_match_mode=category_match_mode,
)
# checks
if not processes:
if silent:
return None
err = f"no process named '{process_name}' found"
if category_pattern:
err += f" in category '{category_pattern}'"
raise ValueError(err)
if len(processes) > 1:
if silent:
return None
raise ValueError(
f"process '{process_name}' found in more than one category: " + ",".join(processes),
)
category, processes = list(processes.items())[0]
if len(processes) > 1:
if silent:
return None
names = processes if only_name else [p.name for p in processes]
raise ValueError(
f"process '{process_name}' found more than once in category '{category}': " +
",".join(names),
)
return processes[0]
[docs]
def has_process(
self,
process: str | Sequence[str],
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
) -> bool:
"""
Returns *True* if a process whose name matches *process*, and optionally whose category's
name matches *category*, exists, and *False* otherwise. Both *process* and *category* can
be a string, a pattern, or sequence of them.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if a matching process exists, *False* otherwise.
"""
# rename arguments to make their meaning explicit
process_pattern = process
category_pattern = category
# simple length check
return len(self.get_processes(
process_pattern,
category=category_pattern,
only_names=True,
match_mode=match_mode,
category_match_mode=category_match_mode,
)) > 0
[docs]
def add_process(
self,
*args,
category: str | Sequence[str] | None = None,
category_match_mode: Callable = any,
silent: bool = False,
**kwargs,
) -> None:
"""
Adds a new process to all categories whose names match *category*, with all *args* and
*kwargs* used to create the structured process dictionary via :py:meth:`process_spec`.
*category* can be a string, a pattern, or sequence of them.
If a process with the same name already exists in one of the categories, an exception is
raised unless *silent* is *True*.
:param args: Positional arguments used to create the process.
:param category: A string, pattern, or sequence of them to match category names.
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param silent: A boolean flag to suppress exceptions if a process with the same name already exists.
:param kwargs: Keyword arguments used to create the process.
:raises ValueError: If a process with the same name already exists in one of the categories and *silent* is
*False*.
"""
# rename arguments to make their meaning explicit
category_pattern = category
process = self.process_spec(*args, **kwargs)
# get categories the process should be added to
categories = self.get_categories(category_pattern, match_mode=category_match_mode)
# check for duplicates
target_categories = []
for category in categories:
if self.has_process(process.name, category=category.name):
if silent:
continue
raise ValueError(
f"process named '{process.name}' already registered in category " +
f"'{category.name}'",
)
target_categories.append(category)
# add independent copies to categories
for category in target_categories:
category.processes.append(_copy.deepcopy(process))
[docs]
def remove_process(
self,
process: str | Sequence[str],
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
) -> bool:
"""
Removes one or more processes whose names match *process*, and optionally whose category's
name matches *category*. Both *process* and *category* can be a string, a pattern, or
sequence of them. Returns *True* if at least one process was removed, and *False* otherwise.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if at least one process was removed, *False* otherwise.
"""
# rename arguments to make their meaning explicit
process_pattern = process
category_pattern = category
# get categories the process should be removed from
categories = self.get_categories(category_pattern, match_mode=category_match_mode)
match = pattern_matcher(process_pattern, mode=match_mode)
removed_any = False
for category in categories:
n_before = len(category.processes)
category.processes[:] = [
process
for process in category.processes
if not match(process.name)
]
removed_any |= len(category.processes) != n_before
# stop early if nothing was removed
if not removed_any:
return False
return True
#
# handling of parameters
#
[docs]
def get_parameters(
self,
parameter: str | Sequence[str] | None = None,
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
process_match_mode: Callable = any,
only_names: bool = False,
flat: bool = False,
) -> dict[str, dict[str, DotDict | str]] | list[str]:
"""
Returns a dictionary of parameters whose names match *parameter*, mapped twice to the name
of the category and the name of the process they belong to. Categories and processes can
optionally be filtered through *category* and *process*. All three, *parameter*, *process*
and *category* can be a string, a pattern, or sequence of them.
When *only_names* is *True*, only names of parameters are returned rather than structured
dictionaries. When *flat* is *True*, a flat, unique list of parameter names is returned.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param only_names: A boolean flag to return only names of parameters if set to *True*.
:param flat: A boolean flag to return a flat, unique list of parameter names if set to *True*.
:returns: A dictionary of parameters mapped to category and process names, or a list of parameter names.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
process_pattern = process
category_pattern = category
# when flat is True, one can only return names
if flat:
only_names = True
# get matching processes (mapped to matching categories)
processes = self.get_processes(
process=process_pattern,
category=category_pattern,
match_mode=process_match_mode,
category_match_mode=category_match_mode,
)
# look for the process pattern in each pair
match = pattern_matcher(parameter_pattern or "*", mode=match_mode)
parameters = DotDict()
for category_name, _processes in processes.items():
pairs = (
(process.name, [
(parameter.name if only_names else parameter)
for parameter in process.parameters
if match(parameter.name)
])
for process in _processes
)
# only store entries with parameters
_parameters = DotDict.wrap(pair for pair in pairs if pair[1])
if _parameters:
parameters[category_name] = _parameters
# flatten
if flat:
parameters = law.util.make_unique(sum((
sum(_parameters.values(), [])
for _parameters in parameters.values()
), []))
return parameters
[docs]
def get_parameter(
self,
parameter: str | Sequence[str],
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
process_match_mode: Callable = any,
only_name: bool = False,
silent: bool = False,
) -> DotDict | str:
"""
Returns a single parameter whose name matches *parameter*, and optionally, whose category's
and process' name matches *category* and *process*. All three, *parameter*, *process* and
*category* can be a string, a pattern, or sequence of them.
An exception is raised if no or more than one parameter is found, unless *silent* is *True*
in which case *None* is returned. When *only_name* is *True*, only the name of the parameter
is returned rather than a structured dictionary.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param only_name: A boolean flag to return only the name of the parameter if set to *True*.
:param silent: A boolean flag to return *None* instead of raising an exception if no or more than one parameter
is found.
:returns: A single matching parameter or its name.
:raises ValueError: If no parameter or more than one parameter is found and *silent* is *False*.
"""
# rename arguments to make their meaning explicit
parameter_name = parameter
process_pattern = process
category_pattern = category
parameters = self.get_parameters(
parameter_name,
process=process_pattern,
category=category_pattern,
match_mode=match_mode,
category_match_mode=category_match_mode,
process_match_mode=process_match_mode,
only_names=only_name,
)
# checks
if not parameters:
if silent:
return None
err = f"no parameter named '{parameter_name}' found"
if process_pattern:
err += f" for process '{process_pattern}'"
if category_pattern:
err += f" in category '{category_pattern}'"
raise ValueError(err)
if len(parameters) > 1:
if silent:
return None
raise ValueError(
f"parameter '{parameter_name}' found in more than one category: " +
",".join(parameters),
)
category, parameters = list(parameters.items())[0]
if len(parameters) > 1:
if silent:
return None
raise ValueError(
f"parameter '{parameter_name}' found in more than one process in category " +
f"'{category}': " + ",".join(parameters),
)
process, parameters = list(parameters.items())[0]
if len(parameters) > 1:
if silent:
return None
names = parameters if only_name else [p.name for p in parameters]
raise ValueError(
f"parameter '{parameter_name}' found more than once in category '{category}' and " +
f"process '{process}': " + ",".join(names),
)
return parameters[0]
[docs]
def has_parameter(
self,
parameter: str | Sequence[str],
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
process_match_mode: Callable = any,
) -> bool:
"""
Returns *True* if a parameter whose name matches *parameter*, and optionally whose
category's and process' name match *category* and *process*, exists, and *False*
otherwise. All three, *parameter*, *process* and *category* can be a string, a pattern,
or sequence of them.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if a matching parameter exists, *False* otherwise.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
process_pattern = process
category_pattern = category
# simple length check
return len(self.get_parameters(
parameter_pattern,
process=process_pattern,
category=category_pattern,
match_mode=match_mode,
category_match_mode=category_match_mode,
process_match_mode=process_match_mode,
)) > 0
[docs]
def add_parameter(
self,
*args,
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
category_match_mode: Callable = any,
process_match_mode: Callable = any,
group: str | Sequence[str] | None = None,
**kwargs,
) -> DotDict:
"""
Adds a parameter to all categories and processes whose names match *category* and *process*, with all *args* and
*kwargs* used to create the structured parameter dictionary via :py:meth:`parameter_spec`. Both *process* and
*category* can be a string, a pattern, or sequence of them.
If a parameter with the same name already exists in one of the processes throughout the categories, an exception
is raised.
When *group* is given and the parameter is new, it is added to one or more parameter groups via
:py:meth:`add_parameter_to_group`.
:param args: Positional arguments used to create the parameter.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param group: A string, pattern, or sequence of them to specify parameter groups.
:param kwargs: Keyword arguments used to create the parameter.
:returns: The created parameter.
:raises ValueError: If a parameter with the same name already exists in one of the processes throughout the
categories.
"""
# rename arguments to make their meaning explicit
process_pattern = process
category_pattern = category
parameter = self.parameter_spec(*args, **kwargs)
# get processes (mapped to categories) the parameter should be added to
processes = self.get_processes(
process=process_pattern,
category=category_pattern,
match_mode=process_match_mode,
category_match_mode=category_match_mode,
)
# check for duplicates
for category_name, _processes in processes.items():
for process in _processes:
if self.has_parameter(parameter.name, process=process.name, category=category_name):
raise ValueError(
f"parameter named '{parameter.name}' already registered for process " +
f"'{process.name}' in category '{category_name}'",
)
# add independent copies to processes
for category_name, _processes in processes.items():
for process in _processes:
process.parameters.append(_copy.deepcopy(parameter))
# add to groups if it was added to at least one process
if group and processes and any(_processes for _processes in processes.values()):
self.add_parameter_to_group(parameter.name, group)
return parameter
[docs]
def remove_parameter(
self,
parameter: str | Sequence[str],
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
process_match_mode: Callable = any,
) -> bool:
"""
Removes one or more parameters whose names match *parameter*, and optionally whose
category's and process' name match *category* and *process*. All three, *parameter*,
*process* and *category* can be a string, a pattern, or sequence of them.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if at least one parameter was removed, *False* otherwise.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
process_pattern = process
category_pattern = category
# get processes (mapped to categories) the parameter should be removed from
processes = self.get_processes(
process=process_pattern,
category=category_pattern,
match_mode=process_match_mode,
category_match_mode=category_match_mode,
)
match = pattern_matcher(parameter_pattern, mode=match_mode)
removed_any = False
for _processes in processes.values():
for process in _processes:
n_before = len(process.parameters)
process.parameters[:] = [
parameter
for parameter in process.parameters
if not match(parameter.name)
]
removed_any |= len(process.parameters) != n_before
# stop early if nothing was removed
if not removed_any:
return False
return True
#
# handling of parameter groups
#
[docs]
def get_parameter_groups(
self,
group: str | Sequence[str] | None = None,
match_mode: Callable = any,
only_names: bool = False,
) -> list[DotDict | str]:
"""
Returns a list of parameter groups whose names match *group*. *group* can be a string, a
pattern, or sequence of them.
When *only_names* is *True*, only names of parameter groups are returned rather than
structured dictionaries.
:param group: A string, pattern, or sequence of them to match group names.
:param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see
:py:func:`pattern_matcher`).
:param only_names: A boolean flag to return only names of parameter groups if set to *True*.
:returns: A list of parameter groups or their names.
"""
# rename arguments to make their meaning explicit
group_pattern = group
match = pattern_matcher(group_pattern or "*", mode=match_mode)
return [
(group.name if only_names else group)
for group in self.parameter_groups
if match(group.name)
]
[docs]
def get_parameter_group(
self,
group: str | Sequence[str],
match_mode: Callable = any,
only_name: bool = False,
) -> DotDict | str:
"""
Returns a single parameter group whose name matches *group*. *group* can be a string, a
pattern, or sequence of them.
An exception is raised in case no or more than one parameter group is found. When
*only_name* is *True*, only the name of the parameter group is returned rather than a
structured dictionary.
:param group: A string, pattern, or sequence of them to match group names.
:param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see
:py:func:`pattern_matcher`).
:param only_name: A boolean flag to return only the name of the parameter group if set to *True*.
:returns: A single matching parameter group or its name.
:raises ValueError: If no parameter group or more than one parameter group is found.
"""
# rename arguments to make their meaning explicit
group_name = group
groups = self.get_parameter_groups(group_name, match_mode=match_mode, only_names=only_name)
# checks
if not groups:
raise ValueError(f"no parameter group named '{group_name}' found")
if len(groups) > 1:
raise ValueError(f"parameter group name '{group_name}' matches more than one category")
return groups[0]
[docs]
def has_parameter_group(
self,
group: str | Sequence[str],
match_mode: Callable = any,
) -> bool:
"""
Returns *True* if a parameter group whose name matches *group* exists, and *False*
otherwise. *group* can be a string, a pattern, or sequence of them.
:param group: A string, pattern, or sequence of them to match group names.
:param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if a matching parameter group exists, *False* otherwise.
"""
# rename arguments to make their meaning explicit
group_pattern = group
# simeple length check
return len(self.get_parameter_groups(group_pattern, match_mode=match_mode)) > 0
[docs]
def add_parameter_group(self, *args, **kwargs) -> None:
"""
Adds a new parameter group with all *args* and *kwargs* used to create the structured
parameter group dictionary via :py:meth:`parameter_group_spec`. If a group with the same
name already exists, an exception is raised.
:param args: Positional arguments used to create the parameter group.
:param kwargs: Keyword arguments used to create the parameter group.
:raises ValueError: If a parameter group with the same name already exists.
"""
# create the instance
group = self.parameter_group_spec(*args, **kwargs)
# checks
if self.has_parameter_group(group.name):
raise ValueError(f"parameter group named '{group.name}' already registered")
self.parameter_groups.append(group)
[docs]
def remove_parameter_group(
self,
group: str | Sequence[str],
match_mode: Callable = any,
) -> bool:
"""
Removes one or more parameter groups whose names match *group*. *group* can be a string, a
pattern, or sequence of them. Returns *True* if at least one group was removed, and *False*
otherwise.
:param group: A string, pattern, or sequence of them to match group names.
:param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if at least one group was removed, *False* otherwise.
"""
# rename arguments to make their meaning explicit
group_pattern = group
match = pattern_matcher(group_pattern, mode=match_mode)
n_before = len(self.parameter_groups)
self.parameter_groups[:] = [
group
for group in self.parameter_groups
if not match(group.name)
]
removed_any = len(self.parameter_groups) != n_before
return removed_any
[docs]
def add_parameter_to_group(
self,
parameter: str | Sequence[str],
group: str | Sequence[str],
match_mode: Callable = any,
parameter_match_mode: Callable = any,
) -> bool:
"""
Adds a parameter named *parameter* to one or multiple parameter groups whose names match
*group*. *group* can be a string, a pattern, or sequence of them. When *parameter* is a
pattern or regular expression, all previously added, matching parameters are added.
Otherwise, *parameter* is added as is. If a parameter was added to at least one group,
*True* is returned and *False* otherwise.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param group: A string, pattern, or sequence of them to match group names.
:param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see
:py:func:`pattern_matcher`).
:param parameter_match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if at least one parameter was added to a group, *False* otherwise.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
group_pattern = group
# stop when there are no matching groups
groups = self.get_parameter_groups(group_pattern, match_mode=match_mode)
if not groups:
return False
# when parameter(s) contain any pattern, get flat list of all existing parameters first
_is_pattern = lambda s: is_pattern(s) or is_regex(s)
parameter_pattern = law.util.make_list(parameter_pattern)
if any(map(_is_pattern, parameter_pattern)):
parameter_names = self.get_parameters(parameter_pattern, flat=True, match_mode=parameter_match_mode)
else:
parameter_names = parameter_pattern
# add names
added_any = False
for group in groups:
for parameter_name in parameter_names:
if parameter_name not in group.parameter_names:
group.parameter_names.append(parameter_name)
added_any = True
return added_any
[docs]
def remove_parameter_from_groups(
self,
parameter: str | Sequence[str],
group: str | Sequence[str] | None = None,
match_mode: Callable = any,
parameter_match_mode: Callable = any,
) -> bool:
"""
Removes all parameters matching *parameter* from parameter groups whose names match *group*.
Both *parameter* and *group* can be a string, a pattern, or sequence of them. Returns *True*
if at least one parameter was removed, and *False* otherwise.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param group: A string, pattern, or sequence of them to match group names.
:param match_mode: Either ``any`` or ``all`` to control the parameter group matching behavior (see
:py:func:`pattern_matcher`).
:param parameter_match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if at least one parameter was removed, *False* otherwise.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
group_pattern = group
# stop when there are no matching groups
groups = self.get_parameter_groups(group_pattern, match_mode=match_mode)
if not groups:
return False
match = pattern_matcher(parameter_pattern, mode=parameter_match_mode)
removed_any = False
for group in groups:
n_before = len(group.parameter_names)
group.parameter_names[:] = [
parameter_name
for parameter_name in group.parameter_names
if not match(parameter_name)
]
removed_any |= len(group.parameter_names) != n_before
return removed_any
#
# reverse lookups
#
[docs]
def get_categories_with_process(
self,
process: str | Sequence[str],
match_mode: Callable = any,
) -> list[str]:
"""
Returns a flat list of category names that contain processes matching *process*. *process*
can be a string, a pattern, or sequence of them.
:param process: A string, pattern, or sequence of them to match process names.
:param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:returns: A list of category names containing matching processes.
"""
# rename arguments to make their meaning explicit
process_pattern = process
# plain name lookup
return list(self.get_processes(process=process_pattern, match_mode=match_mode, only_names=True).keys())
[docs]
def get_processes_with_parameter(
self,
parameter: str | Sequence[str],
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
flat: bool = True,
) -> list[str] | dict[str, list[str]]:
"""
Returns a dictionary of names of processes that contain a parameter whose names match
*parameter*, mapped to category names. Categories can optionally be filtered through
*category*. Both *parameter* and *category* can be a string, a pattern, or sequence of them.
When *flat* is *True*, a flat, unique list of process names is returned.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param flat: A boolean flag to return a flat, unique list of process names if set to *True*.
:returns: A dictionary of process names mapped to category names, or a flat list of process names.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
category_pattern = category
processes = self.get_parameters(
parameter=parameter_pattern,
category=category_pattern,
match_mode=match_mode,
category_match_mode=category_match_mode,
only_names=True,
)
# restructure
processes = DotDict(
(category_name, list(_processes.keys()))
for category_name, _processes in processes.items()
)
# flatten
if flat:
processes = law.util.make_unique(sum(processes.values(), []))
return processes
[docs]
def get_categories_with_parameter(
self,
parameter: str | Sequence[str],
process: str | Sequence[str] | None = None,
match_mode: Callable = any,
process_match_mode: Callable = any,
flat: bool = True,
) -> list[str] | dict[str, list[str]]:
"""
Returns a dictionary of category names mapping to process names that contain parameters
whose names match *parameter*. Processes can optionally be filtered through *process*. Both
*parameter* and *process* can be a string, a pattern, or sequence of them.
When *flat* is *True*, a flat, unique list of category names is returned.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param process: A string, pattern, or sequence of them to match process names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param flat: A boolean flag to return a flat, unique list of category names if set to *True*.
:returns: A dictionary of category names mapped to process names, or a flat list of category names.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
process_pattern = process
categories = self.get_parameters(
parameter=parameter_pattern,
process=process_pattern,
match_mode=match_mode,
process_match_mode=process_match_mode,
only_names=True,
)
# restructure
categories = DotDict(
(category_name, list(_processes.keys()))
for category_name, _processes in categories.items()
)
# flatten
if flat:
categories = list(categories)
return categories
[docs]
def get_groups_with_parameter(
self,
parameter: str | Sequence[str],
match_mode: Callable = any,
) -> list[str]:
"""
Returns a list of names of parameter groups that contain a parameter whose name matches
*parameter*. *parameter* can be a string, a pattern, or sequence of them.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:returns: A list of names of parameter groups containing the matching parameter.
"""
# rename arguments to make their meaning explicit
parameter_pattern = parameter
match = pattern_matcher(parameter_pattern, mode=match_mode)
return [
group.name
for group in self.parameter_groups
if any(map(match, group.parameter_names))
]
#
# removal of empty and dangling objects
#
[docs]
def cleanup(
self,
keep_parameters: str | Sequence[str] | None = None,
) -> None:
"""
Cleans the internal model structure by removing empty and dangling objects by calling
:py:meth:`remove_empty_categories`, :py:meth:`remove_dangling_parameters_from_groups`
(receiving *keep_parameters*), and :py:meth:`remove_empty_parameter_groups` in that order.
:param keep_parameters: A string, pattern, or sequence of them to specify parameters to keep.
"""
self.remove_empty_categories()
self.remove_dangling_parameters_from_groups(keep_parameters=keep_parameters)
self.remove_empty_parameter_groups()
[docs]
def remove_empty_categories(self) -> None:
"""
Removes all categories that contain no processes.
"""
self.categories[:] = [
category
for category in self.categories
if category.processes
]
[docs]
def remove_dangling_parameters_from_groups(
self,
keep_parameters: str | Sequence[str] | None = None,
) -> None:
"""
Removes names of parameters from parameter groups that are not assigned to any process in any category.
:param keep_parameters: A string, pattern, or sequence of them to specify parameters to keep.
"""
# get a list of all parameters
parameter_names = self.get_parameters("*", flat=True)
# get set of parameters to keep
_keep_parameters = law.util.make_set(keep_parameters) if keep_parameters else set()
# go through groups and remove dangling parameters
for group in self.parameter_groups:
group.parameter_names[:] = [
parameter_name
for parameter_name in group.parameter_names
if (
parameter_name in parameter_names or
law.util.multi_match(parameter_name, _keep_parameters, mode=any)
)
]
[docs]
def remove_empty_parameter_groups(self) -> None:
"""
Removes parameter groups that contain no parameter names.
"""
self.parameter_groups[:] = [
group
for group in self.parameter_groups
if group.parameter_names
]
#
# generators
#
[docs]
def iter_processes(
self,
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
) -> Generator[tuple[DotDict, DotDict], None, None]:
"""
Generator that iteratively yields all processes whose names match *process*, optionally
in all categories whose names match *category*. The yielded value is a 2-tuple containing
the category name and the process object.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:returns: A generator yielding 2-tuples of category name and process object.
"""
processes = self.get_processes(
process=process,
category=category,
match_mode=match_mode,
category_match_mode=category_match_mode,
)
for category_name, processes in processes.items():
for process in processes:
yield (category_name, process)
[docs]
def iter_parameters(
self,
parameter: str | Sequence[str] | None = None,
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
process_match_mode: Callable = any,
) -> Generator[tuple[DotDict, DotDict, DotDict], None, None]:
"""
Generator that iteratively yields all parameters whose names match *parameter*, optionally
in all processes and categories whose names match *process* and *category*. The yielded
value is a 3-tuple containing the category name, the process name, and the parameter object.
:param parameter: A string, pattern, or sequence of them to match parameter names.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the parameter matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:param process_match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:returns: A generator yielding 3-tuples of category name, process name, and parameter object.
"""
parameters = self.get_parameters(
parameter=parameter,
process=process,
category=category,
match_mode=match_mode,
category_match_mode=category_match_mode,
process_match_mode=process_match_mode,
)
for category_name, parameters in parameters.items():
for process_name, parameters in parameters.items():
for parameter in parameters:
yield (category_name, process_name, parameter)
#
# other helpers
#
[docs]
def scale_process(
self,
scale: int | float,
process: str | Sequence[str] | None = None,
category: str | Sequence[str] | None = None,
match_mode: Callable = any,
category_match_mode: Callable = any,
) -> bool:
"""
Sets the scale attribute of all processes whose names match *process*, optionally in all
categories whose names match *category*, to *scale*.
:param scale: The scale value to set for the matching processes.
:param process: A string, pattern, or sequence of them to match process names.
:param category: A string, pattern, or sequence of them to match category names.
:param match_mode: Either ``any`` or ``all`` to control the process matching behavior (see
:py:func:`pattern_matcher`).
:param category_match_mode: Either ``any`` or ``all`` to control the category matching behavior (see
:py:func:`pattern_matcher`).
:returns: *True* if at least one process was found and scaled, *False* otherwise.
"""
found = False
for _, process in self.iter_processes(
process=process,
category=category,
match_mode=match_mode,
category_match_mode=category_match_mode,
):
process.scale = float(scale)
found = True
return found
# shorthand
inference_model = InferenceModel.inference_model