Source code for columnflow.calibration

# coding: utf-8

"""
Object and event calibration tools.
"""

from __future__ import annotations

import copy
import inspect

import law

from columnflow.columnar_util import TaskArrayFunction
from columnflow.util import DerivableMeta, UNSET
from columnflow.types import Callable, Sequence, Any, UNSET_TYPE


[docs] class TaskArrayFunctionWithCalibratorRequirements(TaskArrayFunction): require_calibrators: Sequence[str] | set[str] | None = None def __init__(self, *args, **kwargs): if "require_calibrators" in kwargs or self.__class__.require_calibrators is None: kwargs["require_calibrators"] = kwargs.get("require_calibrators") or [] elif isinstance(self.__class__.require_calibrators, (list, tuple)): kwargs["require_calibrators"] = copy.copy(self.__class__.require_calibrators) super().__init__(*args, **kwargs) def _req_calibrator(self, task: law.Task, calibrator: str) -> Any: # hook to customize how required calibrators are requested from columnflow.tasks.calibration import CalibrateEvents return CalibrateEvents.req_other_calibrator(task, calibrator=calibrator)
[docs] def requires_func(self, task: law.Task, reqs: dict, **kwargs) -> None: super().requires_func(task=task, reqs=reqs, **kwargs) # no requirements for workflows in pilot mode if callable(getattr(task, "is_workflow", None)) and task.is_workflow() and getattr(task, "pilot", False): return # add required calibrators when set if (calibs := self.require_calibrators): reqs["required_calibrators"] = { calib: self._req_calibrator(task, calib) for calib in law.util.make_unique(calibs) }
[docs] def setup_func( self, task: law.Task, reqs: dict, inputs: dict, reader_targets: law.util.InsertableDict, **kwargs, ) -> None: super().setup_func(task=task, reqs=reqs, inputs=inputs, reader_targets=reader_targets, **kwargs) if "required_calibrators" in inputs: for calib, inp in inputs["required_calibrators"].items(): reader_targets[f"required_calibrator_{calib}"] = inp["columns"]
[docs] class Calibrator(TaskArrayFunctionWithCalibratorRequirements): """ Base class for all calibrators. """ exposed = True # register attributes for arguments accepted by decorator mc_only: bool = False data_only: bool = False
[docs] @classmethod def calibrator( cls, func: Callable | None = None, bases: tuple = (), mc_only: bool | UNSET_TYPE = UNSET, data_only: bool | UNSET_TYPE = UNSET, require_calibrators: Sequence[str] | set[str] | None | UNSET_TYPE = UNSET, **kwargs, ) -> DerivableMeta | Callable: """ Decorator for creating a new :py:class:`~.Calibrator` subclass with additional, optional *bases* and attaching the decorated function to it as ``call_func``. When *mc_only* (*data_only*) is *True*, the calibrator is skipped and not considered by other calibrators, selectors and producers in case they are evalauted on a :py:class:`order.Dataset` (using the :py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*. All additional *kwargs* are added as class members of the new subclasses. :param func: Function to be wrapped and integrated into new :py:class:`Calibrator` class. :param bases: Additional bases for the new :py:class:`Calibrator`. :param mc_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on Monte Carlo simulation and skipped for real data. :param data_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on real data and skipped for Monte Carlo simulation. :param require_calibrators: Sequence of names of other calibrators to add to the requirements. :return: New :py:class:`Calibrator` subclass. """ def decorator(func: Callable) -> DerivableMeta: # create the class dict cls_dict = {**kwargs, "call_func": func} if mc_only is not UNSET: cls_dict["mc_only"] = mc_only if data_only is not UNSET: cls_dict["data_only"] = data_only if require_calibrators is not UNSET: cls_dict["require_calibrators"] = require_calibrators # get the module name frame = inspect.stack()[1] module = inspect.getmodule(frame[0]) # get the calibrator name cls_name = cls_dict.pop("cls_name", func.__name__) # hook to update the class dict during class derivation def update_cls_dict(cls_name, cls_dict, get_attr): mc_only = get_attr("mc_only") data_only = get_attr("data_only") # optionally add skip function if mc_only and data_only: raise Exception(f"calibrator {cls_name} received both mc_only and data_only") if (mc_only or data_only) and cls_dict.get("skip_func"): raise Exception( f"calibrator {cls_name} received custom skip_func, but either mc_only or data_only are set", ) if "skip_func" not in cls_dict: def skip_func(self, **kwargs) -> bool: # check mc_only and data_only if mc_only and not self.dataset_inst.is_mc: return True if data_only and not self.dataset_inst.is_data: return True # in all other cases, do not skip return False cls_dict["skip_func"] = skip_func return cls_dict cls_dict["update_cls_dict"] = update_cls_dict # create the subclass subclass = cls.derive(cls_name, bases=bases, cls_dict=cls_dict, module=module) return subclass return decorator(func) if func else decorator
# shorthand calibrator = Calibrator.calibrator