# coding: utf-8
"""
Collection of general helpers and utilities.
"""
from __future__ import annotations
__all__ = []
import re
import dataclasses
import itertools
from collections import OrderedDict, defaultdict
import law
import order as od
from columnflow.util import maybe_import, get_docs_url
from columnflow.columnar_util import flat_np_view, layout_ak_array
from columnflow.types import Callable, Any, Sequence
ak = maybe_import("awkward")
np = maybe_import("numpy")
logger = law.logger.get_logger(__name__)
[docs]
def get_events_from_categories(
events: ak.Array,
categories: Sequence[str | od.Category],
config_inst: od.Config | None = None,
) -> ak.Array:
"""
Helper function that returns all events from an awkward array *events* that are categorized
into one of the leafs of one of the *categories*.
:param events: Awkward array. Requires the 'category_ids' field to be present.
:param categories: Sequence of category instances. Can also be a sequence of strings when passing a
*config_inst*.
:param config_inst: Optional config instance to load category instances.
:raises ValueError: If "category_ids" is not present in the *events* fields.
:return: Awkward array of all events that are categorized into one of the leafs of one of the
*categories*
"""
if "category_ids" not in events.fields:
raise ValueError(
f"{get_events_from_categories.__name__} requires the 'category_ids' field to be present",
)
categories = law.util.make_list(categories)
if config_inst:
# get category insts
categories = [config_inst.get_category(cat) for cat in categories]
leaf_category_insts = set.union(*map(set, (cat.get_leaf_categories() or {cat} for cat in categories)))
# do the "or" of all leaf categories
mask = np.zeros(len(events), dtype=bool)
for cat in leaf_category_insts:
cat_mask = ak.any(events.category_ids == cat.id, axis=1)
mask = cat_mask | mask
return events[mask]
[docs]
def get_category_name_columns(
category_ids: ak.Array,
config_inst: od.Config,
) -> ak.Array:
"""
Function that transforms column of category ids to column of category names.
:param category_ids: Awkward array of category ids.
:param config_inst: Config instance from which to load category instances.
:raises ValueError: If any of the category ids is not defined in the *config_inst*.
:return: Awkward array of category names with the same shape as *category_ids*
"""
flat_ids = flat_np_view(category_ids)
# map all category ids present in *category_ids* to category instances
category_map = {
_id: config_inst.get_category(_id, default=None)
for _id in set(flat_ids)
}
if any(cat is None for cat in category_map.values()):
undefined_ids = {cat_id for cat_id, cat_inst in category_map.items() if cat_inst is None}
raise ValueError(f"undefined category ids: {', '.join(map(str, undefined_ids))}")
# Create a vectorized function for the mapping
map_to_name = np.vectorize(lambda _id: category_map[_id].name)
# Apply the mapping and layout to the original shape
flat_names = map_to_name(flat_ids)
category_names = layout_ak_array(flat_names, category_ids)
return category_names
[docs]
def get_root_processes_from_campaign(campaign: od.config.Campaign) -> od.unique.UniqueObjectIndex:
"""
Extracts all root process objects from datasets contained in an order *campaign* and returns
them in a unique object index.
:param campaign: :py:class:`~order.config.Campaign` object containing information
about relevant datasets
:return: Unique indices for :py:class:`~order.process.Process` instances of
root processes associated with these datasets
"""
# get all dataset processes
processes: set[od.Process] = set.union(*map(set, (dataset.processes for dataset in campaign.datasets)))
# get their root processes
root_processes: set[od.Process] = set.union(*map(set, (
(process.get_root_processes() or [process])
for process in processes
)))
# create an empty index and fill subprocesses via walking
index = od.UniqueObjectIndex(cls=od.Process)
for root_process in root_processes:
for process, _, _ in root_process.walk_processes(include_self=True):
index.add(process, overwrite=True)
return index
[docs]
def get_datasets_from_process(
config: od.config.Config,
process: str | od.process.Process,
strategy: str = "inclusive",
only_first: bool = True,
check_deep: bool = False,
) -> list[od.dataset.Dataset]:
r"""Given a *process* and the *config* it belongs to, returns a list of order dataset objects that
contain matching processes. This is done by walking through *process* and its child processes
and checking whether they are contained in known datasets. *strategy* controls how possible
ambiguities are resolved:
- ``"all"``: The full process tree is traversed and all matching datasets are considered.
Note that this might lead to a potential over-representation of the phase space.
- ``"inclusive"``: If a dataset is found to match a process, its child processes are not
checked further.
- ``"exclusive"``: If **any** (deep) subprocess of *process* is found to be contained in a
dataset, return datasets of subprocesses but not that of *process* itself (if any).
- ``"exclusive_strict"``: If **all** (deep) subprocesses of *process* are found to be
contained in a dataset, return these datasets but not that of *process* itself (if any).
As an example, consider the process tree
.. mermaid::
:align: center
:zoom:
flowchart BT
A[single top]
B{s channel}
C{t channel}
D{tw channel}
E(t)
F(tbar)
G(t)
H(tbar)
I(t)
J(tbar)
B --> A
C --> A
D --> A
E --> B
F --> B
G --> C
H --> C
I --> D
J --> D
and datasets existing for
1. single top - s channel - t
2. single top - s channel - tbar
3. single top - t channel
4. single top - t channel - t
5. single top - tw channel
6. single top - tw channel - t
7. single top - tw channel - tbar
in the *config*. Depending on *strategy*, the returned datasets for process ``single top``are:
- ``"all"``: ``[1, 2, 3, 4, 5, 6, 7]``. Simply all datasets matching any subprocess.
- ``"inclusive"``: ``[1, 2, 3, 5]``. Skipping ``single top - t channel - t``,
``single top - tw channel - t``, and ``single top - tw channel - tbar``, since more inclusive
datasets (``single top - t channel`` and ``single top - tw channel``) exist.
- ``"exclusive"``: ``[1, 2, 4, 6, 7]``. Skipping ``single_top - t_channel`` and
``single top - tw channel`` since more exclusive datasets (``single top - t channel - t``,
``single top - tw channel - t``, and ``single top - tw channel - tbar``) exist.
- ``"exclusive_strict"``: ``[1, 2, 3, 6, 7]``. Like ``"exclusive"``, but not skipping
``single top - t channel`` since not all subprocesses of ``t channel`` match a dataset
(there is no ``single top - t channel - tbar`` dataset).
In addition, two arguments configure how the check is performed whether a process is contained
in a dataset. If *only_first* is *True*, only the first matching dataset is considered.
Otherwise, all datasets matching a specific process are returned. For the check itself,
*check_deep* is forwarded to :py:meth:`order.Dataset.has_process`.
:param config: Config instance containing the information about known datasets.
:param process: Process instance or process name for which you want to obtain
list of datasets.
:param strategy: controls how possible ambiguities are resolved. Choices:
[``"all"``, ``"inclusive"``, ``"exclusive"``, ``"exclusive_strict"``]
:param only_first: If *True*, only the first matching dataset is considered.
:param check_deep: Forwarded to :py:meth:`order.Dataset.has_process`
:raises ValueError: If *strategy* is not in list of allowed choices
:return: List of datasets that correspond to *process*, depending on the
specifics of the query
"""
# check the strategy
known_strategies = ["all", "inclusive", "exclusive", "exclusive_strict"]
if strategy not in known_strategies:
_known_strategies = ", ".join(map("'{}'".format, known_strategies))
raise ValueError(f"unknown strategy {strategy}, known values are {_known_strategies}")
# make sure we are dealing a process instance
root_inst = config.get_process(process)
# the tree traversal differs depending on the strategy, so distinguish cases
if strategy in ["all", "inclusive"]:
dataset_insts: list[od.Dataset] = []
for process_inst, _, child_insts in root_inst.walk_processes(include_self=True, algo="bfs"):
found_dataset = False
# check datasets
for dataset_inst in config.datasets:
if dataset_inst.has_process(process_inst, deep=check_deep):
dataset_insts.append(dataset_inst)
found_dataset = True
# stop checking more datasets when only the first matters
if only_first:
break
# in the inclusive strategy, children do not need to be traversed if a dataset was found
if strategy == "inclusive" and found_dataset:
del child_insts[:]
return law.util.make_unique(dataset_insts)
# at this point, strategy is exclusive or exclusive_strict
dataset_insts_dict: OrderedDict[str, od.Dataset] = OrderedDict()
for process_inst, _, child_insts in root_inst.walk_processes(include_self=True, algo="dfs_post"):
# check if child processes have matched datasets already
if child_insts:
n_found = sum(int(child_inst in dataset_insts_dict) for child_inst in child_insts)
# potentially skip the current process
if strategy == "exclusive" and n_found:
continue
if strategy == "exclusive_strict" and n_found == len(child_insts):
# add a empty list to mark this is done
dataset_insts_dict[process_inst] = []
continue
# at this point, the process itself must be checked,
# so remove potentially found datasets of children
dataset_insts_dict = OrderedDict({
child_inst: _dataset_insts
for child_inst, _dataset_insts in dataset_insts_dict.items()
if child_inst not in child_insts
})
# check datasets
for dataset_inst in config.datasets:
if dataset_inst.has_process(process_inst, deep=check_deep):
dataset_insts_dict.setdefault(process_inst, []).append(dataset_inst)
# stop checking more datasets when only the first matters
if only_first:
break
return sum(dataset_insts_dict.values(), [])
[docs]
def add_shift_aliases(
config: od.Config,
shift_source: str,
aliases: dict,
) -> None:
"""
Extracts the two up and down shift instances from a *config* corresponding to a *shift_source*
(i.e. the name of a shift without directions) and assigns *aliases* to their auxiliary data.
Aliases should be given in a dictionary, mapping alias targets (keys) to sources (values). In
both strings, template variables are injected with fields corresponding to all
:py:class:`od.Shift` attributes, such as *name*, *id*, and *direction*.
Example:
.. code-block:: python
add_shift_aliases(config, "pdf", {"pdf_weight": "pdf_weight_{direction}"})
# adds {"pdf_weight": "pdf_weight_up"} to the "pdf_up" shift in "config"
# plus {"pdf_weight": "pdf_weight_down"} to the "pdf_down" shift in "config"
"""
for direction in ["up", "down"]:
shift = config.get_shift(od.Shift.join_name(shift_source, direction))
_aliases = shift.x("column_aliases", {})
# format keys and values
inject_shift = lambda s: re.sub(r"\{([^_])", r"{_\1", s).format(**shift.__dict__)
_aliases.update({inject_shift(key): inject_shift(value) for key, value in aliases.items()})
# extend existing or register new column aliases
shift.x.column_aliases = _aliases
[docs]
def get_shift_from_configs(configs: list[od.Config], shift: str | od.Shift, silent: bool = False) -> od.Shift | None:
"""
Given a list of *configs* and a *shift* name or instance, returns the corresponding shift instance from the first
config that contains it. If *silent* is *True*, *None* is returned instead of raising an exception in case the shift
is not found.
"""
if isinstance(shift, od.Shift):
shift = shift.name
for config in configs:
if config.has_shift(shift):
return config.get_shift(shift)
if silent:
return None
raise ValueError(f"shift '{shift}' not found in any of the given configs: {configs}")
[docs]
def get_shifts_from_sources(config: od.Config, *shift_sources: Sequence[str]) -> list[od.Shift]:
"""
Takes a *config* object and returns a list of shift instances for both directions given a
sequence *shift_sources*.
"""
return sum(
(
[config.get_shift(f"{s}_{od.Shift.UP}"), config.get_shift(f"{s}_{od.Shift.DOWN}")]
for s in shift_sources
),
[],
)
[docs]
def group_shifts(
shifts: od.Shift | Sequence[od.Shift],
) -> tuple[od.Shift | None, dict[str, tuple[od.Shift, od.Shift]]]:
"""
Takes several :py:class:`order.Shift` instances *shifts* and groups them according to their
shift source. The nominal shift, if present, is returned separately. The remaining shifts are
grouped by their source and the corresponding up and down shifts are stored in a dictionary.
Example:
.. code-block:: python
# assuming the following shifts exist
group_shifts([nominal, x_up, y_up, y_down, x_down])
# -> (nominal, {"x": (x_up, x_down), "y": (y_up, y_down)})
An exception is raised in case a shift source is represented only by its up or down shift.
"""
nominal = None
grouped = defaultdict(lambda: [None, None])
up_sources = set()
down_sources = set()
for shift in law.util.make_list(shifts):
if shift.name == "nominal":
nominal = shift
else:
grouped[shift.source][shift.is_up] = shift
(up_sources if shift.is_up else down_sources).add(shift.source)
# check completeness of shifts
if (diff := up_sources.symmetric_difference(down_sources)):
raise ValueError(f"shift sources {diff} are not complete and cannot be grouped")
return nominal, dict(grouped)
[docs]
def expand_shift_sources(shifts: Sequence[str] | set[str]) -> list[str]:
"""
Given a sequence *shifts* containing either shift names (``<source>_<direction>``) or shift
sources, the latter ones are expanded with both possible directions and returned in a common
list.
Example:
.. code-block:: python
expand_shift_sources(["jes", "jer_up"])
# -> ["jes_up", "jes_down", "jer_up"]
"""
_shifts = []
for shift in shifts:
try:
od.Shift.split_name(shift)
_shifts.append(shift)
except ValueError as e:
if not isinstance(shift, str):
raise e
_shifts.extend([f"{shift}_{od.Shift.UP}", f"{shift}_{od.Shift.DOWN}"])
return law.util.make_unique(_shifts)
[docs]
def create_category_id(
config: od.Config,
category_name: str,
hash_len: int = 7,
salt: Any = None,
) -> int:
"""
Creates a unique id for a :py:class:`order.Category` named *category_name* in a
:py:class:`order.Config` object *config* and returns it. Internally,
:py:func:`law.util.create_hash` is used which receives *hash_len*. In case of an unintentional
(yet unlikely) collision of two ids, there is the option to add a custom *salt* value.
.. note::
Please note that the size of the returned id depends on *hash_len*. When storing the id
subsequently in an array, please be aware that values 8 or more require a ``np.int64``.
"""
# create the hash
h = law.util.create_hash((config.name, config.id, category_name, salt), l=hash_len)
h = int(h, base=16)
# add an offset to ensure that are hashes are above a threshold
digits = len(str(int("F" * hash_len, base=16)))
h += int(10 ** digits)
return h
[docs]
def add_category(
config: od.Config,
parent: od.Config | od.Category | od.Channel | None = None,
**kwargs,
) -> od.Category:
"""
Creates a :py:class:`order.Category` instance by forwarding all *kwargs* to its constructor,
adds it to a *parent* object. such as a :py:class:`order.Config` or an other
:py:class:`order.Category`, and returns it. When *kwargs* do not contain a field *id*,
:py:func:`create_category_id` is used to create one.
:param config: :py:class:`order.Config` object for which the category is created.
:param parent: Parent object to which the category is added. If *None*, *config* is used.
:param kwargs: Keyword arguments forwarded to the category constructor.
:return: The newly created category instance.
"""
if "name" not in kwargs:
fields = ",".join(map(str, kwargs))
raise ValueError(f"a field 'name' is required to create a category, got '{fields}'")
if "id" not in kwargs:
kwargs["id"] = create_category_id(config, kwargs["name"])
if parent is None:
parent = config
return parent.add_category(**kwargs)
[docs]
@dataclasses.dataclass
class CategoryGroup:
"""
Container to store information about a group of categories, mostly used for creating combinations in
:py:func:`create_category_combinations`.
:param categories: List of :py:class:`order.Category` objects or names that refer to the desired category.
:param is_complete: Should be *True* if the union of category selections covers the full phase space (no gaps).
:param has_overlap: Should be *False* if all categories are pairwise disjoint (no overlap).
:param warn: If *True*, a warning is issued when summing over the group of categories.
"""
categories: list[od.Category | str]
is_complete: bool
has_overlap: bool
warn: bool = True
@property
def is_partition(self) -> bool:
"""
Returns *True* if the group of categories is a full partition of the phase space (no overlap, no gaps).
"""
return self.is_complete and not self.has_overlap
[docs]
def create_category_combinations(
config: od.Config,
categories: dict[str, CategoryGroup | list[od.Category]],
name_fn: Callable[[Any], str],
kwargs_fn: Callable[[Any], dict] | None = None,
skip_existing: bool = True,
skip_fn: Callable[[dict[str, od.Category], str], bool] | None = None,
) -> int:
"""
Given a *config* object and sequences of *categories* in a dict, creates all combinations of possible leaf
categories at different depths, connects them with parent - child relations (see :py:class:`order.Category`) and
returns the number of newly created categories.
*categories* should be a dictionary that maps string names to :py:class:`CategoryGroup` objects which are thin
wrappers around sequences of categories (objects or names). Group names (dictionary keys) are used as keyword
arguments in a callable *name_fn* that is supposed to return the name of newly created categories (see example
below).
.. note::
The :py:attr:`CategoryGroup.is_complete` and :py:attr:`CategoryGroup.has_overlap` attributes are imperative for
columnflow to determine whether the summation over specific categories is valid or may result in under- or
over-counting when combining leaf categories. These checks may be performed by other functions and tools based
on information derived from groups and stored in auxiliary fields of the newly created categories.
Each newly created category is instantiated with this name as well as arbitrary keyword arguments as returned by
*kwargs_fn*. This function is called with the categories (in a dictionary, mapped to the sequence names as given in
*categories*) that contribute to the newly created category and should return a dictionary. If the fields ``"id"``
and ``"selection"`` are missing, they are filled with reasonable defaults leading to a auto-generated, deterministic
id and a list of all parent selection statements.
If the name of a new category is already known to *config* it is skipped unless *skip_existing* is *False*. In
addition, *skip_fn* can be a callable that receives a dictionary mapping group names to categories that represents
the combination of categories to be added. In case *skip_fn* returns *True*, the combination is skipped.
Example:
.. code-block:: python
categories = {
"lepton": CategoryGroup(categories=["e", "mu"], is_complete=False, has_overlap=False),
"n_jets": CategoryGroup(categories=["eq0j", "eq1j", "ge2j"], is_complete=True, has_overlap=False),
"n_tags": CategoryGroup(categories=["0t", "1t"], is_complete=False, has_overlap=False),
}
def name_fn(categories):
# simple implementation: join names in defined order if existing
return "__".join(cat.name for cat in categories.values() if cat)
def kwargs_fn(categories):
# return arguments that are forwarded to the category init
# (use id "+" here which simply increments the last taken id, see order.Category)
# (note that this is also the default)
return {"id": "+"}
create_category_combinations(cfg, categories, name_fn, kwargs_fn)
:param config: :py:class:`order.Config` object for which the categories are created.
:param categories: Dictionary that maps group names to :py:class:`CategoryGroup` containers.
:param name_fn: Callable that receives a dictionary mapping group names to categories and returns the name of the
newly created category.
:param kwargs_fn: Callable that receives a dictionary mapping group names to categories and returns a dictionary of
keyword arguments that are forwarded to the category constructor.
:param skip_existing: If *True*, skip the creation of a category when it already exists in *config*.
:param skip_fn: Callable that receives a dictionary mapping group names to categories and returns *True* if the
combination should be skipped.
:raises TypeError: If *name_fn* is not a callable.
:raises TypeError: If *kwargs_fn* is not a callable when set.
:raises ValueError: If a non-unique category id is detected.
:return: Number of newly created categories.
"""
# cast categories
for name, _categories in categories.items():
# ensure CategoryGroup is used
if not isinstance(_categories, CategoryGroup):
docs_url = get_docs_url("api", "config_util.html", anchor="columnflow.config_util.CategoryGroup")
logger.warning_once(
"deprecated_category_group_lists",
f"using a list to define a sequence of categories for create_category_combinations() is depcreated "
f"and will be removed in a future version, please use a CategoryGroup instance instead: {docs_url}",
)
_categories = CategoryGroup(
categories=law.util.make_list(_categories),
is_complete=True,
has_overlap=False,
)
categories[name] = _categories
# cast string category names to instances
for i, cat in enumerate(_categories.categories):
if isinstance(cat, str):
_categories.categories[i] = config.get_category(cat)
n_created_categories = 0
unique_ids_cache = {cat.id for cat, _, _ in config.walk_categories()}
n_groups = len(categories)
group_names = list(categories.keys())
# nothing to do when there are less than 2 groups
if n_groups < 2:
return n_created_categories
# check functions
if not callable(name_fn):
raise TypeError(f"name_fn must be a function, but got {name_fn}")
if kwargs_fn and not callable(kwargs_fn):
raise TypeError(f"when set, kwargs_fn must be a function, but got {kwargs_fn}")
# start combining, considering one additional groups for combinatorics at a time
for _n_groups in range(2, n_groups + 1):
# build all group combinations
for _group_names in itertools.combinations(group_names, _n_groups):
# build the product of all categories for the given groups
_categories = [categories[group_name].categories for group_name in _group_names]
for root_cats in itertools.product(*_categories):
# build the name
root_cats = dict(zip(_group_names, root_cats))
cat_name = name_fn(root_cats)
# skip when already existing
if skip_existing and config.has_category(cat_name, deep=True):
continue
# skip when manually triggered
if callable(skip_fn) and skip_fn(root_cats):
continue
# create arguments for the new category
kwargs = kwargs_fn(root_cats) if callable(kwargs_fn) else {}
if "id" not in kwargs:
kwargs["id"] = create_category_id(config, cat_name)
if "selection" not in kwargs:
kwargs["selection"] = [c.selection for c in root_cats.values()]
# create the new category
cat = od.Category(name=cat_name, **kwargs)
n_created_categories += 1
# ID uniqueness check: raise an error when a non-unique id is detected for a new category
if isinstance(kwargs["id"], int):
if kwargs["id"] in unique_ids_cache:
matching_cat = config.get_category(kwargs["id"])
if matching_cat.name != cat_name:
raise ValueError(
f"non-unique category id '{kwargs['id']}' for '{cat_name}' has already been used for "
f"category '{matching_cat.name}'",
)
unique_ids_cache.add(kwargs["id"])
# find direct parents and connect them
for _parent_group_names in itertools.combinations(_group_names, _n_groups - 1):
if len(_parent_group_names) == 1:
parent_cat_name = root_cats[_parent_group_names[0]].name
else:
parent_cat_name = name_fn({
group_name: root_cats[group_name]
for group_name in _parent_group_names
})
parent_cat = config.get_category(parent_cat_name, deep=True)
parent_cat.add_category(cat)
return n_created_categories
[docs]
def verify_config_processes(config: od.Config, warn: bool = False) -> None:
"""
Verifies for all datasets contained in a *config* object that the linked processes are covered
by any process object registered in *config* and raises an exception if not. If *warn* is
*True*, a warning is printed instead.
"""
missing_pairs = []
for dataset in config.datasets:
for process in dataset.processes:
if not config.has_process(process):
missing_pairs.append((dataset, process))
# nothing to do when nothing is missing
if not missing_pairs:
return
# build the message
msg = f"found {len(missing_pairs)} dataset(s) whose process is not registered in the '{config.name}' config:"
for dataset, process in missing_pairs:
msg += f"\n dataset '{dataset.name}' -> process '{process.name}'"
# warn or raise
if not warn:
raise Exception(msg)
print(f"{law.util.colored('WARNING', 'red')}: {msg}")