Source code for columnflow.plotting.plot_util

# coding: utf-8

"""
Some utils for plot functions.
"""

from __future__ import annotations

__all__ = []

import re
import math
import operator
import functools
from collections import OrderedDict

import law
import order as od
import scinum as sn

from columnflow.util import maybe_import, try_int, try_complex, safe_div, UNSET
from columnflow.hist_util import copy_axis, sum_hists
from columnflow.types import TYPE_CHECKING, Iterable, Any, Callable, Sequence, Hashable

np = maybe_import("numpy")
if TYPE_CHECKING:
    hist = maybe_import("hist")
    plt = maybe_import("matplotlib.pyplot")


logger = law.logger.get_logger(__name__)


label_options = {
    "wip": "Work in progress",
    "pre": "Preliminary",
    "pw": "Private work (CMS data/simulation)",
    "pwip": "Private work in progress (CMS)",
    "sim": "Simulation",
    "simwip": "Simulation work in progress",
    "simpre": "Simulation preliminary",
    "simpw": "Private work (CMS simulation)",
    "datapw": "Private work (CMS data)",
    "od": "OpenData",
    "odwip": "OpenData work in progress",
    "odpw": "OpenData private work",
    "public": "",
}


[docs] def get_cms_label(ax: plt.Axes, llabel: str) -> dict: """ Helper function to get the CMS label configuration. :param ax: The axis to plot the CMS label on. :param llabel: The left label of the CMS label. :return: A dictionary with the CMS label configuration. """ llabel = label_options.get(llabel, llabel) cms_label_kwargs = { "ax": ax, "llabel": llabel, "fontsize": 22, "data": False, } if "CMS" in llabel: cms_label_kwargs["exp"] = "" return cms_label_kwargs
[docs] def get_attr_or_aux(proc: od.AuxDataMixin, attr: str, default: Any) -> Any: if (value := getattr(proc, attr, UNSET)) != UNSET: return value if proc.has_aux(attr): return proc.get_aux(attr) return default
[docs] def round_dynamic(value: int | float) -> int | float: """ Rounds a *value* at various scales to a subjective, sensible precision. Rounding rules: - 0 -> 0 (int) - (0, 1) -> round to 1 significant digit (float) - [1, 10) -> round to 1 significant digit (int) - [10, inf) -> round to 2 significant digits (int) :param value: The value to round. :return: The rounded value. """ # determine significant digits digits = 1 if abs(value) < 10 else 2 # split into value and magnitude v_str, _, mag = sn.round_value(value, method=digits) # recombine value = float(v_str) * 10**mag # return with proper type return int(value) if value >= 1 else value
[docs] def apply_settings( instances: Iterable[od.AuxDataMixin], settings: dict[str, Any] | None, parent_check: Callable[[od.AuxDataMixin, str], bool] | None = None, ) -> None: """ applies settings from `settings` dictionary to a list of order objects `containers` :param instances: List of order instances to apply settings to. :param settings: Dictionary of settings to apply on the instances. Each key should correspond to the name of an instance and each value should be a dictionary with attributes that will be set on the instance either as a attribute or as an auxiliary. :param parent_check: Function that checks if an instance has a parent with a given name. """ if not settings: return for inst in instances: for name, inst_settings in (settings or {}).items(): if inst != name and not (callable(parent_check) and parent_check(inst, name)): continue for key, value in inst_settings.items(): # try attribute first, otherwise auxiliary entry try: setattr(inst, key, value) except (AttributeError, ValueError): inst.set_aux(key, value)
[docs] def hists_merge_cutflow_steps( hists: dict, ) -> dict: """ Make 'step' axis uniform among a set of histograms. Takes a dict of 1D histogram objects with a single 'step' axis of type *StrCategory*, computes the full list of possible 'step' values across all histograms, and returns a dict of histograms whose 'step' axis has a corresponding, uniform structure. The values and variances inserted for missing 'step' are taken from the previous existing step. """ # return immediately if fewer than two hists to merge if len(hists) < 2: return hists # get histogram instances hist_insts = list(hists.values()) # validate inputs if any(h.ndim != 1 for h in hist_insts): raise ValueError( "cannot merge cutflow steps: histograms must be one-dimensional", ) # ensure step structure is uniform by taking a linear # combination with only one nonzero coefficient hist_insts_merged = [] for coeffs in np.eye(len(hist_insts)): hist_row = sum( h * coeff for h, coeff in zip(hist_insts, coeffs) ) hist_insts_merged.append(hist_row) # fill missing entries from preceding steps merged_steps = list(hist_insts_merged[0].axes[0]) for hist_inst, hist_inst_merged in zip(hist_insts, hist_insts_merged): last_step = merged_steps[0] for merged_step in merged_steps[1:]: if merged_step not in hist_inst.axes[0]: hist_inst_merged[merged_step] = hist_inst_merged[last_step] else: last_step = merged_step # put merged hists into dict hists = { k: h for k, h in zip(hists, hist_insts_merged) } # return return hists
[docs] def apply_process_settings( hists: dict[Hashable, hist.Hist], process_settings: dict | None = None, ) -> tuple[dict[Hashable, hist.Hist], dict[str, Any]]: """ applies settings from `process_settings` dictionary to the `process_insts` """ # store info gathered along application of process settings that can be inserted to the style config process_style_config = {} # apply all settings on process insts apply_settings( hists.keys(), process_settings, parent_check=(lambda proc, parent_name: proc.has_parent_process(parent_name)), ) return hists, process_style_config
[docs] def apply_process_scaling(hists: dict[Hashable, hist.Hist]) -> dict[Hashable, hist.Hist]: # helper to compute the stack integral stack_integral = None def get_stack_integral() -> float: nonlocal stack_integral if stack_integral is None: stack_integral = sum( remove_residual_axis_single(proc_h, "shift", select_value="nominal").sum().value for proc, proc_h in hists.items() if proc.is_mc and not get_attr_or_aux(proc, "unstack", False) ) return stack_integral for proc_inst, h in hists.items(): # apply "scale" setting directly to the hists scale_factor = get_attr_or_aux(proc_inst, "scale", None) if scale_factor == "stack": # compute the scale factor and round h_no_shift = remove_residual_axis_single(h, "shift", select_value="nominal") scale_factor = round_dynamic(safe_div(get_stack_integral(), h_no_shift.sum().value)) or 1 if try_int(scale_factor): scale_factor = int(scale_factor) hists[proc_inst] = h * scale_factor scale_factor_str = ( str(scale_factor) if scale_factor < 1e5 else re.sub(r"e(\+?)(-?)(0*)", r"e\2", f"{scale_factor:.1e}") ) if scale_factor != 1: proc_inst.label = apply_label_placeholders( proc_inst.label, apply="SCALE", scale=scale_factor_str, ) # remove remaining scale placeholders proc_inst.label = remove_label_placeholders(proc_inst.label, drop="SCALE") return hists
[docs] def apply_variable_settings( hists: dict[Hashable, hist.Hist], variable_insts: list[od.Variable], variable_settings: dict | None = None, ) -> tuple[dict[Hashable, hist.Hist], dict[od.Variable, dict[str, Any]]]: """ applies settings from *variable_settings* dictionary to the *variable_insts*; the *rebin*, *overflow*, *underflow*, and *slice* settings are directly applied to the histograms """ import hist # store info gathered along application of variable settings that can be inserted to the style config variable_style_config = {} # apply all settings on variable insts apply_settings(variable_insts, variable_settings) # apply certain setting directly to histograms for var_inst in variable_insts: variable_style_config[var_inst] = {} # rebinning rebin_factor = get_attr_or_aux(var_inst, "rebin", None) if try_int(rebin_factor): for proc_inst, h in list(hists.items()): rebin_factor = int(rebin_factor) h = h[{var_inst.name: hist.rebin(rebin_factor)}] hists[proc_inst] = h # overflow and underflow bins overflow = get_attr_or_aux(var_inst, "overflow", False) underflow = get_attr_or_aux(var_inst, "underflow", False) if overflow or underflow: for proc_inst, h in list(hists.items()): h = use_flow_bins(h, var_inst.name, underflow=underflow, overflow=overflow) hists[proc_inst] = h # slicing slices = get_attr_or_aux(var_inst, "slice", None) if ( slices and isinstance(slices, Iterable) and len(slices) >= 2 and try_complex(slices[0]) and try_complex(slices[1]) ): slice_0 = int(slices[0]) if try_int(slices[0]) else complex(slices[0]) slice_1 = int(slices[1]) if try_int(slices[1]) else complex(slices[1]) for proc_inst, h in list(hists.items()): h = h[{var_inst.name: slice(slice_0, slice_1)}] hists[proc_inst] = h # additional x axis transformations for trafo in law.util.make_list(get_attr_or_aux(var_inst, "x_transformations", None) or []): # forced representation into equal bins if trafo in {"equal_distance_with_edges", "equal_distance_with_indices"}: hists, orig_edges = rebin_equal_width(hists, var_inst.name) new_edges = list(hists.values())[0].axes[-1].edges # store edge values as well as ticks if needed ax_cfg = {"xlim": (new_edges[0], new_edges[-1])} if trafo == "equal_distance_with_edges": # optionally round edges rnd = get_attr_or_aux(var_inst, "x_edge_rounding", (lambda e: e)) edge_labels = [rnd(e) for e in orig_edges] ax_cfg |= {"xmajorticks": new_edges, "xmajorticklabels": edge_labels, "xminorticks": []} variable_style_config[var_inst].setdefault("ax_cfg", {}).update(ax_cfg) variable_style_config[var_inst].setdefault("rax_cfg", {}).update(ax_cfg) else: raise ValueError(f"unknown x transformation '{trafo}'") return hists, variable_style_config
[docs] def remove_negative_contributions(hists: dict[Hashable, hist.Hist]) -> dict[Hashable, hist.Hist]: _hists = hists.copy() for proc_inst, h in hists.items(): h = h.copy() h.view().value[h.view().value < 0] = 0 _hists[proc_inst] = h return _hists
[docs] def use_flow_bins( h_in: hist.Hist, axis_name: str | int, underflow: bool = True, overflow: bool = True, ) -> hist.Hist: """ Adds content of the flow bins of axis *axis_name* of histogram *h_in* to the first/last bin. :param h_in: Input histogram :param axis_name: Name or index of the axis of interest. :param underflow: Whether to add the content of the underflow bin to the first bin of axis *axis_name. :param overflow: Whether to add the content of the overflow bin to the last bin of axis *axis_name*. :return: Copy of the histogram with underflow and/or overflow content added to the first/last bin of the histogram. """ # work on a copy of the histogram h_out = h_in.copy() # nothing to do if neither flag is set if not overflow and not underflow: print(f"{use_flow_bins.__name__} has nothing to do since overflow and underflow are set to False") return h_out # determine the index of the axis of interest and check if it has flow bins activated axis_idx = axis_name if isinstance(axis_name, int) else h_in.axes.name.index(axis_name) h_view = h_out.view(flow=True) if h_out.view().shape[axis_idx] + 2 != h_view.shape[axis_idx]: raise Exception(f"We expect axis {axis_name} to have assigned an underflow and overflow bin") # function to get slice of index *idx* from axis *axis_idx* slice_func = lambda idx: tuple( [slice(None)] * axis_idx + [idx] + [slice(None)] * (len(h_out.shape) - axis_idx - 1), ) if overflow: # replace last bin with last bin + overflow h_view.value[slice_func(-2)] = h_view.value[slice_func(-2)] + h_view.value[slice_func(-1)] h_view.value[slice_func(-1)] = 0 h_view.variance[slice_func(-2)] = h_view.variance[slice_func(-2)] + h_view.variance[slice_func(-1)] h_view.variance[slice_func(-1)] = 0 if underflow: # replace last bin with last bin + overflow h_view.value[slice_func(1)] = h_view.value[slice_func(0)] + h_view.value[slice_func(1)] h_view.value[slice_func(0)] = 0 h_view.variance[slice_func(1)] = h_view.variance[slice_func(0)] + h_view.variance[slice_func(1)] h_view.variance[slice_func(0)] = 0 return h_out
[docs] def apply_density(hists: dict, density: bool = True) -> dict: """ Scales number of histogram entries to bin widths. """ if not density: return hists for key, h in hists.items(): # bin area safe for multi-dimensional histograms area = functools.reduce(operator.mul, h.axes.widths) # scale hist by bin area hists[key] = h / area return hists
[docs] def remove_residual_axis_single( h: hist.Hist, ax_name: str, max_bins: int = 1, select_value: Any = None, ) -> hist.Hist: import hist # force always returning a copy h = h.copy() # nothing to do if the axis is not present if ax_name not in h.axes.name: return h # when a selection is given, select the corresponding value if select_value is not None: h = h[{ax_name: [hist.loc(select_value)]}] # check remaining axis n_bins = len(h.axes[ax_name]) if n_bins > max_bins: raise Exception( f"axis '{ax_name}' of histogram has {n_bins} bins whereas at most {max_bins} bins are " f"accepted for removal of residual axis", ) # accumulate remaining axis return h[{ax_name: sum}]
[docs] def remove_residual_axis( hists: dict, ax_name: str, max_bins: int = 1, select_value: Any = None, ) -> dict: """ Removes axis named 'ax_name' if existing and there is only a single bin in the axis; raises Exception otherwise """ return { key: remove_residual_axis_single(h, ax_name, max_bins=max_bins, select_value=select_value) for key, h in hists.items() }
[docs] def prepare_style_config( config_inst: od.Config, category_inst: od.Category, variable_inst: od.Variable, density: bool | None = False, shape_norm: bool | None = False, yscale: str | None = "", **kwargs: Any, ) -> dict: """ small helper function that sets up a default style config based on the instances of the config, category and variable """ if not yscale: yscale = "log" if variable_inst.log_y else "linear" xlim = ( variable_inst.x("x_min", variable_inst.x_min), variable_inst.x("x_max", variable_inst.x_max), ) # build the label from category and optional variable selection labels cat_label = join_labels(category_inst.label, variable_inst.x("selection_label", None)) # unit format on axes (could be configurable) unit_format = "{title} [{unit}]" if density: ylabel = variable_inst.get_full_y_title( bin_width=False, unit=variable_inst.unit or "unit", unit_format="{title} / {unit}", ) else: ylabel = variable_inst.get_full_y_title( bin_width=False, unit=False, unit_format=unit_format, ) style_config = { "ax_cfg": { "xlim": xlim, "ylabel": ylabel, "xlabel": variable_inst.get_full_x_title(unit_format=unit_format), "yscale": yscale, "xscale": "log" if variable_inst.log_x else "linear", "xrotation": variable_inst.x("x_label_rotation", None), }, "rax_cfg": { "ylabel": "Data / MC", "xlabel": variable_inst.get_full_x_title(unit_format=unit_format), "xrotation": variable_inst.x("x_label_rotation", None), }, "legend_cfg": {}, "annotate_cfg": {"text": cat_label or ""}, "cms_label_cfg": { "lumi": f"{0.001 * config_inst.x.luminosity.get('nominal'):.1f}", # /pb -> /fb "com": config_inst.campaign.ecm, }, } # disable minor ticks based on variable_inst axis_type = variable_inst.x("axis_type", "variable") if variable_inst.discrete_x or "int" in axis_type: # remove the "xscale" attribute since it messes up the bin edges style_config["ax_cfg"].pop("xscale") style_config["ax_cfg"]["xminorticks"] = [] if variable_inst.discrete_y: style_config["ax_cfg"]["yminorticks"] = [] return style_config
[docs] def prepare_stack_plot_config( hists: OrderedDict, shape_norm: bool | None = False, hide_stat_errors: bool | None = None, shift_insts: Sequence[od.Shift] | None = None, density: bool = False, **kwargs, ) -> OrderedDict: """ Prepares a plot config with one entry to create plots containing a stack of backgrounds with uncertainty bands, unstacked processes as lines and data entrys with errorbars. """ import hist # separate histograms into stack, lines and data hists mc_hists, mc_colors, mc_edgecolors, mc_labels = [], [], [], [] mc_syst_hists = [] line_hists, line_colors, line_labels, line_hide_stat_errors = [], [], [], [] data_hists, data_hide_stat_errors = [], [] data_label = None default_shift = shift_insts[0].name if len(shift_insts) == 1 else "nominal" for process_inst, h in hists.items(): # if given, per-process setting overrides task parameter proc_hide_stat_errors = get_attr_or_aux(process_inst, "hide_stat_errors", hide_stat_errors) if process_inst.is_data: data_hists.append(remove_residual_axis_single(h, "shift", select_value=default_shift)) data_hide_stat_errors.append(proc_hide_stat_errors) if data_label is None: data_label = process_inst.label elif get_attr_or_aux(process_inst, "unstack", False): line_hists.append(remove_residual_axis_single(h, "shift", select_value=default_shift)) line_colors.append(process_inst.color1) line_labels.append(process_inst.label) line_hide_stat_errors.append(proc_hide_stat_errors) else: mc_hists.append(remove_residual_axis_single(h, "shift", select_value=default_shift)) mc_colors.append(process_inst.color1) mc_edgecolors.append(process_inst.color2) mc_labels.append(process_inst.label) if "shift" in h.axes.name and h.axes["shift"].size > 1: mc_syst_hists.append(h) h_data, h_mc, h_mc_stack = None, None, None if data_hists: h_data = sum_hists(data_hists) if mc_hists: h_mc = sum_hists(mc_hists) h_mc_stack = hist.Stack(*mc_hists) # setup plotting configs plot_config = OrderedDict() shape_norm_func = kwargs.get("shape_norm_func", lambda h, shape_norm: sum(h.values()) if shape_norm else 1) # draw stack if h_mc_stack is not None: mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_stack"] = { "method": "draw_stack", "hist": h_mc_stack, "kwargs": { "norm": mc_norm, "label": mc_labels, "color": mc_colors, "edgecolor": mc_edgecolors, "linewidth": [(0 if c is None else 1) for c in mc_colors], }, } # draw lines for i, h in enumerate(line_hists): line_norm = shape_norm_func(h, shape_norm) plot_config[f"line_{i}"] = plot_cfg = { "method": "draw_hist", "hist": h, "kwargs": { "norm": line_norm, "label": line_labels[i], "color": line_colors[i], "error_type": "variance", }, # "ratio_kwargs": { # "norm": h.values(), # "color": line_colors[i], # }, } # suppress error bars by overriding `yerr` if line_hide_stat_errors[i]: for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: plot_cfg[key]["yerr"] = False # draw statistical error for stack if h_mc_stack is not None and not hide_stat_errors: mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_stat_unc"] = { "method": "draw_stat_error_bands", "hist": h_mc, "kwargs": {"norm": mc_norm, "label": "MC stat. unc."}, "ratio_kwargs": {"norm": h_mc.values()}, } # draw systematic error for stack if h_mc_stack is not None and mc_syst_hists: mc_norm = shape_norm_func(h_mc, shape_norm) plot_config["mc_syst_unc"] = { "method": "draw_syst_error_bands", "hist": h_mc, "kwargs": { "syst_hists": mc_syst_hists, "shift_insts": shift_insts, "norm": mc_norm, "label": "MC syst. unc.", }, "ratio_kwargs": { "syst_hists": mc_syst_hists, "shift_insts": shift_insts, "norm": h_mc.values(), }, } # draw data if data_hists: data_norm = shape_norm_func(h_data, shape_norm) plot_config["data"] = plot_cfg = { "method": "draw_errorbars", "hist": h_data, "kwargs": { "norm": data_norm, "label": data_label or "Data", "error_type": "poisson_unweighted", "density": density, }, } if h_mc is not None: plot_config["data"]["ratio_kwargs"] = { "norm": h_mc.values() * data_norm / mc_norm, "error_type": "poisson_unweighted", "density": density, } # suppress error bars by overriding `yerr` if any(data_hide_stat_errors): for key in ("kwargs", "ratio_kwargs"): if key in plot_cfg: plot_cfg[key]["yerr"] = False return plot_config
[docs] def split_ax_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: """ Split the given dictionary into two dictionaries based on the keys that are valid for matplotlib's ``ax.set()`` function, and all others, potentially accepted by :py:func:`apply_ax_kwargs`. """ set_kwargs, other_kwargs = {}, {} other_keys = { "xmajorticks", "xminorticks", "xmajorticklabels", "xminorticklabels", "xloc", "xrotation", "ymajorticks", "yminorticks", "yloc", "yrotation", "xticklabelformat", "yticklabelformat", "xoffsettext", "yoffsettext", } for key, value in kwargs.items(): (other_kwargs if key in other_keys else set_kwargs)[key] = value return set_kwargs, other_kwargs
[docs] def apply_ax_kwargs(ax: plt.Axes, kwargs: dict[str, Any]) -> None: """ Apply the given keyword arguments to the given axis, splitting them into those that are valid for ``ax.set()`` and those that are not, and applying them separately. """ # split set_kwargs, other_kwargs = split_ax_kwargs(kwargs) # apply standard ones ax.set(**set_kwargs) # apply others if other_kwargs.get("xmajorticks") is not None: ax.set_xticks(other_kwargs.get("xmajorticks"), minor=False) if other_kwargs.get("ymajorticks") is not None: ax.set_yticks(other_kwargs.get("ymajorticks"), minor=False) if other_kwargs.get("xminorticks") is not None: ax.set_xticks(other_kwargs.get("xminorticks"), minor=True) if other_kwargs.get("yminorticks") is not None: ax.set_yticks(other_kwargs.get("yminorticks"), minor=True) if other_kwargs.get("xmajorticklabels") is not None: ax.set_xticklabels(other_kwargs.get("xmajorticklabels"), minor=False) if other_kwargs.get("xminorticklabels") is not None: ax.set_xticklabels(other_kwargs.get("xminorticklabels"), minor=True) if other_kwargs.get("xloc") is not None: ax.set_xlabel(ax.get_xlabel(), loc=other_kwargs.get("xloc")) if other_kwargs.get("yloc") is not None: ax.set_ylabel(ax.get_ylabel(), loc=other_kwargs.get("yloc")) if other_kwargs.get("xrotation") is not None: ax.tick_params(axis="x", labelrotation=other_kwargs.get("xrotation")) if other_kwargs.get("yrotation") is not None: ax.tick_params(axis="y", labelrotation=other_kwargs.get("yrotation")) if other_kwargs.get("xticklabelformat") is not None: _kwargs = other_kwargs["xticklabelformat"].copy() if not callable(getattr(ax.xaxis.major.formatter, "set_scientific", None)): if _kwargs.get("style") == "sci": _kwargs.pop("style") _kwargs.pop("useMathText", None) ax.ticklabel_format(axis="x", **_kwargs) if other_kwargs.get("yticklabelformat") is not None: _kwargs = other_kwargs["yticklabelformat"].copy() if not callable(getattr(ax.yaxis.major.formatter, "set_scientific", None)): if _kwargs.get("style") == "sci": _kwargs.pop("style") _kwargs.pop("useMathText", None) ax.ticklabel_format(axis="y", **_kwargs) if other_kwargs.get("xoffsettext") is not None: ax.xaxis.get_offset_text().set(**other_kwargs["xoffsettext"]) if other_kwargs.get("yoffsettext") is not None: ax.yaxis.get_offset_text().set(**other_kwargs["yoffsettext"])
[docs] def get_position(minimum: float, maximum: float, factor: float = 1.4, logscale: bool = False) -> float: """ get a relative position between a min and max value based on the scale """ if logscale: value = 10 ** ((math.log10(maximum) - math.log10(minimum)) * factor + math.log10(minimum)) else: value = (maximum - minimum) * factor + minimum return value
[docs] def join_labels( *labels: str | list[str | None] | None, inline_sep: str = ",", multiline_sep: str = "\n", ) -> str: if not labels: return "" # the first label decides whether the overall label is inline or multiline inline = isinstance(labels[0], str) # collect parts parts = sum(map(law.util.make_list, labels), []) # join and return return (inline_sep if inline else multiline_sep).join(filter(None, parts))
[docs] def reduce_with(spec: str | float | callable, values: list[float]) -> float: """ Reduce an array of *values* to a single value using the function indicated by *spec*. Intended as a helper for resolving range specifications supplied as strings. Supported specifiers are: * 'min': minimum value * 'max': maximum value * 'maxabs': the absolute value of the maximum or minimum, whichever is larger * 'minabs': the absolute value of the maximum or minimum, whichever is smaller A hyphen (``-``) can be prefixed to any specifier to return its negative. Callables can be passed as *spec* and should take a single array-valued argument and return a single value. Floats passes as specifiers will be returned directly. """ # if callable, apply to array if callable(spec): return spec(values) # if not a string, assume fixed literal and return if not isinstance(spec, str): return spec # determine sign factor = 1. if spec.startswith("-"): spec = spec[1:] factor = -1. if spec not in reduce_with.funcs: available = ", ".join(reduce_with.funcs) raise ValueError( f"unknown reduction function '{spec}'. " f"Available: {available}", ) func = reduce_with.funcs[spec] values = np.asarray(values) return factor * func(values)
reduce_with.funcs = { "min": lambda v: np.nanmin(v), "max": lambda v: np.nanmax(v), "maxabs": lambda v: max(abs(np.nanmax(v)), abs(np.nanmin(v))), "minabs": lambda v: min(abs(np.nanmax(v)), abs(np.nanmin(v))), }
[docs] def broadcast_1d_to_nd(x: np.array, final_shape: list, axis: int = 1) -> np.array: """ Helper function to broadcast a 1d array *x* to an nd array with shape *final_shape*. The length of *x* should be the same as *final_shape[axis]*. """ if len(x.shape) != 1: raise Exception("Only 1d arrays allowed") if final_shape[axis] != x.shape[0]: raise Exception(f"Initial shape should match with final shape in requested axis {axis}") initial_shape = [1] * len(final_shape) initial_shape[axis] = x.shape[0] x = np.reshape(x, initial_shape) x = np.broadcast_to(x, final_shape) return x
[docs] def broadcast_nminus1d_to_nd(x: np.array, final_shape: list, axis: int = 1) -> np.array: """ Helper function to broadcast a (n-1)d array *x* to an nd array with shape *final_shape*. *final_shape* should be the same as *x.shape* except that the axis *axis* is missing. """ if len(final_shape) - len(x.shape) != 1: raise Exception("Only (n-1)d arrays allowed") # shape comparison between x and final_shape _init_shape = list(final_shape) _init_shape.pop(axis) if _init_shape != list(x.shape): raise Exception( f"input shape ({x.shape}) should agree with final_shape {final_shape} " f"after inserting new axis at {axis}", ) initial_shape = list(x.shape) initial_shape.insert(axis, 1) x = np.reshape(x, initial_shape) x = np.broadcast_to(x, final_shape) return x
[docs] def get_profile_width(h_in: hist.Hist, axis: int = 1) -> tuple[np.array, np.array]: """ Function that takes a histogram *h_in* and returns the mean and width when profiling over the axis *axis*. """ values = h_in.values() centers = h_in.axes[axis].centers centers = broadcast_1d_to_nd(centers, values.shape, axis) num = np.sum(values * centers, axis=axis) den = np.sum(values, axis=axis) print(num.shape) with np.errstate(invalid="ignore"): mean = num / den _mean = broadcast_nminus1d_to_nd(mean, values.shape, axis) width = np.sum(values * (centers - _mean) ** 2, axis=axis) / den return mean, width
[docs] def get_profile_variations(h_in: hist.Hist, axis: int = 1) -> dict[str, hist.Hist]: """ Returns a profile histogram plus the up and down variations of the profile from a normal histogram with N-1 axes. The axis given is profiled over and removed from the final histograms. """ # start with profile such that we only have to replace the mean # NOTE: how do the variances change for the up/down variations? h_profile = h_in.profile(axis) mean, variance = get_profile_width(h_in, axis=axis) h_nom = h_profile.copy() h_up = h_profile.copy() h_down = h_profile.copy() # we modify the view of h_profile -> do not use h_profile anymore! h_view = h_profile.view() h_view.value = mean h_nom[...] = h_view h_view.value = mean + np.sqrt(variance) h_up[...] = h_view h_view.value = mean - np.sqrt(variance) h_down[...] = h_view return {"nominal": h_nom, "up": h_up, "down": h_down}
[docs] def blind_sensitive_bins( hists: dict[od.Process, hist.Hist], config_inst: od.Config, threshold: float, ) -> dict[od.Process, hist.Hist]: """ Function that takes a histogram *h_in* and blinds the values of the profile over the axis *axis* that are below a certain threshold *threshold*. The function needs an entry in the process_groups key of the config auxiliary that is called "signals" to know, where the signal processes are defined (regex allowed). The histograms are not changed inplace, but copies of the modified histograms are returned. """ # build the logic to seperate signal processes signal_procs: set[od.Process] = { config_inst.get_process(proc) for proc in config_inst.x.process_groups.get("signals", []) } check_if_signal = lambda proc: any(signal == proc or signal.has_process(proc) for signal in signal_procs) # separate histograms into signals, backgrounds and data hists and calculate sums signals = {proc: h for proc, h in hists.items() if proc.is_mc and check_if_signal(proc)} data = {proc: h.copy() for proc, h in hists.items() if proc.is_data} backgrounds = {proc: h for proc, h in hists.items() if proc.is_mc and proc not in signals} # Return hists unchanged in case any of the three dicts is empty. if not signals or not backgrounds or not data: logger.info( "one of the following categories: signals, backgrounds or data was not found in the given processes, " "returning unchanged histograms", ) return hists # get nominal signal and background yield sums per bin signals_sum = sum(remove_residual_axis(signals, "shift", select_value="nominal").values()) backgrounds_sum = sum(remove_residual_axis(backgrounds, "shift", select_value="nominal").values()) # calculate sensitivity by S / sqrt(S + B) sensitivity = signals_sum.values() / np.sqrt(signals_sum.values() + backgrounds_sum.values()) mask = sensitivity >= threshold # adjust the mask to blind the bins inbetween blinded ones if sum(mask) > 1: first_ind, last_ind = np.where(mask)[0][::sum(mask) - 1] mask[first_ind:last_ind] = True # set data points in masked region to zero for proc, h in data.items(): h.values()[..., mask] = -999 h.variances()[..., mask] = 0 # merge all histograms hists = law.util.merge_dicts(signals, backgrounds, data) return hists
[docs] def rebin_equal_width( hists: dict[Hashable, hist.Hist], axis_name: str, ) -> tuple[dict[Hashable, hist.Hist], np.ndarray]: """ In a dictionary, rebins an axis named *axis_name* of all histograms to have the same amount of bins but with equal width. This is achieved by using integer edge values starting at 0. The original edge values are returned as well. Bin contents are not changed but copied to the rebinned histograms. :param hists: Dictionary of histograms to rebin. :param axis_name: Name of the axis to rebin. :return: Tuple of the rebinned histograms and the new bin edges. """ import hist # get the variable axis from the first histogram assert hists for var_index, var_axis in enumerate(list(hists.values())[0].axes): if var_axis.name == axis_name: break else: raise ValueError(f"axis '{axis_name}' not found in histograms") assert isinstance(var_axis, (hist.axis.Variable, hist.axis.Regular)) orig_edges = var_axis.edges # prepare arguments for the axis copy if isinstance(var_axis, hist.axis.Variable): axis_kwargs = {"edges": list(range(len(orig_edges)))} else: # hist.axis.Regular axis_kwargs = {"start": orig_edges[0], "stop": orig_edges[-1]} # rebin all histograms new_hists = type(hists)() for key, h in hists.items(): # create a new histogram new_axes = h.axes[:var_index] + (copy_axis(var_axis, **axis_kwargs),) + h.axes[var_index + 1:] new_h = hist.Hist(*new_axes, storage=h.storage_type()) # copy contents and save new_h.view()[...] = h.view() new_hists[key] = new_h return new_hists, orig_edges
[docs] def apply_label_placeholders( label: str, apply: str | Sequence[str] | None = None, skip: str | Sequence[str] | None = None, **kwargs: Any, ) -> str: """ Interprets placeholders in the format "__NAME__" in a label and returns an updated label. Currently supported placeholders are: - SHORT: removes everything (and including) the placeholder - BREAK: inserts a line break - SCALE: inserts a scale factor, passed as "scale" in kwargs; when "scale_format" is given as well, the scale factor is formatted accordingly *apply* and *skip* can be used to de/select certain placeholders. """ # handle apply/skip decisions if apply: _apply = set(p.upper() for p in law.util.make_list(apply)) do_apply = lambda p: p in _apply elif skip: _skip = set(p.upper() for p in law.util.make_list(skip)) do_apply = lambda p: p not in _skip else: do_apply = lambda p: True # shortening if do_apply("SHORT"): label = re.sub(r"__SHORT__.*", "", label) # lines breaks if do_apply("BREAK"): label = label.replace("__BREAK__", "\n") # scale factor if do_apply("SCALE") and "scale" in kwargs: scale_str = kwargs.get("scale_format", "$\\times${}").format(kwargs["scale"]) if "__SCALE__" in label: label = label.replace("__SCALE__", scale_str) else: label += scale_str return label
[docs] def remove_label_placeholders( label: str, keep: str | Sequence[str] | None = None, drop: str | Sequence[str] | None = None, ) -> str: # when placeholders should be kept, determine all existing ones and identify remaining to drop if keep: keep = law.util.make_list(keep) placeholders = re.findall("__([^_]+)__", label) drop = list(set(placeholders) - set(keep)) # drop specific placeholders or all if drop: drop = law.util.make_list(drop) sel = f"({'|'.join(d.upper() for d in drop)})" else: sel = "[A-Z0-9]+" return re.sub(f"__{sel}__", "", label)
[docs] def calculate_stat_error(h: hist.Hist, error_type: str, density: bool = True) -> np.ndarray: """ Calculate the error to be plotted for the given histogram *h*. Supported error types are: - "variance": the plotted error is the square root of the variance for each bin - "poisson_unweighted": the plotted error is the poisson error for each bin - "poisson_weighted": the plotted error is the poisson error for each bin, weighted by the variance """ # undo density if needed if density: area = functools.reduce(operator.mul, h.axes.widths) h = h * area # determine the error type if error_type == "variance": yerr = h.view().variance ** 0.5 elif error_type in {"poisson_unweighted", "poisson_weighted"}: # compute asymmetric poisson confidence interval from hist.intervals import poisson_interval variances = h.view().variance if error_type == "poisson_weighted" else None values = h.view().value confidence_interval = poisson_interval(values, variances) # negative values are considerd as blinded bins -> set confidence interval to 0 confidence_interval[:, values < 0] = 0 if error_type == "poisson_weighted": # might happen if some bins are empty, see https://github.com/scikit-hep/hist/blob/5edbc25503f2cb8193cc5ff1eb71e1d8fa877e3e/src/hist/intervals.py#L74 # noqa: E501 confidence_interval[np.isnan(confidence_interval)] = 0 elif np.any(np.isnan(confidence_interval)): raise ValueError("Unweighted Poisson interval calculation returned NaN values, check Hist package") # calculate the error yerr_lower = values - confidence_interval[0] yerr_upper = confidence_interval[1] - values yerr = np.array([yerr_lower, yerr_upper]) if np.any(yerr < 0): logger.warning("found yerr < 0, forcing to 0; this should not happen, please check your histogram") yerr[yerr < 0] = 0 else: raise ValueError(f"unknown error type '{error_type}'") # re-apply density if needed if density: area = functools.reduce(operator.mul, h.axes.widths) h = h / area yerr = yerr / area return yerr