# coding: utf-8
"""
Object and event calibration tools.
"""
from __future__ import annotations
import copy
import inspect
import law
from columnflow.columnar_util import TaskArrayFunction
from columnflow.util import DerivableMeta, UNSET
from columnflow.types import Callable, Sequence, Any, UNSET_TYPE
[docs]
class TaskArrayFunctionWithCalibratorRequirements(TaskArrayFunction):
require_calibrators: Sequence[str] | set[str] | None = None
def __init__(self, *args, **kwargs):
if "require_calibrators" in kwargs or self.__class__.require_calibrators is None:
kwargs["require_calibrators"] = kwargs.get("require_calibrators") or []
elif isinstance(self.__class__.require_calibrators, (list, tuple)):
kwargs["require_calibrators"] = copy.copy(self.__class__.require_calibrators)
super().__init__(*args, **kwargs)
def _req_calibrator(self, task: law.Task, calibrator: str) -> Any:
# hook to customize how required calibrators are requested
from columnflow.tasks.calibration import CalibrateEvents
return CalibrateEvents.req_other_calibrator(task, calibrator=calibrator)
[docs]
def requires_func(self, task: law.Task, reqs: dict, **kwargs) -> None:
super().requires_func(task=task, reqs=reqs, **kwargs)
# no requirements for workflows in pilot mode
if callable(getattr(task, "is_workflow", None)) and task.is_workflow() and getattr(task, "pilot", False):
return
# add required calibrators when set
if (calibs := self.require_calibrators):
reqs["required_calibrators"] = {
calib: self._req_calibrator(task, calib)
for calib in law.util.make_unique(calibs)
}
[docs]
def setup_func(
self,
task: law.Task,
reqs: dict,
inputs: dict,
reader_targets: law.util.InsertableDict,
**kwargs,
) -> None:
super().setup_func(task=task, reqs=reqs, inputs=inputs, reader_targets=reader_targets, **kwargs)
if "required_calibrators" in inputs:
for calib, inp in inputs["required_calibrators"].items():
reader_targets[f"required_calibrator_{calib}"] = inp["columns"]
[docs]
class Calibrator(TaskArrayFunctionWithCalibratorRequirements):
"""
Base class for all calibrators.
"""
exposed = True
# register attributes for arguments accepted by decorator
mc_only: bool = False
data_only: bool = False
[docs]
@classmethod
def calibrator(
cls,
func: Callable | None = None,
bases: tuple = (),
mc_only: bool | UNSET_TYPE = UNSET,
data_only: bool | UNSET_TYPE = UNSET,
require_calibrators: Sequence[str] | set[str] | None | UNSET_TYPE = UNSET,
**kwargs,
) -> DerivableMeta | Callable:
"""
Decorator for creating a new :py:class:`~.Calibrator` subclass with additional, optional *bases* and attaching
the decorated function to it as ``call_func``.
When *mc_only* (*data_only*) is *True*, the calibrator is skipped and not considered by other calibrators,
selectors and producers in case they are evalauted on a :py:class:`order.Dataset` (using the
:py:attr:`dataset_inst` attribute) whose ``is_mc`` (``is_data``) attribute is *False*.
All additional *kwargs* are added as class members of the new subclasses.
:param func: Function to be wrapped and integrated into new :py:class:`Calibrator` class.
:param bases: Additional bases for the new :py:class:`Calibrator`.
:param mc_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on Monte Carlo
simulation and skipped for real data.
:param data_only: Boolean flag indicating that this :py:class:`Calibrator` should only run on real data and
skipped for Monte Carlo simulation.
:param require_calibrators: Sequence of names of other calibrators to add to the requirements.
:return: New :py:class:`Calibrator` subclass.
"""
def decorator(func: Callable) -> DerivableMeta:
# create the class dict
cls_dict = {**kwargs, "call_func": func}
if mc_only is not UNSET:
cls_dict["mc_only"] = mc_only
if data_only is not UNSET:
cls_dict["data_only"] = data_only
if require_calibrators is not UNSET:
cls_dict["require_calibrators"] = require_calibrators
# get the module name
frame = inspect.stack()[1]
module = inspect.getmodule(frame[0])
# get the calibrator name
cls_name = cls_dict.pop("cls_name", func.__name__)
# hook to update the class dict during class derivation
def update_cls_dict(cls_name, cls_dict, get_attr):
mc_only = get_attr("mc_only")
data_only = get_attr("data_only")
# optionally add skip function
if mc_only and data_only:
raise Exception(f"calibrator {cls_name} received both mc_only and data_only")
if (mc_only or data_only) and cls_dict.get("skip_func"):
raise Exception(
f"calibrator {cls_name} received custom skip_func, but either mc_only or data_only are set",
)
if "skip_func" not in cls_dict:
def skip_func(self, **kwargs) -> bool:
# check mc_only and data_only
if mc_only and not self.dataset_inst.is_mc:
return True
if data_only and not self.dataset_inst.is_data:
return True
# in all other cases, do not skip
return False
cls_dict["skip_func"] = skip_func
return cls_dict
cls_dict["update_cls_dict"] = update_cls_dict
# create the subclass
subclass = cls.derive(cls_name, bases=bases, cls_dict=cls_dict, module=module)
return subclass
return decorator(func) if func else decorator
# shorthand
calibrator = Calibrator.calibrator