# coding: utf-8
"""
Lightweight mixins task classes.
"""
from __future__ import annotations
import time
import itertools
from collections import Counter
import luigi
import law
import order as od
from columnflow.types import Sequence, Any, Iterable, Union
from columnflow.tasks.framework.base import AnalysisTask, ConfigTask, RESOLVE_DEFAULT
from columnflow.tasks.framework.parameters import SettingsParameter
from columnflow.calibration import Calibrator
from columnflow.selection import Selector
from columnflow.production import Producer
from columnflow.weight import WeightProducer
from columnflow.ml import MLModel
from columnflow.inference import InferenceModel
from columnflow.columnar_util import Route, ColumnCollection, ChunkedIOHandler
from columnflow.util import maybe_import, DotDict
ak = maybe_import("awkward")
logger = law.logger.get_logger(__name__)
[docs]class CalibratorMixin(ConfigTask):
"""
Mixin to include a single :py:class:`~columnflow.calibration.Calibrator` into tasks.
Inheriting from this mixin will give access to instantiate and access a
:py:class:`~columnflow.calibration.Calibrator` instance with name *calibrator*,
which is an input parameter for this task.
"""
calibrator = luigi.Parameter(
default=RESOLVE_DEFAULT,
description="the name of the calibrator to be applied; default: value of the "
"'default_calibrator' config",
)
calibrator.__annotations__ = " ".join("""
the name of the calibrator to be applied; default: value of the
'default_calibrator' config""".split())
# decides whether the task itself runs the calibrator and implements its shifts
register_calibrator_sandbox = False
register_calibrator_shifts = False
[docs] @classmethod
def get_calibrator_inst(cls, calibrator: str, kwargs=None) -> Calibrator:
"""
Initialize :py:class:`~columnflow.calibration.Calibrator` instance.
Extracts relevant *kwargs* for this calibrator instance using the
:py:meth:`~columnflow.tasks.framework.base.AnalaysisTask.get_calibrator_kwargs`
method.
After this process, the previously initialized instance of a
:py:class:`~columnflow.calibration.Calibrator` with the name
*calibrator* is initialized using the
:py:meth:`~columnflow.util.DerivableMeta.get_cls` method with the
relevant keyword arguments.
:param calibrator: Name of the calibrator instance
:param kwargs: Any set keyword argument that is potentially relevant for
this :py:class:`~columnflow.calibration.Calibrator` instance
:raises RuntimeError: if requested :py:class:`~columnflow.calibration.Calibrator` instance
is not :py:attr:`~columnflow.calibration.Calibrator.exposed`
:return: The initialized :py:class:`~columnflow.calibration.Calibrator`
instance.
"""
calibrator_cls: Calibrator = Calibrator.get_cls(calibrator)
if not calibrator_cls.exposed:
raise RuntimeError(f"cannot use unexposed calibrator '{calibrator}' in {cls.__name__}")
inst_dict = cls.get_calibrator_kwargs(**kwargs) if kwargs else None
return calibrator_cls(inst_dict=inst_dict)
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
"""
Resolve parameter values *params* relevant for the
:py:class:`CalibratorMixin` and all classes it inherits from.
Loads the ``config_inst`` and loads the parameter ``"calibrator"``.
In case the parameter is not found, defaults to ``"default_calibrator"``.
Finally, this function adds the keyword ``"calibrator_inst"``, which
contains the :py:class:`~columnflow.calibration.Calibrator` instance
obtained using :py:meth:`~.CalibratorMixin.get_calibrator_inst` method.
:param params: Dictionary with parameters provided by the user at
commandline level.
:return: Dictionary of parameters that now includes new value for
``"calibrator_inst"``.
"""
params = super().resolve_param_values(params)
config_inst = params.get("config_inst")
if config_inst:
# add the default calibrator when empty
params["calibrator"] = cls.resolve_config_default(
params,
params.get("calibrator"),
container=config_inst,
default_str="default_calibrator",
multiple=False,
)
params["calibrator_inst"] = cls.get_calibrator_inst(params["calibrator"], params)
return params
[docs] @classmethod
def get_known_shifts(cls, config_inst: od.Config, params: dict[str, Any]) -> tuple[set[str], set[str]]:
"""
Adds set of shifts that the current ``calibrator_inst`` registers to the
set of known ``shifts`` and ``upstream_shifts``.
First, the set of ``shifts`` and ``upstream_shifts`` are obtained from
the *config_inst* and the current set of parameters *params* using the
``get_known_shifts`` methods of all classes that :py:class:`CalibratorMixin`
inherits from.
Afterwards, check if the current ``calibrator_inst`` registers shifts.
If :py:attr:`~CalibratorMixin.register_calibrator_shifts` is ``True``,
add them to the current set of ``shifts``. Otherwise, add the
shifts obtained from the ``calibrator_inst`` to ``upstream_shifts``.
:param config_inst: Config instance for the current task.
:param params: Dictionary containing the current set of parameters provided
by the user at commandline level
:return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``.
"""
shifts, upstream_shifts = super().get_known_shifts(config_inst, params)
# get the calibrator, update it and add its shifts
calibrator_inst = params.get("calibrator_inst")
if calibrator_inst:
if cls.register_calibrator_shifts:
shifts |= calibrator_inst.all_shifts
else:
upstream_shifts |= calibrator_inst.all_shifts
return shifts, upstream_shifts
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""
Returns the required parameters for the task.
It prefers `--calibrator` set on task-level via command line.
:param inst: The current task instance.
:param kwargs: Additional keyword arguments.
:return: Dictionary of required parameters.
"""
# prefer --calibrator set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrator"}
return super().req_params(inst, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache for calibrator inst
self._calibrator_inst = None
@property
def calibrator_inst(self) -> Calibrator:
"""
Access current :py:class:`~columnflow.calibration.Calibrator` instance.
This method loads the current :py:class:`~columnflow.calibration.Calibrator`
*calibrator_inst* from the cache or initializes it.
If the calibrator requests a specific ``sandbox``, set this sandbox as
the environment for the current :py:class:`~law.task.base.Task`.
:return: Current :py:class:`~columnflow.calibration.Calibrator` instance
"""
if self._calibrator_inst is None:
self._calibrator_inst = self.get_calibrator_inst(self.calibrator, {"task": self})
# overwrite the sandbox when set
if self.register_calibrator_sandbox:
sandbox = self._calibrator_inst.get_sandbox()
if sandbox:
self.sandbox = sandbox
# rebuild the sandbox inst when already initialized
if self._sandbox_initialized:
self._initialize_sandbox(force=True)
return self._calibrator_inst
@property
def calibrator_repr(self):
"""
Return a string representation of the calibrator.
"""
return str(self.calibrator_inst)
[docs] def store_parts(self) -> law.util.InsertableDict[str, str]:
"""
Create parts to create the output path to store intermediary results
for the current :py:class:`~law.task.base.Task`.
This method calls :py:meth:`store_parts` of the ``super`` class and inserts
`{"calibrator": "calib__{self.calibrator}"}` before keyword ``version``.
For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`.
:return: Updated parts to create output path to store intermediary results.
"""
parts = super().store_parts()
parts.insert_before("version", "calibrator", f"calib__{self.calibrator_repr}")
return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]:
"""
Finds the columns to keep based on the *collection*.
If the collection is `ALL_FROM_CALIBRATOR`, it includes the columns produced by the calibrator.
:param collection: The collection of columns.
:return: Set of columns to keep.
"""
columns = super().find_keep_columns(collection)
if collection == ColumnCollection.ALL_FROM_CALIBRATOR:
columns |= self.calibrator_inst.produced_columns
return columns
[docs] @classmethod
def get_config_lookup_keys(
cls,
inst_or_params: CalibratorMixin | dict[str, Any],
) -> law.util.InsertiableDict:
keys = super().get_config_lookup_keys(inst_or_params)
get = (
inst_or_params.get
if isinstance(inst_or_params, dict)
else lambda attr: (getattr(inst_or_params, attr, None))
)
# add the calibrator name
calibrator = get("calibrator")
if calibrator not in {law.NO_STR, None, ""}:
keys["calibrator"] = f"calib_{calibrator}"
return keys
[docs]class CalibratorsMixin(ConfigTask):
"""
Mixin to include multiple :py:class:`~columnflow.calibration.Calibrator` instances into tasks.
Inheriting from this mixin will allow a task to instantiate and access a set of
:py:class:`~columnflow.calibration.Calibrator` instances with names *calibrators*,
which is a comma-separated list of calibrator names and is an input parameter for this task.
"""
calibrators = law.CSVParameter(
default=(RESOLVE_DEFAULT,),
description="comma-separated names of calibrators to be applied; default: value of the "
"'default_calibrator' config",
brace_expand=True,
parse_empty=True,
)
# decides whether the task itself runs the calibrators and implements their shifts
register_calibrators_shifts = False
[docs] @classmethod
def get_calibrator_insts(cls, calibrators: Iterable[str], kwargs=None) -> list[Calibrator]:
"""
Get all requested *calibrators*.
:py:class:`~columnflow.calibration.Calibrator` instances are either
initalized or loaded from cache.
:param calibrators: Names of Calibrators to load
:param kwargs: Additional keyword arguments to forward to individual
:py:class:`~columnflow.calibration.Calibrator` instances
:raises RuntimeError: if requested calibrators are not
:py:attr:`~columnflow.calibration.Calibrator.exposed`
:return: List of :py:class:`~columnflow.calibration.Calibrator` instances.
"""
inst_dict = cls.get_calibrator_kwargs(**kwargs) if kwargs else None
insts = []
for calibrator in calibrators:
calibrator_cls = Calibrator.get_cls(calibrator)
if not calibrator_cls.exposed:
raise RuntimeError(
f"cannot use unexposed calibrator '{calibrator}' in {cls.__name__}",
)
insts.append(calibrator_cls(inst_dict=inst_dict))
return insts
[docs] @classmethod
def resolve_param_values(
cls,
params: law.util.InsertableDict[str, Any],
) -> law.util.InsertableDict[str, Any]:
"""
Resolve values *params* and check against possible default values and
calibrator groups.
Check the values in *params* against the default value ``"default_calibrator"``
and possible group definitions ``"calibrator_groups"`` in the current config inst.
For more information, see
:py:meth:`~columnflow.tasks.framework.base.ConfigTask.resolve_config_default_and_groups`.
:param params: Parameter values to resolve
:return: Dictionary of parameters that contains the list requested
:py:class:`~columnflow.calibration.Calibrator` instances under the
keyword ``"calibrator_insts"``. See :py:meth:`~.CalibratorsMixin.get_calibrator_insts`
for more information.
"""
params = super().resolve_param_values(params)
config_inst = params.get("config_inst")
if config_inst:
params["calibrators"] = cls.resolve_config_default_and_groups(
params,
params.get("calibrators"),
container=config_inst,
default_str="default_calibrator",
groups_str="calibrator_groups",
)
params["calibrator_insts"] = cls.get_calibrator_insts(params["calibrators"], params)
return params
[docs] @classmethod
def get_known_shifts(
cls,
config_inst: od.Config,
params: dict[str, Any],
) -> tuple[set[str], set[str]]:
"""
Adds set of all shifts that the list of ``calibrator_insts`` register to the
set of known ``shifts`` and ``upstream_shifts``.
First, the set of ``shifts`` and ``upstream_shifts`` are obtained from
the *config_inst* and the current set of parameters *params* using the
``get_known_shifts`` methods of all classes that :py:class:`CalibratorsMixin`
inherits from.
Afterwards, loop through the list of :py:class:`~columnflow.calibration.Calibrator`
and check if they register shifts.
If :py:attr:`~CalibratorsMixin.register_calibrators_shifts` is ``True``,
add them to the current set of ``shifts``. Otherwise, add the
shifts to ``upstream_shifts``.
:param config_inst: Config instance for the current task.
:param params: Dictionary containing the current set of parameters provided
by the user at commandline level
:return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``.
"""
shifts, upstream_shifts = super().get_known_shifts(config_inst, params)
# get the calibrators, update them and add their shifts
for calibrator_inst in params.get("calibrator_insts") or []:
if cls.register_calibrators_shifts:
shifts |= calibrator_inst.all_shifts
else:
upstream_shifts |= calibrator_inst.all_shifts
return shifts, upstream_shifts
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""
Returns the required parameters for the task.
It prefers ``--calibrators`` set on task-level via command line.
:param inst: The current task instance.
:param kwargs: Additional keyword arguments.
:return: Dictionary of required parameters.
"""
# prefer --calibrators set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"calibrators"}
return super().req_params(inst, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache for calibrator insts
self._calibrator_insts = None
@property
def calibrator_insts(self) -> list[Calibrator]:
"""
Access current list of :py:class:`~columnflow.calibration.Calibrator` instances.
Loads the current :py:class:`~columnflow.calibration.Calibrator` *calibrator_insts* from
the cache or initializes it.
:return: Current list :py:class:`~columnflow.calibration.Calibrator` instances
"""
if self._calibrator_insts is None:
self._calibrator_insts = self.get_calibrator_insts(self.calibrators, {"task": self})
return self._calibrator_insts
@property
def calibrators_repr(self) -> str:
"""
Return a string representation of the calibrators.
"""
calibs_repr = "none"
if self.calibrators:
calibs_repr = "__".join([str(calib) for calib in self.calibrator_insts[:5]])
if len(self.calibrators) > 5:
calibs_repr += f"__{law.util.create_hash([str(calib) for calib in self.calibrator_insts[5:]])}"
return calibs_repr
[docs] def store_parts(self):
"""
Create parts to create the output path to store intermediary results
for the current :py:class:`~law.task.base.Task`.
Calls :py:meth:`store_parts` of the ``super`` class and inserts
`{"calibrator": "calib__{HASH}"}` before keyword ``version``.
Here, ``HASH`` is the joint string of the first five calibrator names
+ a hash created with :py:meth:`law.util.create_hash` based on
the list of calibrators, starting at its 5th element (i.e. ``self.calibrators[5:]``)
For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`.
:return: Updated parts to create output path to store intermediary results.
"""
parts = super().store_parts()
parts.insert_before("version", "calibrators", f"calib__{self.calibrators_repr}")
return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]:
"""
Finds the columns to keep based on the *collection*.
If the collection is ``ALL_FROM_CALIBRATORS``, it includes the columns produced by the calibrators.
:param collection: The collection of columns.
:return: Set of columns to keep.
"""
columns: set[Route] = super().find_keep_columns(collection)
if collection == ColumnCollection.ALL_FROM_CALIBRATORS:
columns |= set.union(*(
calibrator_inst.produced_columns
for calibrator_inst in self.calibrator_insts
))
return columns
[docs]class SelectorMixin(ConfigTask):
"""
Mixin to include a single :py:class:`~columnflow.selection.Selector`
instances into tasks.
Inheriting from this mixin will allow a task to instantiate and access a
:py:class:`~columnflow.selection.Selector` instance with name *selector*,
which is an input parameter for this task.
"""
selector = luigi.Parameter(
default=RESOLVE_DEFAULT,
description="the name of the selector to be applied; default: value of the "
"'default_selector' config",
)
# decides whether the task itself runs the selector and implements its shifts
register_selector_sandbox = False
register_selector_shifts = False
[docs] @classmethod
def get_selector_inst(
cls,
selector: str,
kwargs=None,
) -> Selector:
"""
Get requested *selector*.
:py:class:`~columnflow.selection.Selector` instance is either
initalized or loaded from cache.
:param selector: Name of :py:class:`~columnflow.selection.Selector` to load
:param kwargs: Additional keyword arguments to forward to the
:py:class:`~columnflow.selection.Selector` instance
:return: :py:class:`~columnflow.selection.Selector` instance.
"""
selector_cls = Selector.get_cls(selector)
if not selector_cls.exposed:
raise RuntimeError(f"cannot use unexposed selector '{selector}' in {cls.__name__}")
inst_dict = cls.get_selector_kwargs(**kwargs) if kwargs else None
return selector_cls(inst_dict=inst_dict)
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict:
"""
Resolve values *params* and check against possible default values and
selector groups.
Check the values in *params* against the default value ``"default_selector"``
in the current config inst. For more information, see
:py:meth:`~columnflow.tasks.framework.base.AnalysisTask.resolve_config_default`.
:param params: Parameter values to resolve
:return: Dictionary of parameters that contains the requested
:py:class:`~columnflow.selection.Selector` instance under the
keyword ``"selector_inst"``.
"""
params = super().resolve_param_values(params)
# add the default selector when empty
config_inst = params.get("config_inst")
if config_inst:
params["selector"] = cls.resolve_config_default(
params,
params.get("selector"),
container=config_inst,
default_str="default_selector",
multiple=False,
)
params["selector_inst"] = cls.get_selector_inst(params["selector"], params)
return params
[docs] @classmethod
def get_known_shifts(
cls,
config_inst: od.Config,
params: dict[str, Any],
) -> tuple[set[str], set[str]]:
"""
Adds set of shifts that the current ``selector_inst`` registers to the
set of known ``shifts`` and ``upstream_shifts``.
First, the set of ``shifts`` and ``upstream_shifts`` are obtained from
the *config_inst* and the current set of parameters *params* using the
``get_known_shifts`` methods of all classes that :py:class:`SelectorMixin`
inherits from.
Afterwards, check if the current ``selector_inst`` registers shifts.
If :py:attr:`~SelectorMixin.register_selector_shifts` is ``True``,
add them to the current set of ``shifts``. Otherwise, add the
shifts obtained from the ``selector_inst`` to ``upstream_shifts``.
:param config_inst: Config instance for the current task.
:param params: Dictionary containing the current set of parameters provided
by the user at commandline level
:return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``.
"""
shifts, upstream_shifts = super().get_known_shifts(config_inst, params)
# get the selector, update it and add its shifts
selector_inst = params.get("selector_inst")
if selector_inst:
if cls.register_selector_shifts:
shifts |= selector_inst.all_shifts
else:
upstream_shifts |= selector_inst.all_shifts
return shifts, upstream_shifts
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""
Get the required parameters for the task, preferring the ``--selector`` set on task-level via CLI.
This method first checks if the --selector parameter is set at the task-level via the command line.
If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary.
The method then calls the 'req_params' method of the superclass with the updated kwargs.
:param inst: The current task instance.
:param kwargs: Additional keyword arguments that may contain parameters for the task.
:return: A dictionary of parameters required for the task.
"""
# prefer --selector set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"selector"}
return super().req_params(inst, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache for selector inst
self._selector_inst = None
@property
def selector_inst(self):
"""
Access current :py:class:`~columnflow.selection.Selector` instance.
Loads the current :py:class:`~columnflow.selection.Selector` *selector_inst* from
the cache or initializes it.
If the selector requests a specific ``sandbox``, set this sandbox as
the environment for the current :py:class:`~law.task.base.Task`.
:return: Current :py:class:`~columnflow.selection.Selector` instance
"""
if self._selector_inst is None:
self._selector_inst = self.get_selector_inst(self.selector, {"task": self})
# overwrite the sandbox when set
if self.register_selector_sandbox:
sandbox = self._selector_inst.get_sandbox()
if sandbox:
self.sandbox = sandbox
# rebuild the sandbox inst when already initialized
if self._sandbox_initialized:
self._initialize_sandbox(force=True)
return self._selector_inst
@property
def selector_repr(self):
"""
Return a string representation of the selector.
"""
return str(self.selector_inst)
[docs] def store_parts(self):
"""
Create parts to create the output path to store intermediary results
for the current :py:class:`~law.task.base.Task`.
Calls :py:meth:`store_parts` of the ``super`` class and inserts
`{"selector": "sel__{SELECTOR_NAME}"}` before keyword ``version``.
Here, ``SELECTOR_NAME`` is the name of the current ``selector_inst``.
:return: Updated parts to create output path to store intermediary results.
"""
parts = super().store_parts()
parts.insert_before("version", "selector", f"sel__{self.selector_repr}")
return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]:
columns = super().find_keep_columns(collection)
if collection == ColumnCollection.ALL_FROM_SELECTOR:
columns |= self.selector_inst.produced_columns
return columns
[docs] @classmethod
def get_config_lookup_keys(
cls,
inst_or_params: SelectorMixin | dict[str, Any],
) -> law.util.InsertiableDict:
keys = super().get_config_lookup_keys(inst_or_params)
get = (
inst_or_params.get
if isinstance(inst_or_params, dict)
else lambda attr: (getattr(inst_or_params, attr, None))
)
# add the selector name
selector = get("selector")
if selector not in {law.NO_STR, None, ""}:
keys["selector"] = f"sel_{selector}"
return keys
[docs]class SelectorStepsMixin(SelectorMixin):
"""
Mixin to include multiple selector steps into tasks.
Inheriting from this mixin will allow a task to access selector steps, which can be a
comma-separated list of selector step names and is an input parameter for this task.
"""
selector_steps = law.CSVParameter(
default=(),
description="a subset of steps of the selector to apply; uses all steps when empty; "
"default: empty",
brace_expand=True,
parse_empty=True,
)
exclude_params_repr_empty = {"selector_steps"}
selector_steps_order_sensitive = False
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
"""
Resolve values *params* and check against possible default values and
selector step groups.
Check the values in *params* against the default value ``"default_selector_steps"``
and the group ``"selector_step_groups"`` in the current config inst.
For more information, see
:py:meth:`~columnflow.tasks.framework.base.AnalysisTask.resolve_config_default`.
If :py:attr:`SelectorStepsMixin.selector_steps_order_sensitive` is ``True``,
:py:func:`sort <sorted>` the selector steps.
:param params: Parameter values to resolve
:return: Dictionary of parameters that contains the requested
selector steps under the keyword ``"selector_steps"``.
"""
params = super().resolve_param_values(params)
# apply selector_steps_groups and default_selector_steps from config
config_inst = params.get("config_inst")
if config_inst:
params["selector_steps"] = cls.resolve_config_default_and_groups(
params,
params.get("selector_steps"),
container=config_inst,
default_str="default_selector_steps",
groups_str="selector_step_groups",
)
# sort selector steps when the order does not matter
if not cls.selector_steps_order_sensitive and "selector_steps" in params:
params["selector_steps"] = tuple(sorted(params["selector_steps"]))
return params
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""
Get the required parameters for the task, preferring the --selector-steps set on task-level via CLI.
This method first checks if the --selector-steps parameter is set at the task-level via the command line.
If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary.
The method then calls the 'req_params' method of the superclass with the updated kwargs.
:param inst: The current task instance.
:param kwargs: Additional keyword arguments that may contain parameters for the task.
:return: A dictionary of parameters required for the task.
"""
# prefer --selector-steps set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"selector_steps"}
return super().req_params(inst, **kwargs)
[docs] def store_parts(self) -> law.util.InsertableDict:
"""
Create parts to create the output path to store intermediary results
for the current :py:class:`~law.task.base.Task`.
Calls :py:meth:`store_parts` of the ``super`` class and inserts
`{"selector": "__steps__LIST_OF_STEPS"}`, where ``LIST_OF_STEPS`` is the
sorted list of selector steps.
For more information, see e.g.
:py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`.
:return: Updated parts to create output path to store intermediary results.
"""
parts = super().store_parts()
steps = self.selector_steps
if not self.selector_steps_order_sensitive:
steps = sorted(steps)
if steps:
parts["selector"] += "__steps_" + "_".join(steps)
return parts
[docs]class ProducerMixin(ConfigTask):
"""
Mixin to include a single :py:class:`~columnflow.production.Producer` into tasks.
Inheriting from this mixin will give access to instantiate and access a
:py:class:`~columnflow.production.Producer` instance with name *producer*,
which is an input parameter for this task.
"""
producer = luigi.Parameter(
default=RESOLVE_DEFAULT,
description="the name of the producer to be applied; default: value of the "
"'default_producer' config",
)
# decides whether the task itself runs the producer and implements its shifts
register_producer_sandbox = False
register_producer_shifts = False
[docs] @classmethod
def get_producer_inst(cls, producer: str, kwargs=None) -> Producer:
"""
Initialize :py:class:`~columnflow.production.Producer` instance.
Extracts relevant *kwargs* for this producer instance using the
:py:meth:`~columnflow.tasks.framework.base.AnalaysisTask.get_producer_kwargs`
method.
After this process, the previously initialized instance of a
:py:class:`~columnflow.production.Producer` with the name
*producer* is initialized using the
:py:meth:`~columnflow.util.DerivableMeta.get_cls` method with the
relevant keyword arguments.
:param producer: Name of the :py:class:`~columnflow.production.Producer`
instance
:param kwargs: Any set keyword argument that is potentially relevant for
this :py:class:`~columnflow.production.Producer` instance
:raises RuntimeError: if requested :py:class:`~columnflow.production.Producer` instance
is not :py:attr:`~columnflow.production.Producer.exposed`
:return: The initialized :py:class:`~columnflow.production.Producer`
instance.
"""
producer_cls: Producer = Producer.get_cls(producer)
if not producer_cls.exposed:
raise RuntimeError(f"cannot use unexposed producer '{producer}' in {cls.__name__}")
inst_dict = cls.get_producer_kwargs(**kwargs) if kwargs else None
return producer_cls(inst_dict=inst_dict)
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
"""
Resolve parameter values *params* relevant for the
:py:class:`ProducerMixin` and all classes it inherits from.
Loads the ``config_inst`` and loads the parameter ``"producer"``.
In case the parameter is not found, defaults to ``"default_producer"``.
Finally, this function adds the keyword ``"producer_inst"``, which
contains the :py:class:`~columnflow.production.Producer` instance
obtained using :py:meth:`~.ProducerMixin.get_producer_inst` method.
:param params: Dictionary with parameters provided by the user at
commandline level.
:return: Dictionary of parameters that now includes new value for
``"producer_inst"``.
"""
params = super().resolve_param_values(params)
# add the default producer when empty
config_inst = params.get("config_inst")
if config_inst:
params["producer"] = cls.resolve_config_default(
params,
params.get("producer"),
container=config_inst,
default_str="default_producer",
multiple=False,
)
params["producer_inst"] = cls.get_producer_inst(params["producer"], params)
return params
[docs] @classmethod
def get_known_shifts(cls, config_inst: od.Config, params: dict[str, Any]) -> tuple[set[str], set[str]]:
"""
Adds set of shifts that the current ``producer_inst`` registers to the
set of known ``shifts`` and ``upstream_shifts``.
First, the set of ``shifts`` and ``upstream_shifts`` are obtained from
the *config_inst* and the current set of parameters *params* using the
``get_known_shifts`` methods of all classes that :py:class:`ProducerMixin`
inherits from.
Afterwards, check if the current ``producer_inst`` registers shifts.
If :py:attr:`~ProducerMixin.register_producer_shifts` is ``True``,
add them to the current set of ``shifts``. Otherwise, add the
shifts obtained from the ``producer_inst`` to ``upstream_shifts``.
:param config_inst: Config instance for the current task.
:param params: Dictionary containing the current set of parameters provided
by the user at commandline level
:return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``.
"""
shifts, upstream_shifts = super().get_known_shifts(config_inst, params)
# get the producer, update it and add its shifts
producer_inst = params.get("producer_inst")
if producer_inst:
if cls.register_producer_shifts:
shifts |= producer_inst.all_shifts
else:
upstream_shifts |= producer_inst.all_shifts
return shifts, upstream_shifts
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""
Get the required parameters for the task, preferring the ``--producer`` set on task-level via CLI.
This method first checks if the ``--producer`` parameter is set at the task-level via the command line.
If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary.
The method then calls the 'req_params' method of the superclass with the updated kwargs.
:param inst: The current task instance.
:param kwargs: Additional keyword arguments that may contain parameters for the task.
:return: A dictionary of parameters required for the task.
"""
# prefer --producer set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producer"}
return super().req_params(inst, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache for producer inst
self._producer_inst = None
@property
def producer_inst(self) -> Producer:
"""
Access current :py:class:`~columnflow.production.Producer` instance.
Loads the current :py:class:`~columnflow.production.Producer` *producer_inst* from
the cache or initializes it.
If the producer requests a specific ``sandbox``, set this sandbox as
the environment for the current :py:class:`~law.task.base.Task`.
:return: Current :py:class:`~columnflow.production.Producer` instance
"""
if self._producer_inst is None:
self._producer_inst = self.get_producer_inst(self.producer, {"task": self})
# overwrite the sandbox when set
if self.register_producer_sandbox:
sandbox = self._producer_inst.get_sandbox()
if sandbox:
self.sandbox = sandbox
# rebuild the sandbox inst when already initialized
if self._sandbox_initialized:
self._initialize_sandbox(force=True)
return self._producer_inst
@property
def producer_repr(self) -> str:
"""
Return a string representation of the producer.
"""
return str(self.producer_inst) if self.producer != law.NO_STR else "none"
[docs] def store_parts(self) -> law.util.InsertableDict[str, str]:
"""
Create parts to create the output path to store intermediary results
for the current :py:class:`~law.task.base.Task`.
Calls :py:meth:`store_parts` of the ``super`` class and inserts
`{"producer": "prod__{self.producer}"}` before keyword ``version``.
For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`.
:return: Updated parts to create output path to store intermediary results.
"""
parts = super().store_parts()
producer = f"prod__{self.producer_repr}"
parts.insert_before("version", "producer", producer)
return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]:
"""
Finds the columns to keep based on the *collection*.
This method first calls the 'find_keep_columns' method of the superclass with the given *collection*.
If the *collection* is equal to ``ALL_FROM_PRODUCER``, it adds the
columns produced by the producer instance to the set of columns.
:param collection: The collection of columns.
:return: A set of columns to keep.
"""
columns = super().find_keep_columns(collection)
if collection == ColumnCollection.ALL_FROM_PRODUCER:
columns |= self.producer_inst.produced_columns
return columns
[docs] @classmethod
def get_config_lookup_keys(
cls,
inst_or_params: ProducerMixin | dict[str, Any],
) -> law.util.InsertiableDict:
keys = super().get_config_lookup_keys(inst_or_params)
get = (
inst_or_params.get
if isinstance(inst_or_params, dict)
else lambda attr: (getattr(inst_or_params, attr, None))
)
# add the producer name
producer = get("producer")
if producer not in {law.NO_STR, None, ""}:
keys["producer"] = f"prod_{producer}"
return keys
[docs]class ProducersMixin(ConfigTask):
"""
Mixin to include multiple :py:class:`~columnflow.production.Producer` instances into tasks.
Inheriting from this mixin will allow a task to instantiate and access a set of
:py:class:`~columnflow.production.Producer` instances with names *producers*,
which is a comma-separated list of producer names and is an input parameter for this task.
"""
producers = law.CSVParameter(
default=(RESOLVE_DEFAULT,),
description="comma-separated names of producers to be applied; default: value of the "
"'default_producer' config",
brace_expand=True,
parse_empty=True,
)
# decides whether the task itself runs the producers and implements their shifts
register_producers_shifts = False
[docs] @classmethod
def get_producer_insts(cls, producers: Iterable[str], kwargs=None) -> list[Producer]:
"""
Get all requested *producers*.
:py:class:`~columnflow.production.Producer` instances are either
initalized or loaded from cache.
:param producers: Names of :py:class:`~columnflow.production.Producer`
instances to load
:param kwargs: Additional keyword arguments to forward to individual
:py:class:`~columnflow.production.Producer` instances
:raises RuntimeError: if requested producers are not
:py:attr:`~columnflow.production.Producer.exposed`
:return: List of :py:class:`~columnflow.production.Producer` instances.
"""
inst_dict = cls.get_producer_kwargs(**kwargs) if kwargs else None
insts = []
for producer in producers:
producer_cls = Producer.get_cls(producer)
if not producer_cls.exposed:
raise RuntimeError(f"cannot use unexposed producer '{producer}' in {cls.__name__}")
insts.append(producer_cls(inst_dict=inst_dict))
return insts
[docs] @classmethod
def resolve_param_values(
cls,
params: law.util.InsertableDict[str, Any],
) -> law.util.InsertableDict[str, Any]:
"""
Resolve values *params* and check against possible default values and
producer groups.
Check the values in *params* against the default value ``"default_producer"``
and possible group definitions ``"producer_groups"`` in the current config inst.
For more information, see
:py:meth:`~columnflow.tasks.framework.base.ConfigTask.resolve_config_default_and_groups`.
:param params: Parameter values to resolve
:return: Dictionary of parameters that contains the list requested
:py:class:`~columnflow.production.Producer` instances under the
keyword ``"producer_insts"``. See :py:meth:`~.ProducersMixin.get_producer_insts`
for more information.
"""
params = super().resolve_param_values(params)
config_inst = params.get("config_inst")
if config_inst:
params["producers"] = cls.resolve_config_default_and_groups(
params,
params.get("producers"),
container=config_inst,
default_str="default_producer",
groups_str="producer_groups",
)
params["producer_insts"] = cls.get_producer_insts(params["producers"], params)
return params
[docs] @classmethod
def get_known_shifts(cls, config_inst: od.Config, params: dict[str, Any]) -> tuple[set[str], set[str]]:
"""
Adds set of all shifts that the list of ``producer_insts`` register to the
set of known ``shifts`` and ``upstream_shifts``.
First, the set of ``shifts`` and ``upstream_shifts`` are obtained from
the *config_inst* and the current set of parameters *params* using the
``get_known_shifts`` methods of all classes that :py:class:`ProducersMixin`
inherits from.
Afterwards, loop through the list of :py:class:`~columnflow.production.Producer`
and check if they register shifts.
If :py:attr:`~ProducersMixin.register_producers_shifts` is ``True``,
add them to the current set of ``shifts``. Otherwise, add the
shifts to ``upstream_shifts``.
:param config_inst: Config instance for the current task.
:param params: Dictionary containing the current set of parameters provided
by the user at commandline level
:return: Tuple with updated sets of ``shifts`` and ``upstream_shifts``.
"""
shifts, upstream_shifts = super().get_known_shifts(config_inst, params)
# get the producers, update them and add their shifts
for producer_inst in params.get("producer_insts") or []:
if cls.register_producers_shifts:
shifts |= producer_inst.all_shifts
else:
upstream_shifts |= producer_inst.all_shifts
return shifts, upstream_shifts
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""
Get the required parameters for the task, preferring the --producers set on task-level via CLI.
This method first checks if the --producers parameter is set at the task-level via the command line.
If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary.
The method then calls the 'req_params' method of the superclass with the updated kwargs.
:param inst: The current task instance.
:param kwargs: Additional keyword arguments that may contain parameters for the task.
:return: A dictionary of parameters required for the task.
"""
# prefer --producers set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"producers"}
return super().req_params(inst, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache for producer insts
self._producer_insts = None
@property
def producer_insts(self) -> list[Producer]:
"""
Access current list of :py:class:`~columnflow.production.Producer` instances.
Loads the current :py:class:`~columnflow.production.Producer` *producer_insts* from
the cache or initializes it.
:return: Current list :py:class:`~columnflow.production.Producer` instances
"""
if self._producer_insts is None:
self._producer_insts = self.get_producer_insts(self.producers, {"task": self})
return self._producer_insts
@property
def producers_repr(self) -> str:
"""Return a string representation of the producers."""
prods_repr = "none"
if self.producers:
prods_repr = "__".join([str(prod) for prod in self.producer_insts[:5]])
if len(self.producers) > 5:
prods_repr += f"__{law.util.create_hash([str(prod) for prod in self.producer_insts[5:]])}"
return prods_repr
[docs] def store_parts(self):
"""
Create parts to create the output path to store intermediary results
for the current :py:class:`~law.task.base.Task`.
Calls :py:meth:`store_parts` of the ``super`` class and inserts
`{"producers": "prod__{HASH}"}` before keyword ``version``.
Here, ``HASH`` is the joint string of the first five producer names
+ a hash created with :py:meth:`law.util.create_hash` based on
the list of producers, starting at its 5th element (i.e. ``self.producers[5:]``)
For more information, see e.g. :py:meth:`~columnflow.tasks.framework.base.ConfigTask.store_parts`.
:return: Updated parts to create output path to store intermediary results.
"""
parts = super().store_parts()
parts.insert_before("version", "producers", f"prod__{self.producers_repr}")
return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]:
"""
Finds the columns to keep based on the *collection*.
This method first calls the 'find_keep_columns' method of the superclass with the given *collection*.
If the *collection* is equal to ``ALL_FROM_PRODUCERS``, it adds the
columns produced by all producer instances to the set of columns.
:param collection: The collection of columns.
:return: A set of columns to keep.
"""
columns = super().find_keep_columns(collection)
if collection == ColumnCollection.ALL_FROM_PRODUCERS:
columns |= set.union(*(
producer_inst.produced_columns
for producer_inst in self.producer_insts
))
return columns
[docs]class MLModelMixinBase(AnalysisTask):
"""
Base mixin to include a machine learning application into tasks.
Inheriting from this mixin will allow a task to instantiate and access a
:py:class:`~columnflow.ml.MLModel` instance with name *ml_model*, which is an input parameter
for this task.
"""
ml_model = luigi.Parameter(
description="the name of the ML model to be applied",
)
ml_model_settings = SettingsParameter(
default=DotDict(),
description="settings passed to the init function of the ML model",
)
exclude_params_repr_empty = {"ml_model"}
@property
def ml_model_repr(self):
"""
Returns a string representation of the ML model instance.
"""
return str(self.ml_model_inst)
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""
Get the required parameters for the task, preferring the ``--ml-model`` set on task-level
via CLI.
This method first checks if the ``--ml-model`` parameter is set at the task-level via the command line.
If it is, this parameter is preferred and added to the '_prefer_cli' key in the kwargs dictionary.
The method then calls the 'req_params' method of the superclass with the updated kwargs.
:param inst: The current task instance.
:param kwargs: Additional keyword arguments that may contain parameters for the task.
:return: A dictionary of parameters required for the task.
"""
# prefer --ml-model set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"ml_model"}
return super().req_params(inst, **kwargs)
[docs] @classmethod
def get_ml_model_inst(
cls,
ml_model: str,
analysis_inst: od.Analysis,
requested_configs: list[str] | None = None,
**kwargs,
) -> MLModel:
"""
Get requested *ml_model* instance.
This method retrieves the requested *ml_model* instance.
If *requested_configs* are provided, they are used for the training of
the ML application.
:param ml_model: Name of :py:class:`~columnflow.ml.MLModel` to load.
:param analysis_inst: Forward this analysis inst to the init function of new MLModel sub class.
:param requested_configs: Configs needed for the training of the ML application.
:param kwargs: Additional keyword arguments to forward to the :py:class:`~columnflow.ml.MLModel` instance.
:return: :py:class:`~columnflow.ml.MLModel` instance.
"""
ml_model_inst: MLModel = MLModel.get_cls(ml_model)(analysis_inst, **kwargs)
if requested_configs:
configs = ml_model_inst.training_configs(list(requested_configs))
if configs:
ml_model_inst._setup(configs)
return ml_model_inst
[docs] def events_used_in_training(
self,
config_inst: od.config.Config,
dataset_inst: od.dataset.Dataset,
shift_inst: od.shift.Shift,
) -> bool:
"""
Evaluate whether the events for the combination of *dataset_inst* and
*shift_inst* shall be used in the training.
This method checks if the *dataset_inst* is in the set of datasets of
the current `ml_model_inst` based on the given *config_inst*. Additionally,
the function checks that the *shift_inst* does not have the tag
`"disjoint_from_nominal"`.
:param config_inst: The configuration instance.
:param dataset_inst: The dataset instance.
:param shift_inst: The shift instance.
:return: True if the events shall be used in the training, False otherwise.
"""
# evaluate whether the events for the combination of dataset_inst and shift_inst
# shall be used in the training
return (
dataset_inst in self.ml_model_inst.datasets(config_inst) and
not shift_inst.has_tag("disjoint_from_nominal")
)
[docs]class MLModelTrainingMixin(MLModelMixinBase):
"""
A mixin class for training machine learning models.
This class provides parameters for configuring the training of machine learning models.
"""
configs = law.CSVParameter(
default=(),
description="comma-separated names of analysis config to use; should only contain a single "
"name in case the ml model is bound to a single config; when empty, the ml model is "
"expected to fully define the configs it uses; empty default",
brace_expand=True,
parse_empty=True,
)
calibrators = law.MultiCSVParameter(
default=(),
description="multiple comma-separated sequences of names of calibrators to apply, "
"separated by ':'; each sequence corresponds to a config in --configs; when empty, the "
"'default_calibrator' setting of each config is used if set, or the model is expected to "
"fully define the calibrators it requires upstream; empty default",
brace_expand=True,
parse_empty=True,
)
selectors = law.CSVParameter(
default=(),
description="comma-separated names of selectors to apply; each selector corresponds to a "
"config in --configs; when empty, the 'default_selector' setting of each config is used if "
"set, or the ml model is expected to fully define the selector it uses requires upstream; "
"empty default",
brace_expand=True,
parse_empty=True,
)
producers = law.MultiCSVParameter(
default=(),
description="multiple comma-separated sequences of names of producers to apply, "
"separated by ':'; each sequence corresponds to a config in --configs; when empty, the "
"'default_producer' setting of each config is used if set, or ml model is expected to "
"fully define the producers it requires upstream; empty default",
brace_expand=True,
parse_empty=True,
)
[docs] @classmethod
def resolve_calibrators(
cls,
ml_model_inst: MLModel,
params: dict[str, Any],
) -> tuple[tuple[str]]:
"""
Resolve the calibrators for the given ML model instance.
This method retrieves the calibrators from the parameters *params* and
broadcasts them to the configs if necessary.
It also resolves `calibrator_groups` and `default_calibrator` from the config(s) associated
with this ML model instance, and validates the number of sequences.
Finally, it checks the retrieved calibrators against
the training calibrators of the model using
:py:meth:`~columnflow.ml.MLModel.training_calibrators` and instantiates them if necessary.
:param ml_model_inst: The ML model instance.
:param params: A dictionary of parameters that may contain the calibrators.
:return: A tuple of tuples containing the resolved calibrators.
:raises Exception: If the number of calibrator sequences does not match
the number of configs used by the ML model.
"""
calibrators: Union[tuple[str], tuple[tuple[str]]] = params.get("calibrators") or ((),)
# broadcast to configs
n_configs = len(ml_model_inst.config_insts)
if len(calibrators) == 1 and n_configs != 1:
calibrators = tuple(calibrators * n_configs)
# apply calibrators_groups and default_calibrator from the config
calibrators = tuple(
ConfigTask.resolve_config_default_and_groups(
params,
calibrators[i],
container=config_inst,
default_str="default_calibrator",
groups_str="calibrator_groups",
)
for i, config_inst in enumerate(ml_model_inst.config_insts)
)
# validate number of sequences
if len(calibrators) != n_configs:
raise Exception(
f"MLModel '{ml_model_inst.cls_name}' uses {n_configs} configs but received "
f"{len(calibrators)} calibrator sequences",
)
# final check by model
calibrators = tuple(
tuple(ml_model_inst.training_calibrators(config_inst, list(_calibrators)))
for config_inst, _calibrators in zip(ml_model_inst.config_insts, calibrators)
)
# instantiate them once
for config_inst, _calibrators in zip(ml_model_inst.config_insts, calibrators):
init_kwargs = law.util.merge_dicts(params, {"config_inst": config_inst})
for calibrator in _calibrators:
CalibratorMixin.get_calibrator_inst(calibrator, kwargs=init_kwargs)
return calibrators
[docs] @classmethod
def resolve_selectors(
cls,
ml_model_inst: MLModel,
params: dict[str, Any],
) -> tuple[str]:
"""
Resolve the selectors for the given ML model instance.
This method retrieves the selectors from the parameters *params* and
broadcasts them to the configs if necessary.
It also resolves `default_selector` from the config(s) associated
with this ML model instance, validates the number of sequences.
Finally, it checks the retrieved selectors against the training selectors
of the model, using
:py:meth:`~columnflow.ml.MLModel.training_selector`, and instantiates them.
:param ml_model_inst: The ML model instance.
:param params: A dictionary of parameters that may contain the selectors.
:return: A tuple containing the resolved selectors.
:raises Exception: If the number of selector sequences does not match
the number of configs used by the ML model.
"""
selectors = params.get("selectors") or (None,)
# broadcast to configs
n_configs = len(ml_model_inst.config_insts)
if len(selectors) == 1 and n_configs != 1:
selectors = tuple(selectors * n_configs)
# use config defaults
selectors = tuple(
ConfigTask.resolve_config_default(
params,
selectors[i],
container=config_inst,
default_str="default_selector",
multiple=False,
)
for i, config_inst in enumerate(ml_model_inst.config_insts)
)
# validate sequence length
if len(selectors) != n_configs:
raise Exception(
f"MLModel '{ml_model_inst.cls_name}' uses {n_configs} configs but received "
f"{len(selectors)} selectors",
)
# final check by model
selectors = tuple(
ml_model_inst.training_selector(config_inst, selector)
for config_inst, selector in zip(ml_model_inst.config_insts, selectors)
)
# instantiate them once
for config_inst, selector in zip(ml_model_inst.config_insts, selectors):
init_kwargs = law.util.merge_dicts(params, {"config_inst": config_inst})
SelectorMixin.get_selector_inst(selector, kwargs=init_kwargs)
return selectors
[docs] @classmethod
def resolve_producers(
cls,
ml_model_inst: MLModel,
params: dict[str, Any],
) -> tuple[tuple[str]]:
"""
Resolve the producers for the given ML model instance.
This method retrieves the producers from the parameters *params* and
broadcasts them to the configs if necessary.
It also resolves `producer_groups` and `default_producer` from the config(s) associated
with this ML model instance, validates the number of sequences.
Finally, it checks the retrieved producers against the training producers
of the model, using
:py:meth:`~columnflow.ml.MLModel.training_producers`, and instantiates them.
:param ml_model_inst: The ML model instance.
:param params: A dictionary of parameters that may contain the producers.
:return: A tuple of tuples containing the resolved producers.
:raises Exception: If the number of producer sequences does not match
the number of configs used by the ML model.
"""
producers = params.get("producers") or ((),)
# broadcast to configs
n_configs = len(ml_model_inst.config_insts)
if len(producers) == 1 and n_configs != 1:
producers = tuple(producers * n_configs)
# apply producers_groups and default_producer from the config
producers = tuple(
ConfigTask.resolve_config_default_and_groups(
params,
producers[i],
container=config_inst,
default_str="default_producer",
groups_str="producer_groups",
)
for i, config_inst in enumerate(ml_model_inst.config_insts)
)
# validate number of sequences
if len(producers) != n_configs:
raise Exception(
f"MLModel '{ml_model_inst.cls_name}' uses {n_configs} configs but received "
f"{len(producers)} producer sequences",
)
# final check by model
producers = tuple(
tuple(ml_model_inst.training_producers(config_inst, list(_producers)))
for config_inst, _producers in zip(ml_model_inst.config_insts, producers)
)
# instantiate them once
for config_inst, _producers in zip(ml_model_inst.config_insts, producers):
init_kwargs = law.util.merge_dicts(params, {"config_inst": config_inst})
for producer in _producers:
ProducerMixin.get_producer_inst(producer, kwargs=init_kwargs)
return producers
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
"""
Resolve the parameter values for the given parameters.
This method retrieves the parameters and resolves the ML model instance, configs,
calibrators, selectors, and producers. It also calls the model's setup hook.
:param params: A dictionary of parameters that may contain the analysis instance and ML model.
:return: A dictionary containing the resolved parameters.
:raises Exception: If the ML model instance received configs to define training configs,
but did not define any.
"""
params = super().resolve_param_values(params)
if "analysis_inst" in params and "ml_model" in params:
analysis_inst = params["analysis_inst"]
# NOTE: we could try to implement resolving the default ml_model here
ml_model_inst = cls.get_ml_model_inst(
params["ml_model"],
analysis_inst,
parameters=params["ml_model_settings"],
)
params["ml_model_inst"] = ml_model_inst
# resolve configs
_configs = params.get("configs", ())
params["configs"] = tuple(ml_model_inst.training_configs(list(_configs)))
if not params["configs"]:
raise Exception(
f"MLModel '{ml_model_inst.cls_name}' received configs '{_configs}' to define "
"training configs, but did not define any",
)
ml_model_inst._set_configs(params["configs"])
# resolve calibrators
params["calibrators"] = cls.resolve_calibrators(ml_model_inst, params)
# resolve selectors
params["selectors"] = cls.resolve_selectors(ml_model_inst, params)
# resolve producers
params["producers"] = cls.resolve_producers(ml_model_inst, params)
# call the model's setup hook
ml_model_inst._setup()
return params
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# get the ML model instance
self.ml_model_inst = self.get_ml_model_inst(
self.ml_model,
self.analysis_inst,
configs=list(self.configs),
parameters=self.ml_model_settings,
)
[docs] def store_parts(self) -> law.util.InsertableDict[str, str]:
"""
Generate a dictionary of store parts for the current instance.
This method extends the base method to include additional parts related to machine learning
model configurations, calibrators, selectors, producers (CSP), and the ML model instance itself.
If the list of either of the CSPs is empty, the corresponding part is set to ``"none"``,
otherwise, the first two elements of the list are joined with ``"__"``.
If the list of either of the CSPs contains more than two elements, the part is extended
with the number of elements and a hash of the remaining elements, which is
created with :py:meth:`law.util.create_hash`.
The parts are represented as strings and are used to create unique identifiers for the
instance's output.
:return: An InsertableDict containing the store parts.
"""
parts = super().store_parts()
# since MLTraining is no CalibratorsMixin, SelectorMixin, ProducerMixin, ConfigTask,
# all these parts are missing in the `store_parts`
configs_repr = "__".join(self.configs[:5])
if len(self.configs) > 5:
configs_repr += f"_{law.util.create_hash(self.configs[5:])}"
parts.insert_after("task_family", "configs", configs_repr)
for label, fct_names in [
("calib", self.calibrators),
("sel", tuple((sel,) for sel in self.selectors)),
("prod", self.producers),
]:
if not fct_names or not any(fct_names):
fct_names = ["none"]
elif len(set(fct_names)) == 1:
# when functions are the same per config, only use them once
fct_names = fct_names[0]
n_fct_per_config = str(len(fct_names))
else:
# when functions differ between configs, flatten
n_fct_per_config = "".join(str(len(x)) for x in fct_names)
fct_names = tuple(fct_name for fct_names_cfg in fct_names for fct_name in fct_names_cfg)
part = "__".join(fct_names[:2])
if len(fct_names) > 2:
part += f"_{n_fct_per_config}_{law.util.create_hash(fct_names[2:])}"
parts.insert_before("version", label, f"{label}__{part}")
if self.ml_model_inst:
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_repr}")
return parts
[docs]class MLModelMixin(ConfigTask, MLModelMixinBase):
ml_model = luigi.Parameter(
default=RESOLVE_DEFAULT,
description="the name of the ML model to be applied; default: value of the "
"'default_ml_model' config",
)
allow_empty_ml_model = True
exclude_params_repr_empty = {"ml_model"}
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params = super().resolve_param_values(params)
# add the default ml model when empty
if "analysis_inst" in params and "config_inst" in params:
analysis_inst = params["analysis_inst"]
config_inst = params["config_inst"]
params["ml_model"] = cls.resolve_config_default(
params,
params.get("ml_model"),
container=config_inst,
default_str="default_ml_model",
multiple=False,
)
# initialize it once to trigger its set_config hook which might, in turn,
# add objects to the config itself
if params.get("ml_model") not in (None, law.NO_STR):
params["ml_model_inst"] = cls.get_ml_model_inst(
params["ml_model"],
analysis_inst,
requested_configs=[config_inst],
parameters=params["ml_model_settings"],
)
elif not cls.allow_empty_ml_model:
raise Exception(f"no ml_model configured for {cls.task_family}")
return params
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# get the ML model instance
self.ml_model_inst = None
if self.ml_model != law.NO_STR:
self.ml_model_inst = self.get_ml_model_inst(
self.ml_model,
self.analysis_inst,
requested_configs=[self.config_inst],
parameters=self.ml_model_settings,
)
[docs] def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()
if self.ml_model_inst:
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_repr}")
return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]:
columns = super().find_keep_columns(collection)
if collection == ColumnCollection.ALL_FROM_ML_EVALUATION and self.ml_model_inst:
columns |= set.union(*self.ml_model_inst.produced_columns().values())
return columns
[docs]class MLModelDataMixin(MLModelMixin):
allow_empty_ml_model = False
[docs] def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()
# replace the ml_model entry
store_name = self.ml_model_inst.store_name or self.ml_model_repr
parts.insert_before("ml_model", "ml_data", f"ml__{store_name}")
parts.pop("ml_model")
return parts
[docs]class MLModelsMixin(ConfigTask):
ml_models = law.CSVParameter(
default=(RESOLVE_DEFAULT,),
description="comma-separated names of ML models to be applied; default: value of the "
"'default_ml_model' config",
brace_expand=True,
parse_empty=True,
)
allow_empty_ml_models = True
exclude_params_repr_empty = {"ml_models"}
@property
def ml_models_repr(self):
"""Returns a string representation of the ML models."""
ml_models_repr = "__".join([str(model_inst) for model_inst in self.ml_model_insts])
return ml_models_repr
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params = super().resolve_param_values(params)
analysis_inst = params.get("analysis_inst")
config_inst = params.get("config_inst")
if analysis_inst and config_inst:
# apply ml_model_groups and default_ml_model from the config
params["ml_models"] = cls.resolve_config_default_and_groups(
params,
params.get("ml_models"),
container=config_inst,
default_str="default_ml_model",
groups_str="ml_model_groups",
)
# special case: initialize them once to trigger their set_config hook
if params.get("ml_models"):
params["ml_model_insts"] = [
MLModelMixinBase.get_ml_model_inst(
ml_model,
analysis_inst,
requested_configs=[config_inst],
)
for ml_model in params["ml_models"]
]
elif not cls.allow_empty_ml_models:
raise Exception(f"no ml_models configured for {cls.task_family}")
return params
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict:
# prefer --ml-models set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"ml_models"}
return super().req_params(inst, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# get the ML model instances
self.ml_model_insts = [
MLModelMixinBase.get_ml_model_inst(
ml_model,
self.analysis_inst,
requested_configs=[self.config_inst],
)
for ml_model in self.ml_models
]
[docs] def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()
if self.ml_model_insts:
parts.insert_before("version", "ml_models", f"ml__{self.ml_models_repr}")
return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]:
columns = super().find_keep_columns(collection)
if collection == ColumnCollection.ALL_FROM_ML_EVALUATION:
columns |= set.union(*(
set.union(*model_inst.produced_columns().values())
for model_inst in self.ml_model_insts
))
return columns
[docs]class InferenceModelMixin(ConfigTask):
inference_model = luigi.Parameter(
default=RESOLVE_DEFAULT,
description="the name of the inference model to be used; default: value of the "
"'default_inference_model' config",
)
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params = super().resolve_param_values(params)
# add the default inference model when empty
config_inst = params.get("config_inst")
if config_inst:
params["inference_model"] = cls.resolve_config_default(
params,
params.get("inference_model"),
container=config_inst,
default_str="default_inference_model",
multiple=False,
)
return params
[docs] @classmethod
def get_inference_model_inst(cls, inference_model: str, config_inst: od.Config) -> InferenceModel:
return InferenceModel.get_cls(inference_model)(config_inst)
[docs] @classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict:
# prefer --inference-model set on task-level via cli
kwargs["_prefer_cli"] = law.util.make_set(kwargs.get("_prefer_cli", [])) | {"inference_model"}
return super().req_params(inst, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# get the inference model instance
self.inference_model_inst = self.get_inference_model_inst(self.inference_model, self.config_inst)
@property
def inference_model_repr(self):
return str(self.inference_model)
[docs] def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()
if self.inference_model != law.NO_STR:
parts.insert_before("version", "inf_model", f"inf__{self.inference_model_repr}")
return parts
[docs]class CategoriesMixin(ConfigTask):
categories = law.CSVParameter(
default=(),
description="comma-separated category names or patterns to select; can also be the key of "
"a mapping defined in 'category_groups' auxiliary data of the config; when empty, uses the "
"auxiliary data enty 'default_categories' when set; empty default",
brace_expand=True,
parse_empty=True,
)
default_categories = None
allow_empty_categories = False
[docs] @classmethod
def resolve_param_values(cls, params):
params = super().resolve_param_values(params)
if "config_inst" not in params:
return params
config_inst = params["config_inst"]
# resolve categories
if "categories" in params:
# when empty, use the config default
if not params["categories"] and config_inst.x("default_categories", ()):
params["categories"] = tuple(config_inst.x.default_categories)
# when still empty and default categories are defined, use them instead
if not params["categories"] and cls.default_categories:
params["categories"] = tuple(cls.default_categories)
# resolve them
categories = cls.find_config_objects(
params["categories"],
config_inst,
od.Category,
config_inst.x("category_groups", {}),
deep=True,
)
# complain when no categories were found
if not categories and not cls.allow_empty_categories:
raise ValueError(f"no categories found matching {params['categories']}")
params["categories"] = tuple(categories)
return params
@property
def categories_repr(self):
if len(self.categories) == 1:
return self.categories[0]
return f"{len(self.categories)}_{law.util.create_hash(sorted(self.categories))}"
[docs]class VariablesMixin(ConfigTask):
variables = law.CSVParameter(
default=(),
description="comma-separated variable names or patterns to select; can also be the key of "
"a mapping defined in the 'variable_group' auxiliary data of the config; when empty, uses "
"all variables of the config; empty default",
brace_expand=True,
parse_empty=True,
)
default_variables = None
allow_empty_variables = False
allow_missing_variables = False
[docs] @classmethod
def resolve_param_values(cls, params):
params = super().resolve_param_values(params)
if "config_inst" not in params:
return params
config_inst = params["config_inst"]
# resolve variables
if "variables" in params:
# when empty, use the config default
if not params["variables"] and config_inst.x("default_variables", ()):
params["variables"] = tuple(config_inst.x.default_variables)
# when still empty and default variables are defined, use them instead
if not params["variables"] and cls.default_variables:
params["variables"] = tuple(cls.default_variables)
# resolve them
if params["variables"]:
# first, split into single- and multi-dimensional variables
single_vars = []
multi_var_parts = []
for variable in params["variables"]:
parts = cls.split_multi_variable(variable)
if len(parts) == 1:
single_vars.append(variable)
else:
multi_var_parts.append(parts)
# resolve single variables
variables = cls.find_config_objects(
single_vars,
config_inst,
od.Variable,
config_inst.x("variable_groups", {}),
strict=not cls.allow_missing_variables,
)
# for each multi-variable, resolve each part separately and create the full
# combinatorics of all possibly pattern-resolved parts
for parts in multi_var_parts:
resolved_parts = [
cls.find_config_objects(
part,
config_inst,
od.Variable,
config_inst.x("variable_groups", {}),
strict=not cls.allow_missing_variables,
)
for part in parts
]
variables.extend([
cls.join_multi_variable(_parts)
for _parts in itertools.product(*resolved_parts)
])
else:
# fallback to using all known variables
variables = config_inst.variables.names()
# complain when no variables were found
if not variables and not cls.allow_empty_variables:
raise ValueError(f"no variables found matching {params['variables']}")
params["variables"] = tuple(variables)
return params
[docs] @classmethod
def split_multi_variable(cls, variable: str) -> tuple[str]:
"""
Splits a multi-dimensional *variable* given in the format ``"var_a[-var_b[-...]]"`` into
separate variable names using a delimiter (``"-"``) and returns a tuple.
"""
return tuple(variable.split("-"))
[docs] @classmethod
def join_multi_variable(cls, variables: Sequence[str]) -> str:
"""
Joins the name of multiple *variables* using a delimiter (``"-"``) into a single string
that represents a multi-dimensional variable and returns it.
"""
return "-".join(map(str, variables))
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# if enabled, split names of multi-dimensional parameters into tuples
self.variable_tuples = {
var_name: self.split_multi_variable(var_name)
for var_name in self.variables
}
@property
def variables_repr(self):
if len(self.variables) == 1:
return self.variables[0]
return f"{len(self.variables)}_{law.util.create_hash(sorted(self.variables))}"
[docs]class DatasetsProcessesMixin(ConfigTask):
datasets = law.CSVParameter(
default=(),
description="comma-separated dataset names or patters to select; can also be the key of a "
"mapping defined in the 'dataset_groups' auxiliary data of the config; when empty, uses "
"all datasets registered in the config that contain any of the selected --processes; empty "
"default",
brace_expand=True,
parse_empty=True,
)
processes = law.CSVParameter(
default=(),
description="comma-separated process names or patterns for filtering processes; can also "
"be the key of a mapping defined in the 'process_groups' auxiliary data of the config; "
"uses all processes of the config when empty; empty default",
brace_expand=True,
parse_empty=True,
)
allow_empty_datasets = False
allow_empty_processes = False
[docs] @classmethod
def resolve_param_values(cls, params):
params = super().resolve_param_values(params)
if "config_inst" not in params:
return params
config_inst = params["config_inst"]
# resolve processes
if "processes" in params:
if params["processes"]:
processes = cls.find_config_objects(
params["processes"],
config_inst,
od.Process,
config_inst.x("process_groups", {}),
deep=True,
)
else:
processes = config_inst.processes.names()
# complain when no processes were found
if not processes and not cls.allow_empty_processes:
raise ValueError(f"no processes found matching {params['processes']}")
params["processes"] = tuple(processes)
params["process_insts"] = [config_inst.get_process(p) for p in params["processes"]]
# resolve datasets
if "datasets" in params:
if params["datasets"]:
datasets = cls.find_config_objects(
params["datasets"],
config_inst,
od.Dataset,
config_inst.x("dataset_groups", {}),
)
elif "processes" in params:
# pick all datasets that contain any of the requested (sub) processes
sub_process_insts = sum((
[proc for proc, _, _ in process_inst.walk_processes(include_self=True)]
for process_inst in map(config_inst.get_process, params["processes"])
), [])
datasets = [
dataset_inst.name
for dataset_inst in config_inst.datasets
if any(map(dataset_inst.has_process, sub_process_insts))
]
# complain when no datasets were found
if not datasets and not cls.allow_empty_datasets:
raise ValueError(f"no datasets found matching {params['datasets']}")
params["datasets"] = tuple(datasets)
params["dataset_insts"] = [config_inst.get_dataset(d) for d in params["datasets"]]
return params
[docs] @classmethod
def get_known_shifts(cls, config_inst, params):
shifts, upstream_shifts = super().get_known_shifts(config_inst, params)
# add shifts of all datasets to upstream ones
for dataset_inst in params.get("dataset_insts") or []:
if dataset_inst.is_mc:
upstream_shifts |= set(dataset_inst.info.keys())
return shifts, upstream_shifts
@property
def datasets_repr(self):
if len(self.datasets) == 1:
return self.datasets[0]
return f"{len(self.datasets)}_{law.util.create_hash(sorted(self.datasets))}"
@property
def processes_repr(self):
if len(self.processes) == 1:
return self.processes[0]
return f"{len(self.processes)}_{law.util.create_hash(self.processes)}"
[docs]class ShiftSourcesMixin(ConfigTask):
shift_sources = law.CSVParameter(
default=(),
description="comma-separated shift source names (without direction) or patterns to select; "
"can also be the key of a mapping defined in the 'shift_group' auxiliary data of the "
"config; default: ()",
brace_expand=True,
parse_empty=True,
)
allow_empty_shift_sources = False
[docs] @classmethod
def resolve_param_values(cls, params):
params = super().resolve_param_values(params)
if "config_inst" not in params:
return params
config_inst = params["config_inst"]
# resolve shift sources
if "shift_sources" in params:
# convert to full shift first to do the object finding
shifts = cls.find_config_objects(
cls.expand_shift_sources(params["shift_sources"]),
config_inst,
od.Shift,
config_inst.x("shift_groups", {}),
)
# complain when no shifts were found
if not shifts and not cls.allow_empty_shift_sources:
raise ValueError(f"no shifts found matching {params['shift_sources']}")
# convert back to sources
params["shift_sources"] = tuple(cls.reduce_shifts(shifts))
return params
[docs] @classmethod
def expand_shift_sources(cls, sources: Sequence[str] | set[str]) -> list[str]:
return sum(([f"{s}_up", f"{s}_down"] for s in sources), [])
[docs] @classmethod
def reduce_shifts(cls, shifts: Sequence[str] | set[str]) -> list[str]:
return list(set(od.Shift.split_name(shift)[0] for shift in shifts))
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shifts = self.expand_shift_sources(self.shift_sources)
@property
def shift_sources_repr(self):
if len(self.shift_sources) == 1:
return self.shift_sources[0]
return f"{len(self.shift_sources)}_{law.util.create_hash(sorted(self.shift_sources))}"
[docs]class WeightProducerMixin(ConfigTask):
weight_producer = luigi.Parameter(
default=RESOLVE_DEFAULT,
description="the name of the weight producer to be used; default: value of the "
"'default_weight_producer' config",
)
# decides whether the task itself runs the weight producer and implements its shifts
register_weight_producer_sandbox = False
register_weight_producer_shifts = False
[docs] @classmethod
def get_weight_producer_inst(
cls,
weight_producer: str,
kwargs: dict | None = None,
) -> WeightProducer:
weight_producer_cls = WeightProducer.get_cls(weight_producer)
if not weight_producer_cls.exposed:
raise RuntimeError(
f"cannot use unexposed weight producer '{weight_producer}' in {cls.__name__}",
)
inst_dict = cls.get_weight_producer_kwargs(**kwargs) if kwargs else None
return weight_producer_cls(inst_dict=inst_dict)
[docs] @classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params = super().resolve_param_values(params)
config_inst = params.get("config_inst")
if config_inst:
# add the default weight producer when empty
params["weight_producer"] = cls.resolve_config_default(
params,
params.get("weight_producer"),
container=config_inst,
default_str="default_weight_producer",
multiple=False,
)
if params["weight_producer"] is None:
raise Exception(f"no weight producer configured for task {cls.task_family}")
params["weight_producer_inst"] = cls.get_weight_producer_inst(
params["weight_producer"],
params,
)
return params
[docs] @classmethod
def get_known_shifts(
cls,
config_inst: od.Config,
params: dict[str, Any],
) -> tuple[set[str], set[str]]:
shifts, upstream_shifts = super().get_known_shifts(config_inst, params)
# get the weight producer, update it and add its shifts
weight_producer_inst = params.get("weight_producer_inst")
if weight_producer_inst:
if cls.register_weight_producer_shifts:
shifts |= weight_producer_inst.all_shifts
else:
upstream_shifts |= weight_producer_inst.all_shifts
return shifts, upstream_shifts
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# cache for weight producer inst
self._weight_producer_inst = None
@property
def weight_producer_inst(self) -> WeightProducer:
if self._weight_producer_inst is None:
self._weight_producer_inst = self.get_weight_producer_inst(
self.weight_producer,
{"task": self},
)
# overwrite the sandbox when set
if self.register_weight_producer_sandbox:
sandbox = self._weight_producer_inst.get_sandbox()
if sandbox:
self.sandbox = sandbox
# rebuild the sandbox inst when already initialized
if self._sandbox_initialized:
self._initialize_sandbox(force=True)
return self._weight_producer_inst
@property
def weight_producer_repr(self) -> str:
return str(self.weight_producer_inst)
[docs] def store_parts(self: WeightProducerMixin) -> law.util.InsertableDict[str, str]:
parts = super().store_parts()
parts.insert_before("version", "weightprod", f"weight__{self.weight_producer_repr}")
return parts
[docs]class ChunkedIOMixin(AnalysisTask):
check_finite_output = luigi.BoolParameter(
default=False,
significant=False,
description="when True, checks whether output arrays only contain finite values before "
"writing to them to file",
)
check_overlapping_inputs = luigi.BoolParameter(
default=False,
significant=False,
description="when True, checks whether columns if input arrays overlap in at least one field",
)
exclude_params_req = {"check_finite_output", "check_overlapping_inputs"}
# define default chunk and pool sizes that can be adjusted per inheriting task
default_chunk_size = ChunkedIOHandler.default_chunk_size
default_pool_size = ChunkedIOHandler.default_pool_size
[docs] @classmethod
def raise_if_not_finite(cls, ak_array: ak.Array) -> None:
"""
Checks whether all values in array *ak_array* are finite.
The check is performed using the :external+numpy:py:func:`numpy.isfinite` function.
:param ak_array: Array with events to check.
:raises ValueError: If any value in *ak_array* is not finite.
"""
import numpy as np
from columnflow.columnar_util import get_ak_routes
for route in get_ak_routes(ak_array):
if ak.any(~np.isfinite(ak.flatten(route.apply(ak_array), axis=None))):
raise ValueError(
f"found one or more non-finite values in column '{route.column}' "
f"of array {ak_array}",
)
[docs] @classmethod
def raise_if_overlapping(cls, ak_arrays: Sequence[ak.Array]) -> None:
"""
Checks whether fields of *ak_arrays* overlap.
:param ak_arrays: Arrays with fields to check.
:raises ValueError: If at least one overlap is found.
"""
from columnflow.columnar_util import get_ak_routes
# when less than two arrays are given, there cannot be any overlap
if len(ak_arrays) < 2:
return
# determine overlapping routes
counts = Counter(sum(map(get_ak_routes, ak_arrays), []))
overlapping_routes = [r for r, c in counts.items() if c > 1]
# raise
if overlapping_routes:
raise ValueError(
f"found {len(overlapping_routes)} overlapping columns across {len(ak_arrays)} "
f"columns: {','.join(overlapping_routes)}",
)
[docs] def iter_chunked_io(self, *args, **kwargs):
# get the chunked io handler from first arg or create a new one with all args
if len(args) == 1 and isinstance(args[0], ChunkedIOHandler):
handler = args[0]
else:
# default chunk and pool sizes
for key in ["chunk_size", "pool_size"]:
if kwargs.get(key) is None:
# get the default from the config, defaulting to the class default
kwargs[key] = law.config.get_expanded_int(
"analysis",
f"{self.task_family}__chunked_io_{key}",
getattr(self, f"default_{key}"),
)
# when still not set, remove it and let the handler decide using its defaults
if kwargs.get(key) is None:
kwargs.pop(key, None)
# create the handler
handler = ChunkedIOHandler(*args, **kwargs)
# iterate in the handler context
with handler:
self.chunked_io = handler
msg = f"iterate through {handler.n_entries:_} events in {handler.n_chunks} chunks ..."
try:
# measure runtimes excluding IO
loop_durations = []
for obj in self.iter_progress(handler, max(handler.n_chunks, 1), msg=msg):
t1 = time.perf_counter()
# yield the object provided by the handler
yield obj
# save the runtime
loop_durations.append(time.perf_counter() - t1)
# print runtimes
self.publish_message(
"event processing in loop body took "
f"{law.util.human_duration(seconds=sum(loop_durations))}",
)
finally:
self.chunked_io = None
# eager cleanup
del handler
[docs]class HistHookMixin(ConfigTask):
hist_hooks = law.CSVParameter(
default=(),
description="names of functions in the config's auxiliary dictionary 'hist_hooks' that are "
"invoked before plotting to update a potentially nested dictionary of histograms; "
"default: empty",
)
[docs] def invoke_hist_hooks(self, hists: dict) -> dict:
"""
Invoke hooks to update histograms before plotting.
"""
if not self.hist_hooks:
return hists
for hook in self.hist_hooks:
if hook in (None, "", law.NO_STR):
continue
# get the hook from the config instance
hooks = self.config_inst.x("hist_hooks", {})
if hook not in hooks:
raise KeyError(
f"hist hook '{hook}' not found in 'hist_hooks' auxiliary entry of config",
)
func = hooks[hook]
if not callable(func):
raise TypeError(f"hist hook '{hook}' is not callable: {func}")
# invoke it
self.publish_message(f"invoking hist hook '{hook}'")
hists = func(self, hists)
return hists
@property
def hist_hooks_repr(self) -> str:
"""
Return a string representation of the hist hooks.
"""
hooks = [hook for hook in self.hist_hooks if hook not in (None, "", law.NO_STR)]
hooks_repr = "__".join(hooks[:5])
if len(hooks) > 5:
hooks_repr += f"__{law.util.create_hash(hooks[5:])}"
return hooks_repr