Source code for columnflow.production.normalization

# coding: utf-8

"""
Column production methods related to sample normalization event weights.
"""

from __future__ import annotations

import copy
import itertools
import dataclasses
import collections

import law
import order as od
import scinum as sn

from columnflow.production import Producer, producer
from columnflow.util import maybe_import, DotDict
from columnflow.columnar_util import set_ak_column
from columnflow.types import Any, Sequence

np = maybe_import("numpy")
ak = maybe_import("awkward")


logger = law.logger.get_logger(__name__)


[docs] def get_stitching_datasets(self: Producer, debug: bool = False) -> tuple[od.Dataset, list[od.Dataset]]: """ Helper function to obtain information about stitching datasets: - the inclusive dataset, which is the dataset that contains all processes - all datasets that are required to stitch this *dataset_inst* """ # first collect all datasets that are needed to stitch the current dataset required_datasets = { d for d in self.config_inst.datasets if ( d.has_process(self.dataset_inst.processes.get_first(), deep=True) or self.dataset_inst.has_process(d.processes.get_first(), deep=True) ) } # determine the inclusive dataset inclusive_dataset = None for dataset_inst in required_datasets: for other_dataset_inst in required_datasets: if dataset_inst == other_dataset_inst: continue # check if the other dataset is a sub-dataset of the current one by comparing their leading process if not dataset_inst.has_process(other_dataset_inst.processes.get_first(), deep=True): break else: # if we did not break, the dataset is the inclusive one inclusive_dataset = dataset_inst break if not inclusive_dataset: raise Exception("inclusive dataset not found") if debug: logger.info( f"determined info for stitching content of dataset '{self.dataset_inst.name}':\n" f" - inclusive dataset: {inclusive_dataset.name}\n" f" - required datasets: {', '.join(d.name for d in required_datasets)}", ) return inclusive_dataset, list(required_datasets)
[docs] def get_br_from_inclusive_datasets( self: Producer, process_insts: Sequence[od.Process] | set[od.Process], dataset_selection_stats: dict[str, dict[str, float | dict[str, float]]], merged_selection_stats: dict[str, float | dict[str, float]], debug: bool = False, ) -> dict[od.Process, float]: """ Helper function to compute the branching ratios from sum of weights of inclusive samples. """ # step 1: per desired process, collect datasets that contain them process_datasets = collections.defaultdict(set) for process_inst in process_insts: for dataset_name, dstats in dataset_selection_stats.items(): if str(process_inst.id) in dstats["sum_mc_weight_per_process"]: process_datasets[process_inst].add(self.config_inst.get_dataset(dataset_name)) # step 2: per dataset, collect all "lowest level" processes that are contained in them dataset_processes = collections.defaultdict(set) for dataset_name in dataset_selection_stats: dataset_inst = self.config_inst.get_dataset(dataset_name) dataset_process_inst = dataset_inst.processes.get_first() for process_inst in process_insts: if process_inst == dataset_process_inst or dataset_process_inst.has_process(process_inst, deep=True): dataset_processes[dataset_inst].add(process_inst) # step 3: per process, structure the assigned datasets and corresponding processes in DAGs, from more inclusive down # to more exclusive phase spaces; usually each DAG can contain multiple paths to compute the BR of a single # process; this is resolved in step 4 @dataclasses.dataclass class Node: process_inst: od.Process dataset_inst: od.Dataset | None = None next: set[Node] = dataclasses.field(default_factory=set) def __hash__(self) -> int: return hash((self.process_inst, self.dataset_inst)) def str_lines(self) -> list[str]: lines = [ f"{self.__class__.__name__}(", f" process={self.process_inst.name}({self.process_inst.id})", f" dataset={self.dataset_inst.name if self.dataset_inst else 'None'}", ] if self.next: lines.append(" next={") for n in self.next: lines.extend(f" {line}" for line in n.str_lines()) lines.append(" }") else: lines.append(r" next={}") lines.append(")") return lines def __str__(self) -> str: return "\n".join(self.str_lines()) process_dags = {} for process_inst, dataset_insts in process_datasets.items(): # first, per dataset, remember all sub (more exclusive) datasets # (the O(n^2) is not necessarily optimal, but we are dealing with very small numbers here, thus acceptable) sub_datasets = {} for d_incl, d_excl in itertools.permutations(dataset_insts, 2): if d_incl.processes.get_first().has_process(d_excl.processes.get_first(), deep=True): sub_datasets.setdefault(d_incl, set()).add(d_excl) # then, expand to a DAG structure nodes = {} excl_nodes = set() for d_incl, d_excls in sub_datasets.items(): for d_excl in d_excls: if d_incl not in nodes: nodes[d_incl] = Node(d_incl.processes.get_first(), d_incl) if d_excl not in nodes: nodes[d_excl] = Node(d_excl.processes.get_first(), d_excl) nodes[d_incl].next.add(nodes[d_excl]) excl_nodes.add(nodes[d_excl]) # mark the root node as the head of the DAG dag = (set(nodes.values()) - excl_nodes).pop() # add another node to leaves that only contains the process instance for node in excl_nodes: if node.next or node.process_inst == process_inst: continue if process_inst not in nodes: nodes[process_inst] = Node(process_inst) node.next.add(nodes[process_inst]) process_dags[process_inst] = dag # step 4: per process, compute the branching ratio for each possible path in the DAG, while keeping track of the # statistical precision of each combination, evaluated based on the raw number of events; then pick the # most precise path; again, there should usually be just a single path, but multiple ones are possible when # datasets have complex overlap def get_single_br(dataset_inst: od.Dataset, process_inst: od.Process) -> sn.Number | None: # process_inst might refer to a mid-layer process, so check which lowest-layer processes it is made of lowest_process_ids = ( [process_inst.id] if process_inst in process_insts else [ int(process_id_str) for process_id_str in dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"] if process_inst.has_process(int(process_id_str), deep=True) ] ) # extract stats process_sum_weights = sum( dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"].get(str(process_id), 0.0) for process_id in lowest_process_ids ) dataset_sum_weights = sum(dataset_selection_stats[dataset_inst.name]["sum_mc_weight_per_process"].values()) process_num_events = sum( dataset_selection_stats[dataset_inst.name]["num_events_per_process"].get(str(process_id), 0.0) for process_id in lowest_process_ids ) dataset_num_events = sum(dataset_selection_stats[dataset_inst.name]["num_events_per_process"].values()) # when there are no events, return None if process_num_events == 0: logger.warning( f"found no events for process '{process_inst.name}' ({process_inst.id}) with subprocess ids " f"'{','.join(map(str, lowest_process_ids))}' in selection stats of dataset {dataset_inst.name}", ) return None # compute the ratio of events, assuming correlated poisson counting errors since numbers come from the same # dataset, then compute the relative uncertainty num_ratio = ( sn.Number(process_num_events, process_num_events**0.5) / sn.Number(dataset_num_events, dataset_num_events**0.5) ) rel_unc = num_ratio(sn.UP, unc=True, factor=True) # compute the branching ratio, using the same relative uncertainty and store using the dataset name to mark its # limited statistics as the source of uncertainty which is important for consistent error propagation br = sn.Number(process_sum_weights / dataset_sum_weights, {f"{dataset_inst.name}_stats": rel_unc * 1j}) return br def path_repr(br_path: tuple[sn.Number, ...], dag_path: tuple[Node, ...]) -> str: return " X ".join( f"{node.process_inst.name} (br = {br.combine_uncertainties().str(format=3)})" for br, node in zip(br_path, dag_path) ) process_brs = {} process_brs_debug = {} for process_inst, dag in process_dags.items(): brs = [] queue = collections.deque([(dag, (br := sn.Number(1.0, 0.0)), (br,), (dag,))]) while queue: node, br, br_path, dag_path = queue.popleft() if not node.next: brs.append((br, br_path, dag_path)) continue for sub_node in node.next: sub_br = get_single_br(node.dataset_inst, sub_node.process_inst) if sub_br is not None: queue.append((sub_node, br * sub_br, br_path + (sub_br,), dag_path + (sub_node,))) # combine all uncertainties brs = [(br.combine_uncertainties(), *paths) for br, *paths in brs] # select the most certain one brs.sort(key=lambda tpl: tpl[0](sn.UP, unc=True, factor=True)) best_br, best_br_path, best_dag_path = brs[0] process_brs[process_inst] = best_br.nominal process_brs_debug[process_inst] = (best_br.nominal, best_br(sn.UP, unc=True, factor=True)) # value and % unc # show a warning in case the relative uncertainty is large if (rel_unc := best_br(sn.UP, unc=True, factor=True)) > 0.1: logger.warning( f"large error on the branching ratio of {rel_unc * 100:.2f}% for process '{process_inst.name}' " f"({process_inst.id}), calculated along\n {path_repr(best_br_path, best_dag_path)}", ) # in case there were multiple values, check their compatibility with the best one and warn if they diverge for i, (br, br_path, dag_path) in enumerate(brs[1:], 2): abs_diff = abs(best_br.n - br.n) rel_diff = abs_diff / best_br.n pull = abs(best_br.n - br.n) / (best_br.u(direction="up")**2 + br.u(direction="up")**2)**0.5 if rel_diff > 0.1 and pull > 3: logger.warning( f"detected diverging branching ratios between the best and the one on position {i} for process " f"'{process_inst.name}' (abs_diff={abs_diff:.4f}, rel_diff={rel_diff:.4f}, pull={pull:.2f} ):" f"\nbest path: {best_br.str(format=3)} from {path_repr(best_br_path, best_dag_path)}" f"\npath {i} : {br.str(format=3)} from {path_repr(br_path, dag_path)}", ) if debug: from tabulate import tabulate header = ["process name", "process id", "branching ratio", "uncertainty (%)"] rows = [ [ process_inst.name, process_inst.id, process_brs_debug[process_inst][0], f"{process_brs_debug[process_inst][1] * 100:.4f}", ] for process_inst in sorted(process_brs_debug) ] logger.info(f"extracted branching ratios from process occurrence in datasets:\n{tabulate(rows, header)}") return process_brs
[docs] def update_dataset_selection_stats( self: Producer, dataset_selection_stats: dict[str, dict[str, float | dict[str, float]]], ) -> dict[str, dict[str, float | dict[str, float]]]: """ Hook to optionally update the per-dataset selection stats. """ return dataset_selection_stats
[docs] @producer( uses={"process_id", "mc_weight"}, # name of the output column weight_name="normalization_weight", # which luminosity to apply, uses the value stored in the config when None luminosity=None, # whether to normalize weights per dataset to the mean weight first (to cancel out numeric differences) normalize_weights_per_dataset=True, # whether to allow stitching datasets allow_stitching=False, get_xsecs_from_inclusive_datasets=False, get_stitching_datasets=get_stitching_datasets, get_br_from_inclusive_datasets=get_br_from_inclusive_datasets, update_dataset_selection_stats=update_dataset_selection_stats, update_dataset_selection_stats_br=None, update_dataset_selection_stats_sum_weights=None, # only run on mc mc_only=True, ) def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Array: """ Uses luminosity information of internal py:attr:`config_inst`, the cross section of a process obtained through :py:class:`category_ids` and the sum of event weights from the py:attr:`selection_stats` attribute to assign each event a normalization weight. The normalization weight is stored in a new column named after the py:attr:`weight_name` attribute. The computation of all weights requires that the selection statistics ("stats" output of :py:class:`SelectEvents`) contains a field ``"sum_mc_weight_per_process"`` which itself is a dictionary mapping process ids to the sum of event weights for that process. *luminosity* is used to scale the yield of the simulation. When *None*, the ``luminosity`` auxiliary field of the config is used. When py:attr`allow_stitching` is set to True, the sum of event weights is computed for all datasets with a leaf process contained in the leaf processes of the py:attr:`dataset_inst`. For stitching, the process_id needs to be reconstructed for each leaf process on a per event basis. Moreover, when stitching is enabled, an additional normalization weight is computed for the inclusive dataset only and stored in a column named `<weight_name>_inclusive_only`. This weight resembles the normalization weight for the inclusive dataset, as if it were unstitched and should therefore only be applied, when using the inclusive dataset as a standalone dataset. """ # read the process id column process_id = np.asarray(events.process_id) # ensure all ids were assigned a cross section unique_process_ids = set(np.unique(process_id)) invalid_ids = unique_process_ids - self.known_process_ids if invalid_ids: invalid_names = [ f"{self.config_inst.get_process(proc_id).name} ({proc_id})" for proc_id in invalid_ids ] raise Exception( f"process_id field contains entries {', '.join(invalid_names)} for which no cross sections were found; " f"process ids with cross sections: {self.known_process_ids}", ) # read the weight per process (defined as lumi * xsec / sum_weights) from the lookup table process_weight = np.squeeze(np.asarray(self.process_weight_table[process_id].todense()), axis=-1) # compute the weight and store it norm_weight = events.mc_weight * process_weight events = set_ak_column(events, self.weight_name, norm_weight, value_type=np.float32) # when stitching, also compute the inclusive-only weight if self.allow_stitching and self.dataset_inst == self.inclusive_dataset: incl_norm_weight = events.mc_weight * self.inclusive_weight events = set_ak_column(events, self.weight_name_incl, incl_norm_weight, value_type=np.float32) return events
@normalization_weights.init def normalization_weights_init(self: Producer, **kwargs) -> None: """ Initializes the normalization weights producer by setting up the normalization weight column. """ super(normalization_weights, self).init_func(**kwargs) # declare the weight name to be a produced column self.produces.add(self.weight_name) # when stitching is enabled, store specific information if self.allow_stitching: # remember the inclusive dataset and all datasets needed to determine the weights of processes in _this_ dataset self.inclusive_dataset, self.required_datasets = self.get_stitching_datasets() # potentially also store the weight needed for only using the inclusive dataset if self.dataset_inst == self.inclusive_dataset: self.weight_name_incl = f"{self.weight_name}_inclusive" self.produces.add(self.weight_name_incl) else: self.inclusive_dataset = self.dataset_inst self.required_datasets = [self.dataset_inst] @normalization_weights.requires def normalization_weights_requires( self: Producer, task: law.Task, reqs: dict[str, DotDict[str, Any]], **kwargs, ) -> None: """ Adds the requirements needed by the underlying py:attr:`task` to access selection stats into *reqs*. """ super(normalization_weights, self).requires_func(task=task, reqs=reqs, **kwargs) # check that all datasets are known for dataset in self.required_datasets: if not self.config_inst.has_dataset(dataset): raise Exception(f"unknown dataset '{dataset}' required for normalization weights computation") from columnflow.tasks.selection import MergeSelectionStats reqs["selection_stats"] = { dataset.name: MergeSelectionStats.req_different_branching( task, dataset=dataset.name, branch=-1 if task.is_workflow() else 0, ) for dataset in self.required_datasets } return reqs @normalization_weights.setup def normalization_weights_setup( self: Producer, task: law.Task, reqs: dict[str, DotDict[str, Any]], inputs: dict[str, Any], reader_targets: law.util.InsertableDict, **kwargs, ) -> None: """ Sets up objects required by the computation of normalization weights and stores them as instance attributes: - py: attr: `process_weight_table`: A sparse array serving as a lookup table for the calculated process weights. This weight is defined as the product of the luminosity, the cross section, divided by the sum of event weights per process. - py: attr: `known_process_ids`: A set of all process ids that are known by the lookup table. """ super(normalization_weights, self).setup_func( task=task, reqs=reqs, inputs=inputs, reader_targets=reader_targets, **kwargs, ) import scipy.sparse # load the selection stats dataset_selection_stats = { dataset: copy.deepcopy(task.cached_value( key=f"selection_stats_{dataset}", func=lambda: inp["stats"].load(formatter="json"), )) for dataset, inp in inputs["selection_stats"].items() } # optionally normalize weights per dataset to their mean, to potentially align different numeric domains norm_factor = 1.0 if self.normalize_weights_per_dataset: for dataset, stats in dataset_selection_stats.items(): dataset_mean_weight = ( sum(stats["sum_mc_weight_per_process"].values()) / sum(stats["num_events_per_process"].values()) ) for process_id_str in stats["sum_mc_weight_per_process"]: stats["sum_mc_weight_per_process"][process_id_str] /= dataset_mean_weight if dataset == self.dataset_inst.name: norm_factor = 1.0 / dataset_mean_weight # drop unused stats dataset_selection_stats = { dataset: {field: stats[field] for field in ["num_events_per_process", "sum_mc_weight_per_process"]} for dataset, stats in dataset_selection_stats.items() } # separately treat stats for extracting BRs and sum of mc weights def extract_stats(*update_funcs): # create copy stats = copy.deepcopy(dataset_selection_stats) # update through one of the functions for update_func in update_funcs: if callable(update_func): stats = update_func(stats) break # merge if len(stats) > 1: from columnflow.tasks.selection import MergeSelectionStats merged_stats = collections.defaultdict(float) for _stats in stats.values(): MergeSelectionStats.merge_counts(merged_stats, _stats) else: merged_stats = stats[self.dataset_inst.name] return stats, merged_stats dataset_selection_stats_br, merged_selection_stats_br = extract_stats( self.update_dataset_selection_stats_br, self.update_dataset_selection_stats, ) _, merged_selection_stats_sum_weights = extract_stats( self.update_dataset_selection_stats_sum_weights, self.update_dataset_selection_stats, ) # get all process ids and instances seen and assigned during selection of this dataset # (i.e., all possible processes that might be encountered during event processing) process_ids = set(map(int, dataset_selection_stats_br[self.dataset_inst.name]["sum_mc_weight_per_process"])) process_insts = set(map(self.config_inst.get_process, process_ids)) # consistency check: when the main process of the current dataset is part of these "lowest level" processes, # there should only be this single process, otherwise the manual (sub) process assignment does not match the # general dataset -> main process info if self.dataset_inst.processes.get_first() in process_insts and len(process_insts) > 1: raise Exception( f"dataset '{self.dataset_inst.name}' has main process '{self.dataset_inst.processes.get_first().name}' " "assigned to it (likely as per cmsdb), but the dataset selection stats for this dataset contain multiple " "sub processes, which is likely a misconfiguration of the manual sub process assignment upstream; found " f"sub processes: {', '.join(f'{process_inst.name} ({process_inst.id})' for process_inst in process_insts)}", ) # setup the event weight lookup table process_weight_table = scipy.sparse.dok_matrix((max(process_ids) + 1, 1), dtype=np.float32) def fill_weight_table(process_inst: od.Process, xsec: float, sum_weights: float) -> None: if sum_weights == 0: logger.warning( f"zero sum of weights found for computing normalization weight for process '{process_inst.name}' " f"({process_inst.id}) in dataset '{self.dataset_inst.name}', going to use weight of 0.0", ) weight = 0.0 else: weight = norm_factor * xsec * lumi / sum_weights process_weight_table[process_inst.id, 0] = weight # get the luminosity lumi = float(self.config_inst.x.luminosity if self.luminosity is None else self.luminosity) # prepare info for the inclusive dataset inclusive_proc = self.inclusive_dataset.processes.get_first() try: inclusive_xsec = inclusive_proc.get_xsec(self.config_inst.campaign.ecm).nominal except KeyError as e: raise KeyError( f"no cross section registered for inclusive process {inclusive_proc} for center-of-mass energy of " f"{self.config_inst.campaign.ecm}", ) from e # compute the weight the inclusive dataset would have on its own without stitching if self.allow_stitching and self.dataset_inst == self.inclusive_dataset: inclusive_sum_weights = sum( dataset_selection_stats[self.inclusive_dataset.name]["sum_mc_weight_per_process"].values(), ) self.inclusive_weight = norm_factor * inclusive_xsec * lumi / inclusive_sum_weights # fill weights into the lut, depending on whether stitching is allowed / needed or not do_stitch = ( self.allow_stitching and self.get_xsecs_from_inclusive_datasets and (len(process_insts) > 1 or len(self.required_datasets) > 1) ) if do_stitch: logger.debug( f"using inclusive dataset '{self.inclusive_dataset.name}' and process '{inclusive_proc.name}' for cross " "section lookup", ) # optionally run the dataset lookup again in debug mode when stitching is_first_branch = getattr(task, "branch", None) == 0 if is_first_branch: self.get_stitching_datasets(debug=True) # extract branching ratios branching_ratios = self.get_br_from_inclusive_datasets( process_insts, dataset_selection_stats_br, merged_selection_stats_br, debug=is_first_branch, ) # fill the process weight table for process_inst, br in branching_ratios.items(): sum_weights = merged_selection_stats_sum_weights["sum_mc_weight_per_process"][str(process_inst.id)] fill_weight_table(process_inst, br * inclusive_xsec, sum_weights) else: # fill the process weight table with per-process cross sections for process_inst in process_insts: if self.config_inst.campaign.ecm not in process_inst.xsecs: raise KeyError( f"no cross section registered for process {process_inst} for center-of-mass energy of " f"{self.config_inst.campaign.ecm}", ) xsec = process_inst.get_xsec(self.config_inst.campaign.ecm).nominal sum_weights = merged_selection_stats_sum_weights["sum_mc_weight_per_process"][str(process_inst.id)] fill_weight_table(process_inst, xsec, sum_weights) # store lookup table and known process ids self.process_weight_table = process_weight_table self.known_process_ids = process_ids stitched_normalization_weights = normalization_weights.derive( "stitched_normalization_weights", cls_dict={ "weight_name": "normalization_weight", "get_xsecs_from_inclusive_datasets": True, "allow_stitching": True, }, ) stitched_normalization_weights_brs_from_processes = stitched_normalization_weights.derive( "stitched_normalization_weights_brs_from_processes", cls_dict={ "get_xsecs_from_inclusive_datasets": False, }, )