# coding: utf-8
"""
Default histogram producers that define columnflow's default behavior.
"""
from __future__ import annotations
import functools
import law
import order as od
from columnflow.histogramming import HistProducer, hist_producer
from columnflow.columnar_util import has_ak_column, Route
from columnflow.hist_util import create_hist_from_variables, fill_hist, translate_hist_intcat_to_strcat
from columnflow.util import maybe_import
from columnflow.types import TYPE_CHECKING, Any
np = maybe_import("numpy")
ak = maybe_import("awkward")
if TYPE_CHECKING:
hist = maybe_import("hist")
[docs]
@hist_producer()
def cf_default(self: HistProducer, events: ak.Array, **kwargs) -> ak.Array:
"""
Default histogram producer that implements all hooks necessary to ensure columnflow's default behavior:
- create_hist: defines the histogram structure
- __call__: receives an event chunk and updates it, and creates event weights (1's in this case)
- fill: receives the data and fills the histogram
- post_process_hist: post-processes the histogram before it is saved
"""
return events, ak.Array(np.ones(len(events), dtype=np.float32))
@cf_default.create_hist
def cf_default_create_hist(
self: HistProducer,
variables: list[od.Variable],
task: law.Task,
**kwargs,
) -> hist.Hist:
"""
Define the histogram structure for the default histogram producer.
"""
return create_hist_from_variables(
*variables,
categorical_axes=[
("category", "intcat"),
("process", "intcat"),
("shift", "intcat", [0]),
],
weight=True,
)
@cf_default.fill_hist
def cf_default_fill_hist(self: HistProducer, h: hist.Hist, data: dict[str, Any], task: law.Task) -> None:
"""
Fill the histogram with the data.
"""
# in case multiple variable axes are given that refer to data arrays with more than one dimension (i.e. nested),
# check if they are broadcasting-compatible since otherwise, the full combinatorics of values would be fille which
# is not supported by fill_hist in its default implementation
import hist
var_axes = [
ax for ax in h.axes
if isinstance(ax, (hist.axis.Variable, hist.axis.Integer)) and ax.name in data and data[ax.name].ndim > 1
]
if len(var_axes) > 1:
ref_counts = ak.count(data[var_axes[0].name], axis=1)
for ax in var_axes[1:]:
counts = ak.count(data[ax.name], axis=1)
if not ak.all(counts == ref_counts):
err = (
"detected multiple variable axes with data to be filled that is not broadcasting-compatible:\n" +
"\n - ".join(f"{ax.name}: {data[ax.name]}" for ax in var_axes) +
"please use a custom histogram producer whose fill_hist implementation supports the desired "
"filling logic including combinatorics or custom broadcasting"
)
raise ValueError(err)
fill_hist(h, data, last_edge_inclusive=task.last_edge_inclusive)
@cf_default.post_process_hist
def cf_default_post_process_hist(self: HistProducer, h: hist.Hist, task: law.Task) -> hist.Hist:
"""
Post-process the histogram, converting integer to string axis for consistent lookup across configs where ids might
be different.
"""
axis_names = {ax.name for ax in h.axes}
# translate axes
if "category" in axis_names:
@functools.cache
def get_category_name(cat_id: int) -> str:
return self.config_inst.get_category(cat_id).name
h = translate_hist_intcat_to_strcat(h, "category", get_category_name)
if "process" in axis_names:
process_map = {proc_id: self.config_inst.get_process(proc_id).name for proc_id in h.axes["process"]}
h = translate_hist_intcat_to_strcat(h, "process", process_map)
if "shift" in axis_names:
shift_map = {task.global_shift_inst.id: task.global_shift_inst.name}
h = translate_hist_intcat_to_strcat(h, "shift", shift_map)
return h
[docs]
@cf_default.hist_producer()
def all_weights(self: HistProducer, events: ak.Array, **kwargs) -> ak.Array:
"""
HistProducer that combines all event weights from the *event_weights* aux entry from either the config or the
dataset. The weights are multiplied together to form the full event weight.
The expected structure of the *event_weights* aux entry is a dictionary with the weight column name as key and a
list of shift sources as values. The shift sources are used to declare the shifts that the produced event weight
depends on. Example:
.. code-block:: python
from columnflow.config_util import get_shifts_from_sources
# add weights and their corresponding shifts for all datasets
cfg.x.event_weights = {
"normalization_weight": [],
"muon_weight": get_shifts_from_sources(config, "mu_sf"),
"btag_weight": get_shifts_from_sources(config, "btag_hf", "btag_lf"),
}
for dataset_inst in cfg.datasets:
# add dataset-specific weights and their corresponding shifts
dataset.x.event_weights = {}
if not dataset_inst.has_tag("skip_pdf"):
dataset_inst.x.event_weights["pdf_weight"] = get_shifts_from_sources(config, "pdf")
"""
weight = ak.Array(np.ones(len(events)))
# build the full event weight
if self.dataset_inst.is_mc and len(events):
# multiply weights from global config `event_weights` aux entry
for column in self.config_inst.x.event_weights:
weight = weight * Route(column).apply(events)
# multiply weights from dataset-specific `event_weights` aux entry
for column in self.dataset_inst.x("event_weights", []):
if has_ak_column(events, column):
weight = weight * Route(column).apply(events)
else:
self.logger.warning_once(
f"missing_dataset_weight_{column}",
f"weight '{column}' for dataset {self.dataset_inst.name} not found",
)
return events, weight
@all_weights.init
def all_weights_init(self: HistProducer, **kwargs) -> None:
super(all_weights, self).init_func(**kwargs)
weight_columns = set()
if self.dataset_inst.is_data:
return
# add used weight columns and declare shifts that the produced event weight depends on
if self.config_inst.has_aux("event_weights"):
weight_columns |= {Route(column) for column in self.config_inst.x.event_weights}
for shift_insts in self.config_inst.x.event_weights.values():
self.shifts |= {shift_inst.name for shift_inst in shift_insts}
# optionally also for weights defined by a dataset
if self.dataset_inst.has_aux("event_weights"):
weight_columns |= {Route(column) for column in self.dataset_inst.x("event_weights", [])}
for shift_insts in self.dataset_inst.x.event_weights.values():
self.shifts |= {shift_inst.name for shift_inst in shift_insts}
# add weight columns to uses
self.uses |= weight_columns