# coding: utf-8
"""
Collection of general helpers and utilities.
"""
from __future__ import annotations
__all__ = []
import os
import io
import re
import abc
import uuid
import queue
import threading
import subprocess
import importlib
import fnmatch
import inspect
import pprint
import multiprocessing
import multiprocessing.pool
from functools import wraps
from collections import OrderedDict
import law
import luigi
from columnflow import env_is_dev, env_is_remote, docs_url, github_url
from columnflow.types import Callable, Any, Sequence, Union, ModuleType
#: Placeholder for an unset value.
UNSET = object()
#: List of the first 200 primes.
primes = [
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193,
197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307,
311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421,
431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547,
557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659,
661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797,
809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929,
937, 941, 947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013, 1019, 1021, 1031, 1033, 1039,
1049, 1051, 1061, 1063, 1069, 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153,
1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223,
]
[docs]
def maybe_import(
name: str,
package: str | None = None,
force: bool = False,
) -> ModuleType | MockModule:
"""
Calls *importlib.import_module* internally and returns the module if it exists, or otherwise a
:py:class:`MockModule` instance with the same name. When *force* is *True* and the import fails,
an *ImportError* is raised.
"""
try:
return importlib.import_module(name, package)
except ImportError as e:
# raise in case force is set, or an other package than the requested one was not found
m = re.match(r"^No\smodule\snamed\s\'(.+)\'$", str(e))
if force or not m or not name.startswith(m.group(1)):
raise
if package:
name = package + name
return MockModule(name)
_plt = None
[docs]
def import_plt() -> ModuleType:
"""
Lazily imports and configures matplotlib pyplot.
"""
global _plt
if not _plt:
import matplotlib
matplotlib.use("Agg")
matplotlib.rc("text", usetex=True)
matplotlib.rcParams["text.latex.preamble"] = [r"\usepackage{amsmath}"]
matplotlib.rcParams["legend.edgecolor"] = "white"
import matplotlib.pyplot as plt
_plt = plt
return _plt
_ROOT = None
[docs]
def import_ROOT() -> ModuleType:
"""
Lazily imports and configures ROOT.
"""
global _ROOT
if not _ROOT:
import ROOT
ROOT.PyConfig.IgnoreCommandLineOptions = True
ROOT.gROOT.SetBatch()
_ROOT = ROOT
return _ROOT
[docs]
def import_file(path: str, attr: str | None = None):
"""
Loads the content of a python file located at *path* and returns its package content as a
dictionary. When *attr* is set, only the attribute with that name is returned.
The file is not required to be importable as its content is loaded directly into the
interpreter. While this approach is not necessarily clean, it can be useful in places where
custom code must be loaded.
"""
# load the package contents (do not try this at home)
path = expand_path(path)
pkg = DotDict()
with open(path, "r") as f:
exec(f.read(), pkg)
# extract a particular attribute
if attr:
if attr not in pkg:
raise AttributeError(f"no local member '{attr}' found in file {path}")
return pkg[attr]
return pkg
[docs]
def ipython_shell(
confirm_exit: bool = False,
pretty_print: bool = True,
banner: bool = False,
):
"""
Starts an IPython shell with configurable parameters.
:param confirm_exit: Whether to ask for confirmation before exiting the shell.
:param pretty_print: Whether to use pretty printing.
:param banner: Whether to display the IPython banner.
"""
# set the config
from traitlets.config import Config
config = Config()
config.TerminalInteractiveShell.confirm_exit = confirm_exit
config.TerminalInteractiveShell.pretty_print = pretty_print
# start the shell
from IPython.terminal.embed import InteractiveShellEmbed
return InteractiveShellEmbed.instance(config=config, display_banner=banner)
[docs]
def prettify(obj: Any, **kwargs) -> str:
"""
Prettifies the string repserentation of an object *obj* and returns it.
:param obj: Object to prettify.
:param kwargs: Optional arguments passed to :py:meth:`pprint.pprint`.
:return: Prettified string representation.
"""
s = io.StringIO()
pprint.pprint(obj, stream=s, **kwargs)
s.seek(0)
return s.read()
[docs]
def get_docs_url(*parts: str, anchor: str | None = None) -> str:
"""
Returns a URL pointing to the documentation of a particular page defined by *parts*. When an *anchor* is defined,
it is appended to the URL.
"""
url = "/".join([docs_url, *(str(part).strip("/") for part in parts)])
if anchor:
url += f"#{anchor}"
return url
[docs]
def get_github_url(*parts: str) -> str:
"""
Returns a URL pointing to the repository on github including additional URL fragments *parts*.
"""
url = "/".join([github_url, *(str(part).strip("/") for part in parts)])
return url
[docs]
def get_release_url(tag: str) -> str:
"""
Returns a URL pointing to the release notes of a particular tag.
"""
return get_github_url("releases", "tag", f"v{tag.lstrip('/v')}")
[docs]
def get_code_url(*parts: str, branch: str = "master") -> str:
"""
Returns a URL pointing to specific code on the github repository, defined by *parts* and the corresponding *branch*.
"""
return get_github_url("blob", branch, *parts)
[docs]
def create_random_name() -> str:
"""
Returns a random string based on UUID v4.
"""
return str(uuid.uuid4())
[docs]
def expand_path(*path: str) -> str:
"""
Takes *path* fragments, joins them and recursively expands all contained environment variables.
"""
path = os.path.join(*map(str, path))
while "$" in path or "~" in path:
path = os.path.expandvars(os.path.expanduser(path))
return path
[docs]
def real_path(*path: str) -> str:
"""
Takes *path* fragments and returns the real, absolute location with all variables expanded.
"""
return os.path.realpath(expand_path(*path))
[docs]
def ensure_dir(path: str) -> str:
"""
Ensures that a directory at *path* (and its subdirectories) exists and returns the full,
expanded path.
"""
path = real_path(path)
if not os.path.exists(path):
os.makedirs(path)
return path
[docs]
def wget(src: str, dst: str, force: bool = False) -> str:
"""
Downloads a file from a remote *src* to a local destination *dst*, creating intermediate
directories when needed. When *dst* refers to an existing file, an exception is raised unless
*force* is *True*.
The full, normalized destination path is returned.
"""
# check if the target directory exists
dst = real_path(dst)
if os.path.isdir(dst):
dst = os.path.join(dst, os.path.basename(src))
else:
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
raise IOError(f"target directory '{dst_dir}' does not exist")
# remove existing dst or complain
if os.path.exists(dst):
if force:
os.remove(dst)
else:
raise IOError(f"target '{dst}' already exists")
# actual download
cmd = ["wget", src, "-O", dst]
code, _, error = law.util.interruptable_popen(law.util.quote_cmd(cmd), shell=True,
executable="/bin/bash", stderr=subprocess.PIPE)
if code != 0:
raise Exception(f"wget failed: {error}")
return dst
[docs]
def call_thread(
func: Callable,
args: tuple = (),
kwargs: dict | None = None,
timeout: float | None = None,
) -> tuple[bool, Any, str | None]:
"""
Execute a function *func* in a thread and aborts the call when *timeout* is reached. *args* and
*kwargs* are forwarded to the function.
The return value is a 3-tuple (finsihed_in_time, func(), err).
"""
def wrapper(q, *args, **kwargs):
try:
ret = (func(*args, **kwargs), None)
except Exception as e:
ret = (None, str(e))
q.put(ret)
q = queue.Queue(1)
thread = threading.Thread(target=wrapper, args=(q,) + args, kwargs=kwargs or {})
thread.start()
thread.join(timeout)
if thread.is_alive():
return (False, None, None)
else:
return (True,) + q.get()
[docs]
def call_proc(
func: Callable,
args: tuple = (),
kwargs: dict | None = None,
timeout: float | None = None,
) -> tuple[bool, Any, str | None]:
"""
Execute a function *func* in a process and aborts the call when *timeout* is reached. *args* and
*kwargs* are forwarded to the function.
The return value is a 3-tuple (finsihed_in_time, func(), err).
"""
def wrapper(q, *args, **kwargs):
try:
ret = (func(*args, **kwargs), None)
except Exception as e:
ret = (None, str(e))
q.put(ret)
q = multiprocessing.Queue(1)
proc = multiprocessing.Process(target=wrapper, args=(q,) + args, kwargs=kwargs or {})
proc.start()
proc.join(timeout)
if proc.is_alive():
proc.terminate()
return (False, None, None)
else:
return (True,) + q.get()
[docs]
@law.decorator.factory(accept_generator=True)
def ensure_proxy(
fn: Callable,
opts: dict,
task: law.Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable]:
"""
Law task decorator that checks whether either a voms or arc proxy is existing before calling
the decorated method.
"""
def before_call():
# do nothing in remote jobs
if env_is_remote:
return None
# do nothing when explicitly skipped by the law config
if law.config.get_expanded_boolean("analysis", "skip_ensure_proxy", False):
return None
# check the proxy validity
if not law.wlcg.check_vomsproxy_validity() and not law.arc.check_arcproxy_validity():
raise Exception("neither voms nor arc proxy valid")
def call(state):
return fn(task, *args, **kwargs)
def after_call(state):
return
return before_call, call, after_call
[docs]
def dev_sandbox(sandbox: str, add: bool = True, remove: bool = True) -> str:
"""
Takes a sandbox key *sandbox* and adds or removes the substring "_dev" right before the file
extension (if any), depending on whether the current environment is used for development (see
:py:attr:`env_is_dev`) and the *add* and *remove* flags.
If *sandbox* does not contain the "_dev" postfix and both :py:attr:`env_is_dev` and *add* are
*True*, the postfix is appended.
If *sandbox* does (!) contain the "_dev" postfix, :py:attr:`env_is_dev` is *False* and *remove*
is *True*, the postfix is removed.
In any other case, *sandbox* is returned unchanged.
Examples:
.. code-block:: python
# if env_is_dev and /path/to/script_dev.sh exists
dev_sandbox("bash::/path/to/script.sh")
# -> "bash::/path/to/script_dev.sh"
# otherwise
dev_sandbox("bash::/path/to/script.sh")
# -> "bash::/path/to/script.sh"
"""
# only take into account venv and bash sandboxes
_type, path = law.Sandbox.split_key(sandbox)
if _type not in ["venv", "bash"]:
return sandbox
# check if the sandbox is dev
path_no_ext, ext = os.path.splitext(path)
sandbox_is_dev = path_no_ext.endswith("_dev")
# update the path if needed
if not sandbox_is_dev and env_is_dev and add:
path = f"{path_no_ext}_dev{ext}"
elif sandbox_is_dev and not env_is_dev and remove:
path = f"{path_no_ext[:-4]}{ext}"
else:
# nothing to do in any other case
return sandbox
# if the path does not exist, return the sandbox unchanged as well
if not os.path.exists(real_path(path)):
return sandbox
# all checks passed
return law.Sandbox.join_key(_type, path)
[docs]
def freeze(cont: Any) -> Any:
"""Constructs an immutable version of a native Python container.
Recursively replaces all mutable containers (``dict``, ``list``, ``set``) encountered within
*cont* by an immutable equivalent: Lists are converted to tuples, sets to ``frozenset``
objects, and dictionaries to tuples of (*key*, *value*) pairs.
"""
if isinstance(cont, dict):
return tuple((k, freeze(v)) for k, v in cont.items())
elif isinstance(cont, (list, tuple)):
return tuple(freeze(v) for v in cont)
elif isinstance(cont, set):
return frozenset(freeze(v) for v in cont)
return cont
[docs]
def memoize(f: Callable) -> Callable:
"""
Function decorator that implements memoization. Function results are cached on
first call and returned from cache on every subsequent call with the same arguments.
"""
_cache = {}
@wraps(f)
def wrapper(*args, **kwargs):
frozen_args = freeze(dict(args=args, kwargs=kwargs))
if frozen_args not in _cache:
_cache[frozen_args] = f(*args, **kwargs)
return _cache[frozen_args]
return wrapper
[docs]
def safe_div(a: int | float, b: int | float) -> float:
"""
Returns *a* divided by *b* if *b* is not zero, and zero otherwise.
"""
return (a / b) if b else 0.0
[docs]
def try_float(f: Any) -> bool:
"""
Tests whether a value *f* can be converted to a float.
"""
try:
float(f)
return True
except (ValueError, TypeError):
return False
[docs]
def try_complex(f: Any) -> bool:
"""
Tests whether a value *f* can be converted to a complex number.
"""
try:
complex(f)
return True
except (ValueError, TypeError):
return False
[docs]
def try_int(i: Any) -> bool:
"""
Tests whether a value *i* can be converted to an integer.
"""
try:
int(i)
return True
except (ValueError, TypeError):
return False
[docs]
def maybe_int(i: Any) -> Any:
"""
Returns *i* as an integer if it is a whole number, and as a float otherwise.
"""
if isinstance(i, (int, bool)) or (isinstance(i, float) and i.is_integer()):
return int(i)
return i
[docs]
def is_pattern(s: str) -> bool:
"""
Returns *True* if a string *s* contains pattern characters such as "*" or "?", and *False* otherwise.
"""
return "*" in s or "?" in s or s.startswith("!")
[docs]
def is_regex(s: str) -> bool:
"""
Returns *True* if a string *s* is a regular expression starting with "^" and ending with "$",
and *False* otherwise.
"""
return s.startswith("^") and s.endswith("$")
[docs]
def pattern_matcher(pattern: Sequence[str] | str, mode: Callable = any) -> Callable[[str], bool]:
r"""
Takes a string *pattern* which might be an actual pattern for fnmatching, a regular expressions
or just a plain string and returns a function that can be used to test of a string matches that
pattern.
Patterns starting with "^" and ending with "$" are considered regular expressions, and otherwise fnmatch patterns.
In the latter case, when the pattern starts with a "!", the match is inverted.
When *pattern* is a sequence, all its patterns are compared the same way and the result is the
combination given a *mode* which typically should be *any* or *all*.
Example:
.. code-block:: python
matcher = pattern_matcher("foo*")
matcher("foo123") # -> True
matcher("bar123") # -> False
matcher = pattern_matcher(r"^foo\d+.*$")
matcher("foox") # -> False
matcher("foo1") # -> True
matcher = pattern_matcher("!foo*")
matcher("foo123") # -> False
matcher("bar123") # -> True
matcher = pattern_matcher(("foo*", "*bar"), mode=any)
matcher("foo123") # -> True
matcher("123bar") # -> True
matcher = pattern_matcher(("foo*", "*bar"), mode=all)
matcher("foo123") # -> False
matcher("123bar") # -> False
matcher("foo123bar") # -> True
"""
if isinstance(pattern, (list, tuple, set)):
matchers = [pattern_matcher(p) for p in pattern]
return lambda s: mode(matcher(s) for matcher in matchers)
# special cases
if pattern in ["*", "^.*$"]:
return lambda s: True
# identify regular expressions
if is_regex(pattern):
cre = re.compile(pattern)
return lambda s: cre.match(s) is not None
# identify fnmatch patterns
if is_pattern(pattern):
negate = pattern.startswith("!")
if negate:
return lambda s: not fnmatch.fnmatch(s, pattern[1:])
return lambda s: fnmatch.fnmatch(s, pattern)
# fallback to string comparison
return lambda s: s == pattern
[docs]
def dict_add_strict(d: dict, key: str, value: Any) -> None:
"""
Adds key-value pair to dictionary, but only if it does not change an existing value;
Raises KeyError otherwise.
"""
if key in d.keys() and d[key] != value:
raise KeyError(f"'{d.__class__.__name__}' object already has key {key}")
d[key] = value
[docs]
def get_source_code(obj: Any, indent: str | int = None) -> str:
"""
Returns the source code of any object *obj* as a string. When *indent* is not *None*, the code
indentation is first removed and then re-applied with *indent* if it is a string, or by that
many spaces in case it is an integer.
"""
code = inspect.getsource(obj)
if indent is not None:
code = code.replace("\t", " ")
lines = code.split("\n")
n_old_indent = len(lines[0]) - len(lines[0].lstrip(" "))
new_indent = (" " * indent) if isinstance(indent, int) else indent
code = "\n".join(
(new_indent + line[n_old_indent:]) if line.strip() else ""
for line in lines
)
return code
[docs]
class DotDict(OrderedDict):
"""
Subclass of *OrderedDict* that provides read and write access to items via attributes by
implementing ``__getattr__`` and ``__setattr__``. In case a item is accessed via attribute and
it does not exist, an *AttriuteError* is raised rather than a *KeyError*. Example:
.. code-block:: python
d = DotDict()
d["foo"] = 1
print(d["foo"])
# => 1
print(d.foo)
# => 1
print(d["bar"])
# => KeyError
print(d.bar)
# => AttributeError
d.bar = 123
print(d.bar)
# => 123
# use wrap() to convert a nested dict
d = DotDict.wrap({"foo": {"bar": 1}})
print(d.foo.bar)
# => 1
"""
[docs]
def __getattr__(self, attr: str) -> Any:
try:
return self[attr]
except KeyError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")
[docs]
def __setattr__(self, attr: str, value: Any) -> None:
self[attr] = value
[docs]
def copy(self) -> DotDict:
""""""
return self.__class__(self)
[docs]
@classmethod
def wrap(cls, *args, **kwargs) -> DotDict:
"""
Takes a dictionary *d* and recursively replaces it and all other nested dictionary types
with :py:class:`DotDict`'s for deep attribute-style access.
"""
wrap = lambda d: cls((k, wrap(v)) for k, v in d.items()) if isinstance(d, dict) else d
return wrap(OrderedDict(*args, **kwargs))
[docs]
class MockModule(object):
"""
Mockup object that resembles a module with arbitrarily deep structure such that, e.g.,
.. code-block:: python
coffea = MockModule("coffea")
print(coffea.nanoevents.NanoEventsArray)
# -> "<MockupModule 'coffea' at 0x981jald1>"
will always succeed at declaration, but most likely fail at execution time. In fact, each
attribute access will return the mock object again. This might only be useful in places where
a module is potentially not existing (e.g. due to sandboxing) but one wants to import it either
way a) to perform only one top-level import as opposed to imports in all functions of a package,
or b) to provide type hints for documentation purposes.
.. py:attribute:: _name
type: str
The name of the mock module.
"""
[docs]
def __init__(self, name: str):
super().__init__()
self._name = name
[docs]
def __getattr__(self, attr: str) -> MockModule:
return self
[docs]
def __repr__(self) -> str:
return f"<{self.__class__.__name__} '{self._name}' at {hex(id(self))}>"
[docs]
def __call__(self, *args, **kwargs) -> None:
raise Exception(f"{self._name} is a mock module and cannot be called")
[docs]
def __nonzero__(self) -> bool:
return False
[docs]
def __bool__(self) -> bool:
return False
[docs]
def __or__(self, other) -> Any:
# forward union type hints
return Union[type(self), other]
[docs]
class FunctionArgs(object):
"""
Light-weight utility class that wraps all passed *args* and *kwargs* and allows to invoke
different functions with them.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__()
# store attributes
self.args = args
self.kwargs = kwargs
[docs]
def __call__(self, func: Callable) -> Any:
return func(*self.args, **self.kwargs)
[docs]
class ClassPropertyDescriptor(object):
"""
Generic descriptor class that is used by :py:func:`classproperty`.
"""
[docs]
def __init__(self, fget: Callable, fset: Callable | None = None):
super().__init__()
self.fget = fget
self.fset = fset
[docs]
def __get__(self, obj: type, cls: type | None = None) -> Any:
if cls is None:
cls = type(obj)
return self.fget.__get__(obj, cls)()
[docs]
def __set__(self, obj: type, value: Any) -> None:
# fset must exist
if not self.fset:
raise AttributeError("can't set attribute")
return self.fset.__get__(obj, type(obj))(value)
[docs]
def classproperty(func: Callable) -> ClassPropertyDescriptor:
"""
Propety decorator for class-level methods.
"""
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
return ClassPropertyDescriptor(func)
[docs]
class Derivable(object, metaclass=DerivableMeta):
"""
Derivable base class with features provided by the meta :py:class:`DerivableMeta`.
.. py:classattribute:: cls_name
type: str
read-only
A shorthand to access the name of the class.
"""
@classproperty
def cls_name(cls) -> str:
# shorthand to the class name
return cls.__name__
[docs]
class KeyValueMessage(luigi.worker.SchedulerMessage):
"""
Subclass of :py:class:`luigi.worker.SchedulerMessage` that adds :py:attr:`key` and
:py:attr:`value` attributes, parsed from the incoming message assuming a format ``key = value``.
.. py:attribute: key
type: str
The key of the message.
.. py:attribute: value
type: str
The value of the message.
"""
# compile expression for key - value parsing of scheduler messages
message_cre = re.compile(r"^\s*([^\=\:]+)\s*(\=|\:)\s*(.*)\s*$")
[docs]
@classmethod
def from_message(cls, message: luigi.worker.SchedulerMessage) -> KeyValueMessage | None:
"""
Factory for :py:class:`KeyValueMessage` instances that takes an existing *message* object
and splits its content into a key value pair. The instance is returned if the parsing is
successful, and *None* otherwise.
"""
m = cls.message_cre.match(message.content)
if not m:
return None
return cls(
message._scheduler,
message._task_id,
message._message_id,
message.content,
m.group(1),
m.group(3),
**message.payload,
)
[docs]
def __init__(self, *args, key, value, **kwargs):
super().__init__(*args, **kwargs)
self.key = key
self.value = value
[docs]
def __str__(self) -> str:
return str(self.value)
[docs]
def load_correction_set(target: law.FileSystemFileTarget) -> Any:
"""
Loads a correction set using the correctionlib from a file *target*.
"""
import correctionlib
# extend the Correction object
correctionlib.highlevel.Correction.__call__ = correctionlib.highlevel.Correction.evaluate
# use the path when the input file is a normal json
if target.ext() == "json":
return correctionlib.CorrectionSet.from_file(target.abspath)
# otherwise, assume the input file is compressed
return correctionlib.CorrectionSet.from_string(target.load(formatter="gzip").decode("utf-8"))