Source code for columnflow.plotting.plot_functions_2d

# coding: utf-8

"""
Example 2d plot functions.
"""

from __future__ import annotations

from collections import OrderedDict
from functools import partial
from unittest.mock import patch

import law

from columnflow.util import maybe_import
from columnflow.plotting.plot_util import (
    remove_residual_axis,
    apply_variable_settings,
    apply_process_settings,
    apply_process_scaling,
    apply_density,
    get_position,
    reduce_with,
)

hist = maybe_import("hist")
np = maybe_import("numpy")
mpl = maybe_import("matplotlib")
plt = maybe_import("matplotlib.pyplot")
mplhep = maybe_import("mplhep")
od = maybe_import("order")
mticker = maybe_import("matplotlib.ticker")


[docs] def plot_2d( hists: OrderedDict, config_inst: od.Config, category_inst: od.Category, variable_insts: list[od.Variable], shift_insts: list[od.Shift], style_config: dict | None = None, density: bool | None = False, shape_norm: bool | None = False, zscale: str | None = "", # z axis range zlim: tuple | None = None, # how to handle bins with values outside the z range extremes: str | None = "", # colors to use for marking out-of-bounds values extreme_colors: tuple[str] | None = None, colormap: str | None = "", skip_legend: bool = False, cms_label: str = "wip", process_settings: dict | None = None, variable_settings: dict | None = None, **kwargs, ) -> plt.Figure: # remove shift axis from histograms hists = remove_residual_axis(hists, "shift") hists, process_style_config = apply_process_settings(hists, process_settings) hists, variable_style_config = apply_variable_settings(hists, variable_insts, variable_settings) hists = apply_process_scaling(hists) if density: hists = apply_density(hists, density) # use CMS plotting style plt.style.use(mplhep.style.CMS) fig, ax = plt.subplots() # how to handle yscale information from 2 variable insts? if not zscale: zscale = "log" if (variable_insts[0].log_y or variable_insts[1].log_y) else "linear" # how to handle bin values outside plot range if not extremes: extremes = "color" # add all processes into 1 histogram h_sum = sum(list(hists.values())[1:], list(hists.values())[0].copy()) if shape_norm: h_sum = h_sum / h_sum.sum().value # mask bins without any entries (variance == 0) h_view = h_sum.view() h_view.value[h_view.variance == 0] = np.nan # check histogram value range vmin, vmax = np.nanmin(h_sum.values()), np.nanmax(h_sum.values()) vmin, vmax = np.nan_to_num(np.array([vmin, vmax]), 0) # default to full z range if zlim is None: zlim = ("min", "max") # resolve string specifiers like "min", "max", etc. zlim = tuple(reduce_with(lim, h_sum.values()) for lim in zlim) # if requested, hide or clip bins outside specified plot range if extremes == "hide": h_view.value[h_view.value < zlim[0]] = np.nan h_view.value[h_view.value > zlim[1]] = np.nan elif extremes == "clip": h_view.value[h_view.value < zlim[0]] = zlim[0] h_view.value[h_view.value > zlim[1]] = zlim[1] # update histogram values from view h_sum[...] = h_view # choose appropriate colorbar normalization # based on scale type and histogram content # log scale (turning linear for low values) if zscale == "log": # use SymLogNorm to correctly handle both positive and negative values cbar_norm = mpl.colors.SymLogNorm( vmin=zlim[0], vmax=zlim[1], # TODO: better heuristics? linscale=1.0, linthresh=max(0.05 * min(abs(zlim[0]), abs(zlim[1])), 1e-3), ) # linear scale else: cbar_norm = mpl.colors.Normalize( vmin=zlim[0], vmax=zlim[1], ) # obtain colormap cmap = plt.get_cmap(colormap or "viridis") # use dark and light gray to mark extreme values if extremes == "color": # choose light/dark order depending on the # lightness of first/last colormap color if not extreme_colors: extreme_colors = ["#444444", "#bbbbbb"] if sum(cmap(0.0)[:3]) > sum(cmap(1.0)[:3]): extreme_colors = extreme_colors[::-1] # copy if colormap with extreme colors set cmap = cmap.with_extremes( under=extreme_colors[0], over=extreme_colors[1], ) # unit format on axes (could be configurable) unit_format = "{title} [{unit}]" # setup style config # TODO: some kind of z-label is still missing default_style_config = { "ax_cfg": { "xlim": (variable_insts[0].x_min, variable_insts[0].x_max), "ylim": (variable_insts[1].x_min, variable_insts[1].x_max), "xlabel": variable_insts[0].get_full_x_title(unit_format=unit_format), "ylabel": variable_insts[1].get_full_x_title(unit_format=unit_format), "xscale": "log" if variable_insts[0].log_x else "linear", "yscale": "log" if variable_insts[1].log_x else "linear", }, "legend_cfg": { "title": "Process" if len(hists.keys()) == 1 else "Processes", "handles": [mpl.lines.Line2D([0], [0], lw=0) for proc_inst in hists.keys()], # dummy handle "labels": [proc_inst.label for proc_inst in hists.keys()], "ncol": 1, "loc": "upper right", }, "cms_label_cfg": { "lumi": round(0.001 * config_inst.x.luminosity.get("nominal"), 2), # /pb -> /fb }, "plot2d_cfg": { "norm": cbar_norm, "cmap": cmap, # "labels": True, # this enables displaying numerical values for each bin, but needs some optimization "cbar": True, "cbarextend": True, }, "annotate_cfg": { "text": category_inst.label, }, } style_config = law.util.merge_dicts( default_style_config, process_style_config, variable_style_config[variable_insts[0]], variable_style_config[variable_insts[1]], style_config, deep=True, ) # apply style_config ax.set(**style_config["ax_cfg"]) if not skip_legend: ax.legend(**style_config["legend_cfg"]) if variable_insts[0].discrete_x: ax.set_xticks([], minor=True) if variable_insts[1].discrete_x: ax.set_yticks([], minor=True) # annotation of category label annotate_kwargs = { "text": "", "xy": ( get_position(*ax.get_xlim(), factor=0.05, logscale=False), get_position(*ax.get_ylim(), factor=0.95, logscale=False), ), "xycoords": "data", "color": "black", "fontsize": 22, "horizontalalignment": "left", "verticalalignment": "top", } annotate_kwargs.update(default_style_config.get("annotate_cfg", {})) plt.annotate(**annotate_kwargs) # cms label if cms_label != "skip": label_options = { "wip": "Work in progress", "pre": "Preliminary", "pw": "Private work", "sim": "Simulation", "simwip": "Simulation work in progress", "simpre": "Simulation preliminary", "simpw": "Simulation private work", "od": "OpenData", "odwip": "OpenData work in progress", "odpw": "OpenData private work", "public": "", } cms_label_kwargs = { "ax": ax, "llabel": label_options.get(cms_label, cms_label), "fontsize": 22, "data": False, } cms_label_kwargs.update(style_config.get("cms_label_cfg", {})) mplhep.cms.label(**cms_label_kwargs) # decide at which ends of the colorbar to draw symbols # indicating that there are values outside the range if extremes == "hide": extend = "neither" elif vmax > zlim[1] and vmin < zlim[0]: extend = "both" elif vmin < zlim[0]: extend = "min" elif vmax > zlim[1]: extend = "max" else: extend = "neither" # call plot method, patching the colorbar function # called internally by mplhep to draw the extension symbols with patch.object(plt, "colorbar", partial(plt.colorbar, extend=extend)): h_sum.plot2d(ax=ax, **style_config["plot2d_cfg"]) # fix color bar minor ticks with SymLogNorm if isinstance(cbar_norm, mpl.colors.SymLogNorm): # returned collections can vary -> brute-force set # norm on all colorbars that are found cbars = { coll.colorbar for coll in ax.collections if coll.colorbar } for cbar in cbars: _scale = cbar.ax.yaxis._scale _scale.subs = [2, 3, 4, 5, 6, 7, 8, 9] cbar.ax.yaxis.set_minor_locator( mticker.SymmetricalLogLocator(_scale.get_transform(), subs=_scale.subs), ) cbar.ax.yaxis.set_minor_formatter( mticker.LogFormatterSciNotation(_scale.base), ) plt.tight_layout() return fig, (ax,)