# coding: utf-8
"""
Collection of general helpers and utilities.
"""
from __future__ import annotations
__all__ = []
import re
import itertools
from collections import OrderedDict
import law
import order as od
from columnflow.util import maybe_import
from columnflow.types import Callable, Any, Sequence
ak = maybe_import("awkward")
np = maybe_import("numpy")
[docs]def get_events_from_categories(
events: ak.Array,
categories: Sequence[str | od.Category],
config_inst: od.Config | None = None,
) -> ak.Array:
"""
Helper function that returns all events from an awkward array *events* that are categorized
into one of the leafs of one of the *categories*.
:param events: Awkward array. Requires the 'category_ids' field to be present.
:param categories: Sequence of category instances. Can also be a sequence of strings when passing a
*config_inst*.
:param config_inst: Optional config instance to load category instances.
:raises ValueError: If "category_ids" is not present in the *events* fields.
:return: Awkward array of all events that are categorized into one of the leafs of one of the
*categories*
"""
if "category_ids" not in events.fields:
raise ValueError(
f"{get_events_from_categories.__name__} requires the 'category_ids' field to be present",
)
categories = law.util.make_list(categories)
if config_inst:
# get category insts
categories = [config_inst.get_category(cat) for cat in categories]
leaf_category_insts = set.union(*map(set, (cat.get_leaf_categories() or {cat} for cat in categories)))
# do the "or" of all leaf categories
mask = np.zeros(len(events), dtype=bool)
for cat in leaf_category_insts:
cat_mask = ak.any(events.category_ids == cat.id, axis=1)
mask = cat_mask | mask
return events[mask]
[docs]def get_root_processes_from_campaign(campaign: od.config.Campaign) -> od.unique.UniqueObjectIndex:
"""
Extracts all root process objects from datasets contained in an order *campaign* and returns
them in a unique object index.
:param campaign: :py:class:`~order.config.Campaign` object containing information
about relevant datasets
:return: Unique indices for :py:class:`~order.process.Process` instances of
root processes associated with these datasets
"""
# get all dataset processes
processes: set[od.Process] = set.union(*map(set, (dataset.processes for dataset in campaign.datasets)))
# get their root processes
root_processes: set[od.Process] = set.union(*map(set, (
(process.get_root_processes() or [process])
for process in processes
)))
# create an empty index and fill subprocesses via walking
index = od.UniqueObjectIndex(cls=od.Process)
for root_process in root_processes:
for process, _, _ in root_process.walk_processes(include_self=True):
index.add(process, overwrite=True)
return index
[docs]def get_datasets_from_process(
config: od.config.Config,
process: str | od.process.Process,
strategy: str = "inclusive",
only_first: bool = True,
check_deep: bool = False,
) -> list[od.dataset.Dataset]:
r"""Given a *process* and the *config* it belongs to, returns a list of order dataset objects that
contain matching processes. This is done by walking through *process* and its child processes
and checking whether they are contained in known datasets. *strategy* controls how possible
ambiguities are resolved:
- ``"all"``: The full process tree is traversed and all matching datasets are considered.
Note that this might lead to a potential over-representation of the phase space.
- ``"inclusive"``: If a dataset is found to match a process, its child processes are not
checked further.
- ``"exclusive"``: If **any** (deep) subprocess of *process* is found to be contained in a
dataset, return datasets of subprocesses but not that of *process* itself (if any).
- ``"exclusive_strict"``: If **all** (deep) subprocesses of *process* are found to be
contained in a dataset, return these datasets but not that of *process* itself (if any).
As an example, consider the process tree
.. mermaid::
:align: center
:zoom:
flowchart BT
A[single top]
B{s channel}
C{t channel}
D{tw channel}
E(t)
F(tbar)
G(t)
H(tbar)
I(t)
J(tbar)
B --> A
C --> A
D --> A
E --> B
F --> B
G --> C
H --> C
I --> D
J --> D
and datasets existing for
1. single top - s channel - t
2. single top - s channel - tbar
3. single top - t channel
4. single top - t channel - t
5. single top - tw channel
6. single top - tw channel - t
7. single top - tw channel - tbar
in the *config*. Depending on *strategy*, the returned datasets for process ``single top``are:
- ``"all"``: ``[1, 2, 3, 4, 5, 6, 7]``. Simply all datasets matching any subprocess.
- ``"inclusive"``: ``[1, 2, 3, 5]``. Skipping ``single top - t channel - t``,
``single top - tw channel - t``, and ``single top - tw channel - tbar``, since more inclusive
datasets (``single top - t channel`` and ``single top - tw channel``) exist.
- ``"exclusive"``: ``[1, 2, 4, 6, 7]``. Skipping ``single_top - t_channel`` and
``single top - tw channel`` since more exclusive datasets (``single top - t channel - t``,
``single top - tw channel - t``, and ``single top - tw channel - tbar``) exist.
- ``"exclusive_strict"``: ``[1, 2, 3, 6, 7]``. Like ``"exclusive"``, but not skipping
``single top - t channel`` since not all subprocesses of ``t channel`` match a dataset
(there is no ``single top - t channel - tbar`` dataset).
In addition, two arguments configure how the check is performed whether a process is contained
in a dataset. If *only_first* is *True*, only the first matching dataset is considered.
Otherwise, all datasets matching a specific process are returned. For the check itself,
*check_deep* is forwarded to :py:meth:`order.Dataset.has_process`.
:param config: Config instance containing the information about known datasets.
:param process: Process instance or process name for which you want to obtain
list of datasets.
:param strategy: controls how possible ambiguities are resolved. Choices:
[``"all"``, ``"inclusive"``, ``"exclusive"``, ``"exclusive_strict"``]
:param only_first: If *True*, only the first matching dataset is considered.
:param check_deep: Forwarded to :py:meth:`order.Dataset.has_process`
:raises ValueError: If *strategy* is not in list of allowed choices
:return: List of datasets that correspond to *process*, depending on the
specifics of the query
"""
# check the strategy
known_strategies = ["all", "inclusive", "exclusive", "exclusive_strict"]
if strategy not in known_strategies:
_known_strategies = ", ".join(map("'{}'".format, known_strategies))
raise ValueError(f"unknown strategy {strategy}, known values are {_known_strategies}")
# make sure we are dealing a process instance
root_inst = config.get_process(process)
# the tree traversal differs depending on the strategy, so distinguish cases
if strategy in ["all", "inclusive"]:
dataset_insts: list[od.Dataset] = []
for process_inst, _, child_insts in root_inst.walk_processes(include_self=True, algo="bfs"):
found_dataset = False
# check datasets
for dataset_inst in config.datasets:
if dataset_inst.has_process(process_inst, deep=check_deep):
dataset_insts.append(dataset_inst)
found_dataset = True
# stop checking more datasets when only the first matters
if only_first:
break
# in the inclusive strategy, children do not need to be traversed if a dataset was found
if strategy == "inclusive" and found_dataset:
del child_insts[:]
return law.util.make_unique(dataset_insts)
# at this point, strategy is exclusive or exclusive_strict
dataset_insts_dict: OrderedDict[str, od.Dataset] = OrderedDict()
for process_inst, _, child_insts in root_inst.walk_processes(include_self=True, algo="dfs_post"):
# check if child processes have matched datasets already
if child_insts:
n_found = sum(int(child_inst in dataset_insts_dict) for child_inst in child_insts)
# potentially skip the current process
if strategy == "exclusive" and n_found:
continue
if strategy == "exclusive_strict" and n_found == len(child_insts):
# add a empty list to mark this is done
dataset_insts_dict[process_inst] = []
continue
# at this point, the process itself must be checked,
# so remove potentially found datasets of children
dataset_insts_dict = OrderedDict({
child_inst: _dataset_insts
for child_inst, _dataset_insts in dataset_insts_dict.items()
if child_inst not in child_insts
})
# check datasets
for dataset_inst in config.datasets:
if dataset_inst.has_process(process_inst, deep=check_deep):
dataset_insts_dict.setdefault(process_inst, []).append(dataset_inst)
# stop checking more datasets when only the first matters
if only_first:
break
return sum(dataset_insts_dict.values(), [])
[docs]def add_shift_aliases(
config: od.Config,
shift_source: str,
aliases: dict,
) -> None:
"""
Extracts the two up and down shift instances from a *config* corresponding to a *shift_source*
(i.e. the name of a shift without directions) and assigns *aliases* to their auxiliary data.
Aliases should be given in a dictionary, mapping alias targets (keys) to sources (values). In
both strings, template variables are injected with fields corresponding to all
:py:class:`od.Shift` attributes, such as *name*, *id*, and *direction*.
Example:
.. code-block:: python
add_shift_aliases(config, "pdf", {"pdf_weight": "pdf_weight_{direction}"})
# adds {"pdf_weight": "pdf_weight_up"} to the "pdf_up" shift in "config"
# plus {"pdf_weight": "pdf_weight_down"} to the "pdf_down" shift in "config"
"""
for direction in ["up", "down"]:
shift = config.get_shift(od.Shift.join_name(shift_source, direction))
_aliases = shift.x("column_aliases", {})
# format keys and values
inject_shift = lambda s: re.sub(r"\{([^_])", r"{_\1", s).format(**shift.__dict__)
_aliases.update({inject_shift(key): inject_shift(value) for key, value in aliases.items()})
# extend existing or register new column aliases
shift.x.column_aliases = _aliases
[docs]def get_shifts_from_sources(config: od.Config, *shift_sources: Sequence[str]) -> list[od.Shift]:
"""
Takes a *config* object and returns a list of shift instances for both directions given a
sequence *shift_sources*.
"""
return sum(
(
[config.get_shift(f"{s}_{od.Shift.UP}"), config.get_shift(f"{s}_{od.Shift.DOWN}")]
for s in shift_sources
),
[],
)
[docs]def expand_shift_sources(shifts: Sequence[str] | set[str]) -> list[str]:
"""
Given a sequence *shifts* containing either shift names (``<source>_<direction>``) or shift
sources, the latter ones are expanded with both possible directions and returned in a common
list.
Example:
.. code-block:: python
expand_shift_sources(["jes", "jer_up"])
# -> ["jes_up", "jes_down", "jer_up"]
"""
_shifts = []
for shift in shifts:
try:
od.Shift.split_name(shift)
_shifts.append(shift)
except ValueError as e:
if not isinstance(shift, str):
raise e
_shifts.extend([f"{shift}_{od.Shift.UP}", f"{shift}_{od.Shift.DOWN}"])
return law.util.make_unique(_shifts)
[docs]def create_category_id(
config: od.Config,
category_name: str,
hash_len: int = 7,
salt: Any = None,
) -> int:
"""
Creates a unique id for a :py:class:`order.Category` named *category_name* in a
:py:class:`order.Config` object *config* and returns it. Internally,
:py:func:`law.util.create_hash` is used which receives *hash_len*. In case of an unintentional
(yet unlikely) collision of two ids, there is the option to add a custom *salt* value.
.. note::
Please note that the size of the returned id depends on *hash_len*. When storing the id
subsequently in an array, please be aware that values 8 or more require a ``np.int64``.
"""
# create the hash
h = law.util.create_hash((config.name, config.id, category_name, salt), l=hash_len)
h = int(h, base=16)
# add an offset to ensure that are hashes are above a threshold
digits = len(str(int("F" * hash_len, base=16)))
h += int(10 ** digits)
return h
[docs]def add_category(
config: od.Config,
parent: od.Config | od.Category | od.Channel | None = None,
**kwargs,
) -> od.Category:
"""
Creates a :py:class:`order.Category` instance by forwarding all *kwargs* to its constructor,
adds it to a *parent* object. such as a :py:class:`order.Config` or an other
:py:class:`order.Category`, and returns it. When *kwargs* do not contain a field *id*,
:py:func:`create_category_id` is used to create one.
:param config: :py:class:`order.Config` object for which the category is created.
:param parent: Parent object to which the category is added. If *None*, *config* is used.
:param kwargs: Keyword arguments forwarded to the category constructor.
:return: The newly created category instance.
"""
if "name" not in kwargs:
fields = ",".join(map(str, kwargs))
raise ValueError(f"a field 'name' is required to create a category, got '{fields}'")
if "id" not in kwargs:
kwargs["id"] = create_category_id(config, kwargs["name"])
if parent is None:
parent = config
return parent.add_category(**kwargs)
[docs]def create_category_combinations(
config: od.Config,
categories: dict[str, list[od.Category]],
name_fn: Callable[[Any], str],
kwargs_fn: Callable[[Any], dict] | None = None,
skip_existing: bool = True,
skip_fn: Callable[[dict[str, od.Category], str], bool] | None = None,
) -> int:
"""
Given a *config* object and sequences of *categories* in a dict, creates all combinations of
possible leaf categories at different depths, connects them with parent - child relations
(see :py:class:`order.Category`) and returns the number of newly created categories.
*categories* should be a dictionary that maps string names to sequences of categories that
should be combined. The names are used as keyword arguments in a callable *name_fn* that is
supposed to return the name of newly created categories (see example below).
Each newly created category is instantiated with this name as well as arbitrary keyword
arguments as returned by *kwargs_fn*. This function is called with the categories (in a
dictionary, mapped to the sequence names as given in *categories*) that contribute to the newly
created category and should return a dictionary. If the fields ``"id"`` and ``"selection"`` are
missing, they are filled with reasonable defaults leading to a auto-generated, deterministic id
and a list of all parent selection statements.
If the name of a new category is already known to *config* it is skipped unless *skip_existing*
is *False*. In addition, *skip_fn* can be a callable that receives a dictionary mapping group
names to categories that represents the combination of categories to be added. In case *skip_fn*
returns *True*, the combination is skipped.
Example:
.. code-block:: python
categories = {
"lepton": [cfg.get_category("e"), cfg.get_category("mu")],
"n_jets": [cfg.get_category("1j"), cfg.get_category("2j")],
"n_tags": [cfg.get_category("0t"), cfg.get_category("1t")],
}
def name_fn(categories):
# simple implementation: join names in defined order if existing
return "__".join(cat.name for cat in categories.values() if cat)
def kwargs_fn(categories):
# return arguments that are forwarded to the category init
# (use id "+" here which simply increments the last taken id, see order.Category)
# (note that this is also the default)
return {"id": "+"}
create_category_combinations(cfg, categories, name_fn, kwargs_fn)
:param config: :py:class:`order.Config` object for which the categories are created.
:param categories: Dictionary that maps group names to sequences of categories.
:param name_fn: Callable that receives a dictionary mapping group names to categories and
returns the name of the newly created category.
:param kwargs_fn: Callable that receives a dictionary mapping group names to categories and
returns a dictionary of keyword arguments that are forwarded to the category constructor.
:param skip_existing: If *True*, skip the creation of a category when it already exists in
*config*.
:param skip_fn: Callable that receives a dictionary mapping group names to categories and
returns *True* if the combination should be skipped.
:raises TypeError: If *name_fn* is not a callable.
:raises TypeError: If *kwargs_fn* is not a callable when set.
:raises ValueError: If a non-unique category id is detected.
:return: Number of newly created categories.
"""
n_created_categories = 0
unique_ids_cache = {cat.id for cat, _, _ in config.walk_categories()}
n_groups = len(categories)
group_names = list(categories.keys())
# nothing to do when there are less than 2 groups
if n_groups < 2:
return n_created_categories
# check functions
if not callable(name_fn):
raise TypeError(f"name_fn must be a function, but got {name_fn}")
if kwargs_fn and not callable(kwargs_fn):
raise TypeError(f"when set, kwargs_fn must be a function, but got {kwargs_fn}")
# start combining, considering one additional groups for combinatorics at a time
for _n_groups in range(2, n_groups + 1):
# build all group combinations
for _group_names in itertools.combinations(group_names, _n_groups):
# build the product of all categories for the given groups
_categories = [categories[group_name] for group_name in _group_names]
for root_cats in itertools.product(*_categories):
# build the name
root_cats = dict(zip(_group_names, root_cats))
cat_name = name_fn(root_cats)
# skip when already existing
if skip_existing and config.has_category(cat_name, deep=True):
continue
# skip when manually triggered
if callable(skip_fn) and skip_fn(root_cats):
continue
# create arguments for the new category
kwargs = kwargs_fn(root_cats) if callable(kwargs_fn) else {}
if "id" not in kwargs:
kwargs["id"] = create_category_id(config, cat_name)
if "selection" not in kwargs:
kwargs["selection"] = [c.selection for c in root_cats.values()]
# create the new category
cat = od.Category(name=cat_name, **kwargs)
n_created_categories += 1
# ID uniqueness check: raise an error when a non-unique id is detected for a new category
if isinstance(kwargs["id"], int):
if kwargs["id"] in unique_ids_cache:
matching_cat = config.get_category(kwargs["id"])
if matching_cat.name != cat_name:
raise ValueError(
f"non-unique category id '{kwargs['id']}' for '{cat_name}' has already been used for "
f"category '{matching_cat.name}'",
)
unique_ids_cache.add(kwargs["id"])
# find direct parents and connect them
for _parent_group_names in itertools.combinations(_group_names, _n_groups - 1):
if len(_parent_group_names) == 1:
parent_cat_name = root_cats[_parent_group_names[0]].name
else:
parent_cat_name = name_fn({
group_name: root_cats[group_name]
for group_name in _parent_group_names
})
parent_cat = config.get_category(parent_cat_name, deep=True)
parent_cat.add_category(cat)
return n_created_categories
[docs]def verify_config_processes(config: od.Config, warn: bool = False) -> None:
"""
Verifies for all datasets contained in a *config* object that the linked processes are covered
by any process object registered in *config* and raises an exception if not. If *warn* is
*True*, a warning is printed instead.
"""
missing_pairs = []
for dataset in config.datasets:
for process in dataset.processes:
if not config.has_process(process):
missing_pairs.append((dataset, process))
# nothing to do when nothing is missing
if not missing_pairs:
return
# build the message
msg = f"found {len(missing_pairs)} dataset(s) whose process is not registered in the '{config.name}' config:"
for dataset, process in missing_pairs:
msg += f"\n dataset '{dataset.name}' -> process '{process.name}'"
# warn or raise
if not warn:
raise Exception(msg)
print(f"{law.util.colored('WARNING', 'red')}: {msg}")