Source code for columnflow.tasks.framework.base

# coding: utf-8

"""
Generic tools and base tasks that are defined along typical objects in an analysis.
"""

from __future__ import annotations

import os
import abc
import enum
import importlib
import itertools
import inspect
import functools
import collections
import copy
import subprocess
from dataclasses import dataclass, field

import luigi
import law
import order as od

from columnflow.columnar_util import mandatory_coffea_columns, Route, ColumnCollection
from columnflow.util import is_regex, prettify, DotDict
from columnflow.types import Sequence, Callable, Any, T


logger = law.logger.get_logger(__name__)
logger_dev = law.logger.get_logger(f"{__name__}-dev")

# default analysis and config related objects
default_analysis = law.config.get_expanded("analysis", "default_analysis")
default_config = law.config.get_expanded("analysis", "default_config")
default_dataset = law.config.get_expanded("analysis", "default_dataset")
default_repr_max_len = law.config.get_expanded_int("analysis", "repr_max_len")
default_repr_max_count = law.config.get_expanded_int("analysis", "repr_max_count")
default_repr_hash_len = law.config.get_expanded_int("analysis", "repr_hash_len")

# placeholder to denote a default value that is resolved dynamically
RESOLVE_DEFAULT = "DEFAULT"


[docs] class Requirements(DotDict): """ Container for task-level requirements of different tasks. Can be initialized with other :py:class:`Requirement` instances and additional keyword arguments ``kwargs``, which are added. """ def __init__(self, *others, **kwargs) -> None: super().__init__() # add others and kwargs for reqs in others + (kwargs,): self.update(reqs)
[docs] class OutputLocation(enum.Enum): """ Output location flag. """ config = "config" local = "local" wlcg = "wlcg" wlcg_mirrored = "wlcg_mirrored"
[docs] @dataclass class TaskShifts: """ Container for *local* and *upstream* shifts at a point in the task graph. """ # NOTE: maybe these should be a dict of sets (one set per config) to allow for different shifts # per config local: set[str] = field(default_factory=set) upstream: set[str] = field(default_factory=set)
[docs] class BaseTask(law.Task): task_namespace = law.config.get_expanded("analysis", "cf_task_namespace", "cf") # container for upstream requirements for convenience reqs = Requirements()
[docs] def get_params_dict(self) -> dict[str, Any]: return { attr: getattr(self, attr) for attr, param in self.get_params() if isinstance(param, luigi.Parameter) }
[docs] class AnalysisTask(BaseTask, law.SandboxTask): analysis = luigi.Parameter( default=default_analysis, description=f"name of the analysis; default: '{default_analysis}'", ) version = luigi.Parameter( description="mandatory version that is encoded into output paths", ) notify_slack = law.slack.NotifySlackParameter(significant=False) notify_mattermost = law.mattermost.NotifyMattermostParameter(significant=False) notify_custom = law.NotifyCustomParameter(significant=False) allow_empty_sandbox = True sandbox = None message_cache_size = 25 local_workflow_require_branches = False output_collection_cls = law.SiblingFileCollection # defaults for targets default_store = "$CF_STORE_LOCAL" default_wlcg_fs = law.config.get_expanded("target", "default_wlcg_fs", "wlcg_fs") default_output_location = "config" exclude_params_index = {"user"} exclude_params_req = {"user", "notify_slack", "notify_mattermost", "notify_custom"} exclude_params_repr = {"user", "notify_slack", "notify_mattermost", "notify_custom"} exclude_params_branch = {"user"} exclude_params_workflow = {"user", "notify_slack", "notify_mattermost", "notify_custom"} # cached and parsed sections of the law config for faster lookup _cfg_outputs_dict = None _cfg_versions_dict = None _cfg_resources_dict = None
[docs] @classmethod def modify_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().modify_param_values(params) params = cls.resolve_param_values(params) return params
[docs] @classmethod def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: # store a reference to the analysis inst if "analysis_inst" not in params and "analysis" in params: params["analysis_inst"] = cls.get_analysis_inst(params["analysis"]) return params
[docs] @classmethod def get_analysis_inst(cls, analysis: str) -> od.Analysis: # prepare names if "." not in analysis: raise ValueError(f"invalid analysis format: {analysis}") module_id, name = analysis.rsplit(".", 1) # import the module try: mod = importlib.import_module(module_id) except ImportError as e: raise ImportError(f"cannot import analysis module {module_id}: {e}") # get the analysis instance analysis_inst = getattr(mod, name, None) if analysis_inst is None: raise Exception(f"module {module_id} does not contain analysis instance {name}") return analysis_inst
[docs] @classmethod def req_params(cls, inst: AnalysisTask, **kwargs) -> dict[str, Any]: """ Returns parameters that are jointly defined in this class and another task instance of some other class. The parameters are used when calling ``Task.req(self)``. """ # always prefer certain parameters given as task family parameters (--TaskFamily-parameter) _prefer_cli = law.util.make_set(kwargs.get("_prefer_cli", [])) | { "version", "workflow", "job_workers", "poll_interval", "walltime", "max_runtime", "retries", "acceptance", "tolerance", "parallel_jobs", "shuffle_jobs", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_pool", "pilot", } kwargs["_prefer_cli"] = _prefer_cli # build the params params = super().req_params(inst, **kwargs) # when not explicitly set in kwargs and no global value was defined on the cli for the task # family, evaluate and use the default value if ( isinstance(getattr(cls, "version", None), luigi.Parameter) and "version" not in kwargs and not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") and cls.task_family != law.parser.root_task_cls().task_family ): default_version = cls.get_default_version(inst, params) if default_version and default_version != law.NO_STR: params["version"] = default_version return params
@classmethod def _structure_cfg_items(cls, items: list[tuple[str, Any]]) -> dict[str, Any]: if not items: return {} # apply brace expansion to keys items = sum(( [(_key, value) for _key in law.util.brace_expand(key)] for key, value in items ), []) # breakup keys at double underscores and create a nested dictionary items_dict = {} for key, value in items: if not value: continue d = items_dict parts = key.split("__") for i, part in enumerate(parts): if i < len(parts) - 1: # fill intermediate structure if part not in d: d[part] = {} elif not isinstance(d[part], dict): d[part] = {"*": d[part]} d = d[part] else: # assign value to the last nesting level if part in d and isinstance(d[part], dict): d[part]["*"] = value else: d[part] = value return items_dict @classmethod def _get_cfg_outputs_dict(cls) -> dict[str, Any]: if cls._cfg_outputs_dict is None and law.config.has_section("outputs"): # collect config item pairs skip_keys = {"wlcg_file_systems", "lfn_sources"} items = [ (key, law.config.get_expanded("outputs", key, None, split_csv=True)) for key, value in law.config.items("outputs") if value and key not in skip_keys ] cls._cfg_outputs_dict = cls._structure_cfg_items(items) return cls._cfg_outputs_dict @classmethod def _get_cfg_versions_dict(cls) -> dict[str, Any]: if cls._cfg_versions_dict is None and law.config.has_section("versions"): # collect config item pairs items = [ (key, value) for key, value in law.config.items("versions") if value ] cls._cfg_versions_dict = cls._structure_cfg_items(items) return cls._cfg_versions_dict @classmethod def _get_cfg_resources_dict(cls) -> dict[str, Any]: if cls._cfg_resources_dict is None and law.config.has_section("resources"): # helper to split resource values into key-value pairs themselves def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: params = [] for part in value.split(","): part = part.strip() if not part: continue if "=" not in part: logger.warning_once( f"invalid_resource_{key}", f"resource for key {key} contains invalid instruction {part}, skipping", ) continue param, value = (s.strip() for s in part.split("=", 1)) params.append((param, value)) return key, params # collect config item pairs items = [ parse(key, value) for key, value in law.config.items("resources") if value and not key.startswith("_") ] cls._cfg_resources_dict = cls._structure_cfg_items(items) return cls._cfg_resources_dict
[docs] @classmethod def get_default_version(cls, inst: AnalysisTask, params: dict[str, Any]) -> str | None: """ Determines the default version for instances of *this* task class when created through :py:meth:`req` from another task *inst* given parameters *params*. :param inst: The task instance from which *this* task should be created via :py:meth:`req`. :param params: The parameters that are passed to the task instance. :return: The default version, or *None* if no default version can be defined. """ # get different attributes by which the default version might be looked up keys = cls.get_config_lookup_keys(params) # forward to lookup implementation version = cls._get_default_version(inst, params, keys) # after a version is found, it can still be an exectuable taking the same arguments return version(cls, inst, params) if callable(version) else version
@classmethod def _get_default_version( cls, inst: AnalysisTask, params: dict[str, Any], keys: law.util.InsertableDict, ) -> str | None: # try to lookup the version in the analysis's auxiliary data analysis_inst = getattr(inst, "analysis_inst", None) if analysis_inst: version = cls._dfs_key_lookup(keys, analysis_inst.x("versions", {})) if version: return version # try to find it in the analysis section in the law config if law.config.has_section("versions"): versions_dict = cls._get_cfg_versions_dict() if versions_dict: version = cls._dfs_key_lookup(keys, versions_dict) if version: return version # no default version found return None
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: AnalysisTask | dict[str, Any], ) -> law.util.InsertiableDict: """ Returns a dictionary with keys that can be used to lookup state specific values in a config or dictionary, such as default task versions or output locations. :param inst_or_params: The tasks instance or its parameters. :return: A dictionary with keys that can be used for nested lookup. """ keys = law.util.InsertableDict() # add the analysis name analysis = ( inst_or_params.get("analysis") if isinstance(inst_or_params, dict) else getattr(inst_or_params, "analysis", None) ) if analysis not in {law.NO_STR, None, ""}: keys["analysis"] = analysis # add the task family keys["task_family"] = cls.task_family return keys
@classmethod def _dfs_key_lookup( cls, keys: law.util.InsertableDict[str, str | Sequence[str]] | Sequence[str | Sequence[str]], nested_dict: dict[str, Any], empty_value: Any = None, ) -> str | Callable | None: # opinionated nested dictionary lookup alongside in ordered sequence of (optional) keys, # that allows for patterns in the keys and, interpreting the nested dict as a tree, finds # matches in a depth-first (dfs) manner if not nested_dict: return empty_value # the keys to use for the lookup are the flattened values of the keys dict flat_keys = collections.deque(law.util.flatten(keys.values() if isinstance(keys, dict) else keys)) # start tree traversal using a queue lookup consisting of names and values of tree nodes, # as well as the remaining keys (as a deferred function) to compare for that particular path lookup = collections.deque([tpl + ((lambda: flat_keys.copy()),) for tpl in nested_dict.items()]) while lookup: pattern, obj, keys_func = lookup.popleft() # create the copy of comparison keys on demand # (the original sequence is living once on the previous stack until now) _keys = keys_func() # check if the pattern matches any key regex = is_regex(pattern) while _keys: key = _keys.popleft() if law.util.multi_match(key, pattern, regex=regex): # when obj is not a dict, we found the value if not isinstance(obj, dict): return obj # go one level deeper and stop the current iteration keys_func = (lambda _keys: (lambda: _keys.copy()))(_keys) lookup.extendleft(tpl + (keys_func,) for tpl in reversed(obj.items())) break # at this point, no value could be found return empty_value
[docs] @classmethod def get_array_function_dict(cls, params: dict[str, Any]) -> dict[str, Any]: if "analysis_inst" in params: analysis_inst = params["analysis_inst"] else: analysis_inst = cls.get_analysis_inst(params["analysis"]) return {"analysis_inst": analysis_inst}
[docs] @classmethod def find_config_objects( cls, names: str | Sequence[str] | set[str], container: od.UniqueObject | Sequence[od.UniqueObject], object_cls: od.UniqueObjectMeta, groups_str: str | None = None, accept_patterns: bool = True, deep: bool = False, strict: bool = False, multi_strategy: str = "first", ) -> list[str] | dict[od.UniqueObject, list[str]]: """ Returns all names of objects of type *object_cls* known to a *container* (e.g. :py:class:`od.Analysis` or :py:class:`od.Config`) that match *names*. A name can also be a pattern to match if *accept_patterns* is *True*, or, when given, the key of a mapping named *group_str* in the container auxiliary data that matches group names to object names. When *deep* is *True* the lookup of objects in the *container* is recursive. When *strict* is *True*, an error is raised if no matches are found for any of the *names*. *container* can also refer to a sequence of container objects. If this is the case, the default object retrieval is performed for all of them and the resulting values can be handled with five different strategies, controlled via *multi_strategy*: - ``"first"``: The first resolved name is returned. - ``"same"``: The resolved names are forced to be identical and an exception is raised if they differ. The first resolved value is returned. - ``"union"``: The set union of all resolved names is returned in a list. - ``"intersection"``: The set intersection of all resolved names is returned in a list. - ``"all"``: The resolved values are returned in a dictionary mapped to their respective container. Example: .. code-block:: python find_config_objects(names=["st_tchannel_*"], container=config_inst, object_cls=od.Dataset) # -> ["st_tchannel_t", "st_tchannel_tbar"] """ # when the container is a sequence, find objects per container and apply the multi_strategy if isinstance(container, (list, tuple)): if multi_strategy not in (strategies := {"first", "same", "union", "intersection", "all"}): raise ValueError(f"invalid multi_strategy: {multi_strategy}, must be one of {','.join(strategies)}") all_object_names = { _container: cls.find_config_objects( names=names, container=_container, object_cls=object_cls, groups_str=groups_str, accept_patterns=accept_patterns, deep=deep, strict=strict, ) for _container in container } if multi_strategy == "all": return all_object_names if multi_strategy == "first": return all_object_names[container[0]] if multi_strategy == "union": return list(set.union(*map(set, all_object_names.values()))) if multi_strategy == "intersection": return list(set.intersection(*map(set, all_object_names.values()))) # "same", so check that values are identical first = all_object_names[container[0]] if not all(all_object_names[c] == first for c in container[1:]): raise ValueError( f"different objects found across containers looking for '{object_cls}' objects '{names}':\n" f"{prettify(all_object_names)}", ) return first # prepare value caching singular = object_cls.cls_name_singular plural = object_cls.cls_name_plural _cache: dict[str, set[str]] = {} def get_all_object_names() -> set[str]: if "all_object_names" not in _cache: if deep: _cache["all_object_names"] = {obj.name for obj, _, _ in getattr(container, f"walk_{plural}")()} else: _cache["all_object_names"] = set(getattr(container, plural).names()) return _cache["all_object_names"] def has_obj(name: str) -> bool: if "has_obj_func" not in _cache: kwargs = {} if object_cls in container._deep_child_classes: kwargs["deep"] = deep _cache["has_obj_func"] = functools.partial(getattr(container, f"has_{singular}"), **kwargs) return _cache["has_obj_func"](name) object_names = [] lookup = law.util.make_list(names) missing = set() while lookup: name = lookup.pop(0) if has_obj(name): # known object object_names.append(name) elif groups_str and name in (object_groups := container.x(groups_str, {})): # a key in the object group dict lookup.extend(list(object_groups[name])) elif accept_patterns: # must eventually be a pattern, perform an object traversal found = [] for _name in sorted(get_all_object_names()): if law.util.multi_match(_name, name): found.append(_name) if not found: missing.add(name) object_names.extend(found) if missing and strict: missing_str = ",".join(sorted(missing)) raise ValueError(f"names/patterns did not yield any matches: {missing_str}") return law.util.make_unique(object_names)
[docs] @classmethod def resolve_config_default( cls, *, param: Any, task_params: dict[str, Any], container: str | od.AuxDataMixin | Sequence[od.AuxDataMixin], default_str: str | None = None, multi_strategy: str = "first", ) -> Any | list[Any] | dict[od.AuxDataMixin, Any]: """ Resolves a given parameter value *param*, checks if it should be placed with a default value when empty, and in this case, does the actual default value resolution. This resolution is triggered only in case *param* refers to :py:attr:`RESOLVE_DEFAULT`, a 1-tuple containing this attribute, or *None*. If so, the default is identified via the *default_str* from an :py:class:`order.AuxDataMixin` *container* and points to an auxiliary that can be either a string or a function. In the latter case, it is called with the task class, the container instance, and all task parameters. Note that when no *container* is given, *param* is returned unchanged. *container* can also refer to a sequence of :py:class:`order.AuxDataMixin` objects. If this is the case, the default resolution is performed for all of them and the resulting values can be handled with five different strategies, controlled via *multi_strategy*: - ``"first"``: The first resolved value is returned. - ``"same"``: The resolved values are forced to be identical and an exception is raised if they differ. The first resolved value is returned. - ``"union"``: The set union of all resolved values is returned in a list. - ``"intersection"``: The set intersection of all resolved values is returned in a list. - ``"all"``: The resolved values are returned in a dictionary mapped to their respective container. Example: .. code-block:: python # assuming this is your config config_inst = od.Config( id=1, name="my_config", aux={ "default_selector": "my_selector", }, ) # and these are the task parameters params = { "config_inst": config_inst, } AnalysisTask.resolve_config_default( param=RESOLVE_DEFAULT, task_params=params, container=config_inst, # <-- same as passing the "config_inst" key of params default_str="default_selector", ) # -> "my_selector" Example where the default points to a function: .. code-block:: python def default_selector(task_cls, config_inst, task_params) -> str: # determine the selector based on dynamic conditions return "my_other_selector config_inst = od.Config( id=1, name="my_config", aux={ "default_selector": default_selector, # <-- function }, ) AnalysisTask.resolve_config_default( param=RESOLVE_DEFAULT, task_params=params, container=config_inst, default_str="default_selector", ) # -> "my_other_selector" """ if multi_strategy not in (strategies := {"first", "same", "union", "intersection", "all"}): raise ValueError(f"invalid multi_strategy: {multi_strategy}, must be one of {','.join(strategies)}") # check if the parameter value is to be resolved resolve_default = param in (None, RESOLVE_DEFAULT, (RESOLVE_DEFAULT,)) return_single_value = True if param is None or isinstance(param, str) else False # interpret missing parameters (e.g. NO_STR) as None # (special case: an empty string is usually an active decision, but counts as missing too) if law.is_no_param(param) or resolve_default or param == "": param = None # get the container inst (typically a config_inst or analysis_inst) if isinstance(container, str): container = task_params.get(container) if not container: return param # actual resolution params: dict[od.AuxDataMixin, Any] if resolve_default: params = {} for _container in law.util.make_list(container): _param = param # expand default when container is set if _container and default_str: _param = _container.x(default_str, None) # allow default to be a function, taking task parameters as input if isinstance(_param, Callable): _param = _param(cls, _container, task_params) # handle empty values and return type if not return_single_value: _param = () if _param is None else law.util.make_tuple(_param) elif isinstance(_param, (list, tuple)): _param = _param[0] if _param else None params[_container] = _param else: params = {_container: param for _container in law.util.make_list(container)} # handle values if not isinstance(container, (list, tuple)): return params[container] if multi_strategy == "all": return params if multi_strategy == "first": return params[container[0]] # NOTE: in there two strategies, we loose all order information if multi_strategy == "union": return list(set.union(*map(set, params.values()))) if multi_strategy == "intersection": return list(set.intersection(*map(set, params.values()))) # "same", so check that values are identical first = params[container[0]] if not all(params[c] == first for c in container[1:]): default_str_repr = f" for '{default_str}'" if default_str else "" raise ValueError(f"multiple default values found{default_str_repr} in {container}: {params}") return first
[docs] @classmethod def resolve_config_default_and_groups( cls, *, param: Any, task_params: dict[str, Any], container: str | od.AuxDataMixin | Sequence[od.AuxDataMixin], groups_str: str, default_str: str | None = None, multi_strategy: str = "first", debug=False, ) -> Any | list[Any] | dict[od.AuxDataMixin, Any]: """ This method is similar to :py:meth:`~.resolve_config_default` in that it checks if a parameter value *param* is empty and should be replaced with a default value. All arguments except for *groups_str* are forwarded to this method. What this method does in addition is that it checks if the values contained in *param* (after default value resolution) refers to a group of values identified via the *groups_str* from the :py:class:`order.AuxDataMixin` *container* that maps a string to a tuple of strings. If it does, each value in *param* that refers to a group is expanded by the actual group values. Example: .. code-block:: python # assuming this is your config config_inst = od.Config( id=1, name="my_config", aux={ "default_producer": "my_producers", "producer_groups": { "my_producers": ["producer_1", "producer_2"], "my_other_producers": ["my_producers", "producer_3", "producer_4"], }, }, ) # and these are the task parameters params = { "config_inst": config_inst, } AnalysisTask.resolve_config_default_and_groups( param=RESOLVE_DEFAULT, task_params=params, container=config_inst, # <-- same as passing the "config_inst" key of params default_str="default_producer", groups_str="producer_groups", ) # -> ["producer_1", "producer_2"] Example showing recursive group expansion: .. code-block:: python # assuming config_inst and params are the same as above AnalysisTask.resolve_config_default_and_groups( param="my_other_producers", # <-- points to a group that contains another group task_params=params, container=config_inst, default_str="default_producer", # <-- not used as param is set explicitly groups_str="producer_groups", ) # -> ["producer_1", "producer_2", "producer_3", "producer_4"] """ if multi_strategy not in (strategies := {"first", "same", "union", "intersection", "all"}): raise ValueError(f"invalid multi_strategy: {multi_strategy}, must be one of {','.join(strategies)}") # get the container if isinstance(container, str): container = task_params.get(container, None) if not container: return param containers = law.util.make_list(container) # resolve the parameter params: dict[od.AuxDataMixin, Any] = cls.resolve_config_default( param=param, task_params=task_params, container=containers, default_str=default_str, multi_strategy="all", ) if not params: return param # expand groups recursively values = {} for _container, _param in params.items(): if not (param_groups := _container.x(groups_str, {})): values[_container] = law.util.make_tuple(_param) continue lookup = collections.deque(law.util.make_list(_param)) handled_groups = set() _values = [] while lookup: value = lookup.popleft() if value in param_groups: if value in handled_groups: raise Exception( f"definition of '{groups_str}' contains circular references involving group '{value}'", ) lookup.extendleft(law.util.make_list(param_groups[value])) handled_groups.add(value) else: _values.append(value) values[_container] = tuple(_values) # handle values if not isinstance(container, (list, tuple)): return values[container] if multi_strategy == "all": return values if multi_strategy == "first": return values[container[0]] if multi_strategy == "union": return list(set.union(*map(set, values.values()))) if multi_strategy == "intersection": return list(set.intersection(*map(set, values.values()))) # "same", so check that values are identical first = values[container[0]] if not all(values[c] == first for c in container[1:]): default_str_repr = f" for '{default_str}'" if default_str else "" raise ValueError( f"multiple default values found{default_str_repr} after expanding groups '{groups_str}' in " f"{containers}: {values}", ) return first
[docs] @classmethod def build_repr( cls, objects: Any | Sequence[Any], *, sep: str = "__", prepend_count: bool = False, max_len: int = default_repr_max_len, max_count: int = default_repr_max_count, hash_len: int = default_repr_hash_len, ) -> str: """ Generic method to construct a string representation given a single or a sequece of *objects*. :param objects: The object or objects to be represented. :param sep: The separator used to join the objects. :param prepend_count: When *True*, the number of objects is prepended to the string, followed by *sep*. :param max_len: The maximum length of the string. If exceeded, the string is truncated and hashed. :param max_count: The maximum number of objects to include in the string. Additional objects are hashed, but only if the resulting representation length does not exceed *max_len*. If so, the overall truncation and hashing is applied instead. :param hash_len: The length of the hash that is appended to the string when it is truncated. :return: The string representation. """ if 0 < max_len < hash_len: raise ValueError(f"max_len must be greater than hash_len: {max_len} <= {hash_len}") # join objects when a sequence is given if isinstance(objects, (list, tuple)): r = f"{len(objects)}{sep}" if prepend_count else "" # truncate when requested and the expected length will not exceed max_len, in which case the overall # truncation applies the hashing if ( 0 < max_count < len(objects) and not (0 < max_len < (len(r) + sum(map(len, objects[:max_count])) + len(sep) * max_count + hash_len)) ): r += sep.join(objects[:max_count]) r += f"{sep}{law.util.create_hash(objects[max_count:], l=hash_len)}" else: r += sep.join(objects) else: r = str(objects) # handle overall truncation if max_len > 0 and len(r) > max_len: r = f"{r[:max_len - hash_len - len(sep)]}{sep}{law.util.create_hash(r, l=hash_len)}" return r
def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store the analysis instance self.analysis_inst = self.get_analysis_inst(self.analysis) # cached values added and accessed by cached_value() self._cached_values = {}
[docs] def cached_value(self, key: str, func: Callable[[], T]) -> T: """ Upon first invocation, the function *func* is called and its return value is stored under *key* in :py:attr:`_cached_values`. Subsequent calls with the same *key* return the cached value. :param key: The key under which the value is stored. :param func: The function that is called to generate the value. :return: The cached value. """ if key not in self._cached_values: self._cached_values[key] = func() return self._cached_values[key]
[docs] def reset_sandbox(self, sandbox: str) -> None: """ Resets the sandbox to a new *sandbox* value. """ # do nothing if the value actualy does not change if self.sandbox == sandbox: return # change it and rebuild the sandbox inst when already initialized self.sandbox = sandbox if self._sandbox_initialized: self._initialize_sandbox(force=True)
[docs] def store_parts(self) -> law.util.InsertableDict: """ Returns a :py:class:`law.util.InsertableDict` whose values are used to create a store path. For instance, the parts ``{"keyA": "a", "keyB": "b", 2: "c"}`` lead to the path "a/b/c". The keys can be used by subclassing tasks to overwrite values. :return: Dictionary with parts that will be translated into an output directory path. """ parts = law.util.InsertableDict() # add the analysis name parts["analysis"] = self.analysis_inst.name # in this base class, just add the task class name parts["task_family"] = self.task_family # add the version when set if self.version is not None: parts["version"] = self.version return parts
def _get_store_parts_modifier( self, modifier: str | Callable[[AnalysisTask, dict], dict], ) -> Callable[[AnalysisTask, dict], dict] | None: if isinstance(modifier, str): # interpret it as a name of an entry in the store_parts_modifiers aux entry modifier = self.analysis_inst.x("store_parts_modifiers", {}).get(modifier) return modifier if callable(modifier) else None
[docs] def local_path( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, ) -> str: """ local_path(*path, store=None, fs=None, store_parts_modifier=None) Joins path fragments from *store* (defaulting to :py:attr:`default_store`), :py:meth:`store_parts` and *path* and returns the joined path. In case a *fs* is defined, it should refer to the config section of a local file system, and consequently, *store* is not prepended to the returned path as the resolution of absolute paths is handled by that file system. """ # if no fs is set, determine the main store directory parts = () if not kwargs.pop("fs", None): store = kwargs.get("store") or self.default_store parts += (store,) # get and optional modify the store parts store_parts = self.store_parts() store_parts_modifier = self._get_store_parts_modifier(store_parts_modifier) if callable(store_parts_modifier): store_parts = store_parts_modifier(self, store_parts) # concatenate all parts that make up the path and join them parts += tuple(store_parts.values()) + path path = os.path.join(*map(str, parts)) return path
[docs] def local_target( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, ) -> law.LocalTarget: """ local_target(*path, dir=False, store=None, fs=None, store_parts_modifier=None, **kwargs) Creates either a local file or directory target, depending on *dir*, forwarding all *path* fragments, *store* and *fs* to :py:meth:`local_path` and all *kwargs* the respective target class. """ _dir = kwargs.pop("dir", False) store = kwargs.pop("store", None) fs = kwargs.get("fs", None) # select the target class cls = law.LocalDirectoryTarget if _dir else law.LocalFileTarget # create the local path path = self.local_path(*path, store=store, fs=fs, store_parts_modifier=store_parts_modifier) # create the target instance and return it return cls(path, **kwargs)
[docs] def wlcg_path( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, ) -> str: """ Joins path fragments from *store_parts()* and *path* and returns the joined path. The full URI to the target is not considered as it is usually defined in ``[wlcg_fs]`` sections in the law config and hence subject to :py:func:`wlcg_target`. """ # get and optional modify the store parts store_parts = self.store_parts() store_parts_modifier = self._get_store_parts_modifier(store_parts_modifier) if callable(store_parts_modifier): store_parts = store_parts_modifier(self, store_parts) # concatenate all parts that make up the path and join them parts = tuple(store_parts.values()) + path path = os.path.join(*map(str, parts)) return path
[docs] def wlcg_target( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, ) -> law.wclg.WLCGTarget: """ wlcg_target(*path, dir=False, fs=default_wlcg_fs, store_parts_modifier=None, **kwargs) Creates either a remote WLCG file or directory target, depending on *dir*, forwarding all *path* fragments to :py:meth:`wlcg_path` and all *kwargs* the respective target class. When *None*, *fs* defaults to the *default_wlcg_fs* class level attribute. """ _dir = kwargs.pop("dir", False) if not kwargs.get("fs"): kwargs["fs"] = self.default_wlcg_fs # select the target class cls = law.wlcg.WLCGDirectoryTarget if _dir else law.wlcg.WLCGFileTarget # create the local path path = self.wlcg_path(*path, store_parts_modifier=store_parts_modifier) # create the target instance and return it return cls(path, **kwargs)
[docs] def target(self, *path, **kwargs) -> law.LocalTarget | law.wlcg.WLCGTarget | law.MirroredTarget: """ target(*path, location=None, **kwargs) """ # get the default location location = kwargs.pop("location", self.default_output_location) # parse it and obtain config values if necessary if isinstance(location, str): location = OutputLocation[location] if location == OutputLocation.config: lookup_keys = self.get_config_lookup_keys(self) outputs_dict = self._get_cfg_outputs_dict() location = copy.deepcopy(self._dfs_key_lookup(lookup_keys, outputs_dict)) if not location: self.logger.debug( f"no option 'outputs::{self.task_family}' found in law.cfg to obtain target " "location, falling back to 'local'", ) location = ["local"] location[0] = OutputLocation[location[0]] location = law.util.make_list(location) # forward to correct function if location[0] == OutputLocation.local: # get other options loc, store_parts_modifier = (location[1:] + [None, None])[:2] loc_key = "fs" if (loc and law.config.has_section(loc)) else "store" kwargs.setdefault(loc_key, loc) kwargs.setdefault("store_parts_modifier", store_parts_modifier) return self.local_target(*path, **kwargs) if location[0] == OutputLocation.wlcg: # get other options fs, store_parts_modifier = (location[1:] + [None, None])[:2] kwargs.setdefault("fs", fs) kwargs.setdefault("store_parts_modifier", store_parts_modifier) return self.wlcg_target(*path, **kwargs) if location[0] == OutputLocation.wlcg_mirrored: # get other options loc, wlcg_fs, store_parts_modifier = (location[1:] + [None, None, None])[:3] kwargs.setdefault("store_parts_modifier", store_parts_modifier) # create the wlcg target wlcg_kwargs = kwargs.copy() wlcg_kwargs.setdefault("fs", wlcg_fs) wlcg_target = self.wlcg_target(*path, **wlcg_kwargs) # TODO: add rule for falling back to wlcg target? # create the local target local_kwargs = kwargs.copy() loc_key = "fs" if (loc and law.config.has_section(loc)) else "store" local_kwargs.setdefault(loc_key, loc) local_target = self.local_target(*path, **local_kwargs) # build the mirrored target from these two mirrored_target_cls = ( law.MirroredFileTarget if isinstance(local_target, law.LocalFileTarget) else law.MirroredDirectoryTarget ) # whether to wait for local synchrnoization (for debugging purposes) local_sync = law.util.flag_to_bool(os.getenv("CF_MIRRORED_TARGET_LOCAL_SYNC", "true")) # create and return the target return mirrored_target_cls( path=local_target.abspath, remote_target=wlcg_target, local_target=local_target, local_sync=local_sync, ) raise Exception(f"cannot determine output location based on '{location}'")
[docs] def get_parquet_writer_opts(self, repeating_values: bool = False) -> dict[str, Any]: """ Returns an option dictionary that can be passed as *writer_opts* to :py:meth:`~law.pyarrow.merge_parquet_task`, for instance, at the end of chunked processing steps that produce a single parquet file. See :py:class:`~pyarrow.parquet.ParquetWriter` for valid options. This method can be overwritten in subclasses to customize the exact behavior. :param repeating_values: Whether the values to be written have predominantly repeating values, in which case differnt compression and encoding strategies are followed. :return: A dictionary with options that can be passed to parquet writer objects. """ # use dict encoding if values are repeating dict_encoding = bool(repeating_values) # build and return options return { "compression": "ZSTD", "compression_level": 1, "use_dictionary": dict_encoding, # ensure that after merging, the resulting parquet structure is the same as that of the # input files, e.g. do not switch from "*.list.item.*" to "*.list.element*." structures, # see https://github.com/scikit-hep/awkward/issues/3331 and # https://github.com/apache/arrow/issues/31731 "use_compliant_nested_type": False, }
[docs] class ConfigTask(AnalysisTask): config = luigi.Parameter( default=default_config, description=f"name of the analysis config to use; default: '{default_config}'", ) configs = law.CSVParameter( default=(default_config,), description=f"comma-separated names of analysis configs to use; default: '{default_config}'", brace_expand=True, ) known_shifts = luigi.Parameter( default=None, visibility=luigi.parameter.ParameterVisibility.PRIVATE, ) exclude_params_req = {"known_shifts"} exclude_params_sandbox = {"known_shifts"} exclude_params_remote_workflow = {"known_shifts"} exclude_params_index = {"known_shifts"} exclude_params_repr = {"known_shifts"} # the field in the store parts behind which the new part is inserted # added here for subclasses that typically refer to the store part added by _this_ class config_store_anchor = "config"
[docs] @classmethod def modify_task_attributes(cls) -> None: """ Hook that is called by law's task register meta class right after subclass creation to update class-level attributes. """ super().modify_task_attributes() # single/multi config adjustments in case the switch has been specified if isinstance(cls.single_config, bool): remove_attr = "configs" if cls.has_single_config() else "config" if getattr(cls, remove_attr, law.no_value) != law.no_value: setattr(cls, remove_attr, None)
@abc.abstractproperty def single_config(cls) -> bool: # flag that should be set to True or False by classes that should be instantiated # (this is wrapped into an abstract instance property as a safe-guard against instantiation of a misconfigured # subclass, but when actually specified, this is to be realized as a boolean class attribute or property) ...
[docs] @classmethod def has_single_config(cls) -> bool: """ Returns whether the class is configured to use a single config. :raises AttributeError: When the class does not specify the *single_config* attribute. :return: *True* if the class uses a single config, *False* otherwise. """ single_config = cls.single_config if not isinstance(single_config, bool): raise AttributeError(f"unspecified 'single_config' attribute in {cls}: {single_config}") return single_config
[docs] @classmethod def ensure_single_config(cls, value: bool, *, attr: str | None = None) -> None: """ Ensures that the :py:attr:`single_config` flag of this task is set to *value* by raising an exception if it is not. This method is typically used to guard the access to attributes. If so, *attr* is used in the exception message to reflect this. :param value: The value to compare the flag with. :param attr: The attribute that triggered the check. """ single_config = cls.has_single_config() if single_config != value: if attr: s = "multiple configs" if single_config else "a single config" msg = f"cannot access attribute '{attr}' when task '{cls}' has {s}" else: s = "multiple configs" if value else "a single config" msg = f"task '{cls}' expected to use {s}" raise Exception(msg)
[docs] @classmethod def config_mode(cls) -> str: """ Returns a string representation of this task's config mode. :return: "single" if the task has a single config, "multi" otherwise. """ return "single" if cls.has_single_config() else "multi"
@classmethod def _get_config_container(cls, params: dict[str, Any]) -> od.Config | list[od.Config] | None: """ Extracts the single or multiple config instances from task parameters *params*, or *None* if neither is found. :param params: Dictionary of task parameters. :return: The config instance(s) or *None*. """ if cls.has_single_config(): if (config_inst := params.get("config_inst")): return config_inst elif (config_insts := params.get("config_insts")): return config_insts return None
[docs] @classmethod def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().resolve_param_values(params) if (analysis_inst := params.get("analysis_inst")): # store a reference to the config inst(s) if cls.has_single_config(): if "config_inst" not in params and "config" in params: params["config_inst"] = analysis_inst.get_config(params["config"]) params["config_insts"] = [params["config_inst"]] else: if "config_insts" not in params and "configs" in params: params["config_insts"] = list(map(analysis_inst.get_config, params["configs"])) # resolving of parameters that is required before ArrayFunctions etc. can be initialized params = cls.resolve_param_values_pre_init(params) # check if shifts are already known if params.get("known_shifts", None) is None: logger_dev.debug(f"{cls.task_family}: shifts unknown") # initialize ArrayFunctions etc. and collect known shifts shifts = params["known_shifts"] = TaskShifts() params = cls.resolve_instances(params, shifts) params["known_shifts"] = shifts # resolving of parameters that can only be performed after ArrayFunction initialization params = cls.resolve_param_values_post_init(params) # resolving of shifts params = cls.resolve_shifts(params) return params
[docs] @classmethod def resolve_instances(cls, params: dict[str, Any], shifts: TaskShifts) -> dict[str, Any]: """ Build the array function instances. For single-config/dataset tasks, resolve_instances is implemented by mixin classes such as the ProducersMixin. For multi-config tasks, resolve_instances from the upstream task is called for each config instance. If the resolve_instances function needs to be called for other combinations of parameters (e.g. per dataset), it can be overwritten by the task class. :param params: Dictionary of task parameters. :param shifts: Collection of local and global shifts. :return: Updated dictionary of task parameters. """ cls.get_known_shifts(params, shifts) if not cls.resolution_task_cls: params["known_shifts"] = shifts return params logger_dev.debug( f"{cls.task_family}: uses ConfigTask.resolve_instances base implementation; " f"upsteam_task_cls was defined as {cls.resolution_task_cls}; ", ) # base implementation for ConfigTasks that do not define any datasets. # Needed for e.g. MergeShiftedHistograms if cls.has_single_config(): _params = params.copy() _params = cls.resolution_task_cls.resolve_instances(params, shifts) cls.resolution_task_cls.get_known_shifts(_params, shifts) else: for config_inst in params["config_insts"]: _params = { **params, "config_inst": config_inst, "config": config_inst.name, } _params = cls.resolution_task_cls.resolve_instances(_params, shifts) cls.resolution_task_cls.get_known_shifts(_params, shifts) params["known_shifts"] = shifts return params
[docs] @classmethod def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: """ Resolve parameters before the array function instances have been initialized. :param params: Dictionary of task parameters. :return: Updated dictionary of task parameters. """ return params
[docs] @classmethod def resolve_param_values_post_init(cls, params: dict[str, Any]) -> dict[str, Any]: """ Resolve parameters after the array function instances have been initialized. :param params: Dictionary of task parameters. :return: Updated dictionary of task parameters. """ return params
[docs] @classmethod def resolve_shifts(cls, params: dict[str, Any]) -> dict[str, Any]: """ Resolve shifts :param params: Dictionary of task parameters. :return: Updated dictionary of task parameters. """ # called within modify_param_values to resolve shifts after all other parameters have been resolved return params
[docs] @classmethod def get_known_shifts( cls, params: dict[str, Any], shifts: TaskShifts, ) -> None: """ Adjusts the local and upstream fields of the *shifts* object to include shifts implemented by _this_ task, and dependent shifts that are implemented by upstream tasks. :param params: Dictionary of task parameters. :param shifts: TaskShifts object to adjust. """ return params
resolution_task_cls = None
[docs] @classmethod def req_params(cls, inst: law.Task, *args, **kwargs) -> dict[str, Any]: params = super().req_params(inst, *args, **kwargs) # manually add known shifts between workflows and branches if isinstance(inst, law.BaseWorkflow) and inst.__class__ == cls and getattr(inst, "known_shifts", None): params["known_shifts"] = inst.known_shifts return params
@classmethod def _multi_sequence_repr( cls, values: Sequence[str] | Sequence[Sequence[str]], sort: bool = False, ) -> str: """ Returns a string representation of a singly (for single config) or doubly (for multi config) nested sequence of string *values*. In the former case, the values are sorted if *sort* is *True* and formed into a representation. The behavior of the latter case depends on whether values are identical between configs. If they are, handle them as a single sequence. Otherwise, the representation consists of the number of values per config and a hash of the combined, flat values. :param values: Nested values. :param sort: Whether to sort the values. :return: A string representation. """ # empty case if not values: return "none" # optional sorting helper maybe_sort = (lambda vals: sorted(vals)) if sort else (lambda vals: vals) # helper to perform the single representation, assuming already sorted values def single_repr(values: Sequence[str]) -> str: if not values: return None if len(values) == 1: return values[0] return f"{len(values)}_{law.util.create_hash(values)}" # single case if not isinstance(values[0], (list, tuple)): return single_repr(maybe_sort(values)) # multi case with a single sequence if len(values) == 1: return single_repr(maybe_sort(values[0])) # multi case with identical sequences values = [maybe_sort(_values) for _values in values] if all(_values == values[0] for _values in values[1:]): return single_repr(values[0]) # build full representation _repr = "_".join(map(str, map(len, values))) all_values = sum(values, []) return _repr + f"_{law.util.create_hash(all_values)}"
[docs] @classmethod def broadcast_to_configs(cls, value: Any, name: str, n_config_insts: int) -> tuple[Any]: if not isinstance(value, tuple) or not value: value = (value,) if len(value) == 1: value *= n_config_insts elif len(value) != n_config_insts: raise ValueError( f"number of {name} sequences ({len(value)}) does not match number of configs " f"({n_config_insts})", ) return value
@classmethod def _get_default_version( cls, inst: AnalysisTask, params: dict[str, Any], keys: law.util.InsertableDict, ) -> str | None: # try to lookup the version in the config's auxiliary data if isinstance(inst, ConfigTask) and inst.has_single_config(): version = cls._dfs_key_lookup(keys, inst.config_inst.x("versions", {})) if version: return version return super()._get_default_version(inst, params, keys)
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: ConfigTask | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) # add the config name in front of the task family config = ( inst_or_params.get("config") if isinstance(inst_or_params, dict) else getattr(inst_or_params, "config", None) ) if config not in {law.NO_STR, None, ""}: keys.insert_before("task_family", "config", config) return keys
[docs] @classmethod def get_array_function_dict(cls, params: dict[str, Any]) -> dict[str, Any]: cls.ensure_single_config(True, attr="get_array_function_dict") kwargs = super().get_array_function_dict(params) if "config_inst" in params: kwargs["config_inst"] = params["config_inst"] elif "config" in params and "analysis_inst" in kwargs: kwargs["config_inst"] = kwargs["analysis_inst"].get_config(params["config"]) return kwargs
def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store a reference to the config instances self.config_insts = [ self.analysis_inst.get_config(config) for config in ([self.config] if self.has_single_config() else self.configs) ] if self.has_single_config(): self.config_inst = self.config_insts[0] @property def config_repr(self) -> str: return "__".join(config_inst.name for config_inst in self.config_insts)
[docs] def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() # add the config name parts.insert_after("task_family", "config", self.config_repr) return parts
[docs] def find_keep_columns(self, collection: ColumnCollection) -> set[Route]: """ Returns a set of :py:class:`Route` objects describing columns that should be kept given a type of column *collection*. :param collection: The collection to return. :return: A set of :py:class:`Route` objects. """ columns = set() if collection == ColumnCollection.MANDATORY_COFFEA: columns |= set(Route(c) for c in mandatory_coffea_columns) return columns
def _expand_keep_column( self, column: ColumnCollection | Route | str | Sequence[str | int | slice | type(Ellipsis) | list | tuple], ) -> set[Route]: """ Expands a *column* into a set of :py:class:`Route` objects. *column* can be a :py:class:`ColumnCollection`, a string, or any type that is accepted by :py:class:`Route`. Collections are expanded through :py:meth:`find_keep_columns`. :param column: The column to expand. :return: A set of :py:class:`Route` objects. """ # expand collections if isinstance(column, ColumnCollection): return self.find_keep_columns(column) # brace expand strings if isinstance(column, str): return set(map(Route, law.util.brace_expand(column))) # let Route handle it return {Route(column)}
[docs] class ShiftTask(ConfigTask): shift = luigi.Parameter( default="nominal", description="name of a systematic shift to apply; must fulfill order.Shift naming rules; " "default: 'nominal'", ) local_shift = luigi.Parameter(default=law.NO_STR) # skip passing local_shift to cli completion, req params and sandboxing exclude_params_index = {"local_shift"} exclude_params_req = {"local_shift"} exclude_params_sandbox = {"local_shift"} exclude_params_remote_workflow = {"local_shift"} allow_empty_shift = False
[docs] @classmethod def resolve_shifts(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().resolve_shifts(params) if "known_shifts" not in params: raise Exception(f"{cls.task_family}: known shifts should be resolved before calling 'resolve_shifts'") known_shifts = params["known_shifts"] # get configs config_insts = params.get("config_insts") # require that the shift is set and known if (requested_shift := params.get("shift")) in (None, law.NO_STR): if not cls.allow_empty_shift: raise Exception(f"no shift found in params: {params}") global_shift = local_shift = law.NO_STR else: # check if the shift is known to one of the configs shift_defined_in_config = False for config_inst in config_insts: if requested_shift not in config_inst.shifts: logger_dev.debug(f"shift {requested_shift} unknown to config {config_inst}") else: shift_defined_in_config = True if not shift_defined_in_config: raise ValueError(f"shift {requested_shift} unknown to all configs") # actual shift resolution: compare the requested shift to known ones # local_shift -> the requested shift if implemented by the task itself, else nominal # shift -> the requested shift if implemented by this task # or an upsteam task (== global shift), else nominal global_shift = requested_shift if (local_shift := params.get("local_shift")) in {None, law.NO_STR}: # check cases if requested_shift in known_shifts.local: local_shift = requested_shift elif requested_shift in known_shifts.upstream: local_shift = "nominal" else: global_shift = "nominal" local_shift = "nominal" # store parameters params["shift"] = global_shift params["local_shift"] = local_shift # store references to shift instances if ( params["shift"] != law.NO_STR and params["local_shift"] != law.NO_STR and (not params.get("global_shift_insts") or not params.get("local_shift_insts")) ): params["global_shift_insts"] = {} params["local_shift_insts"] = {} get_shift_or_nominal = lambda config, shift: config.get_shift(shift, default=config.get_shift("nominal")) for config_inst in config_insts: params["global_shift_insts"][config_inst] = get_shift_or_nominal(config_inst, params["shift"]) params["local_shift_insts"][config_inst] = get_shift_or_nominal(config_inst, params["local_shift"]) if cls.has_single_config(): config_inst = params["config_inst"] params["global_shift_inst"] = params["global_shift_insts"][config_inst] params["local_shift_inst"] = params["local_shift_insts"][config_inst] return params
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: ShiftTask | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) # add the (global) shift name shift = ( inst_or_params.get("shift") if isinstance(inst_or_params, dict) else getattr(inst_or_params, "shift", None) ) if shift not in (law.NO_STR, None, ""): keys["shift"] = shift return keys
def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store references to the shift instances self.global_shift_insts = None self.local_shift_insts = None if self.shift not in (None, law.NO_STR) and self.local_shift not in (None, law.NO_STR): get = lambda c, s: c.get_shift(s if s in c.shifts else "nominal") self.global_shift_insts = { config_inst: get(config_inst, self.shift) for config_inst in self.config_insts } self.local_shift_insts = { config_inst: get(config_inst, self.local_shift) for config_inst in self.config_insts } if self.has_single_config(): self.global_shift_inst = None self.local_shift_inst = None if self.global_shift_insts: self.global_shift_inst = self.global_shift_insts[self.config_inst] self.local_shift_inst = self.local_shift_insts[self.config_inst]
[docs] def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() # add the shift name if self.global_shift_insts: parts.insert_after(self.config_store_anchor, "shift", self.shift) return parts
[docs] class DatasetTask(ShiftTask): # all dataset tasks are meant to work for a single config single_config = True dataset = luigi.Parameter( default=default_dataset, description=f"name of the dataset to process; default: '{default_dataset}'", ) file_merging = None
[docs] @classmethod def resolve_param_values_pre_init(cls, params: dict[str, Any]) -> dict[str, Any]: params = super().resolve_param_values_pre_init(params) # store a reference to the dataset inst if "dataset_inst" not in params and "config_inst" in params and "dataset" in params: params["dataset_inst"] = params["config_inst"].get_dataset(params["dataset"]) return params
[docs] @classmethod def get_known_shifts( cls, params: dict[str, Any], shifts: TaskShifts, ) -> None: # dataset can have shifts, that are considered as upstream shifts super().get_known_shifts(params, shifts) if (dataset_inst := params.get("dataset_inst")): if dataset_inst.is_data: # clear all shifts for data shifts.local.clear() shifts.upstream.clear() else: # extend with dataset variations for mc shifts.upstream |= set(dataset_inst.info.keys())
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: DatasetTask | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) # add the dataset name before the shift name dataset = ( inst_or_params.get("dataset") if isinstance(inst_or_params, dict) else getattr(inst_or_params, "dataset", None) ) if dataset not in {law.NO_STR, None, ""}: keys.insert_before("shift", "dataset", dataset) return keys
[docs] @classmethod def get_array_function_dict(cls, params: dict[str, Any]) -> dict[str, Any]: kwargs = super().get_array_function_dict(params) if "dataset_inst" in params: kwargs["dataset_inst"] = params["dataset_inst"] elif "dataset" in params and "config_inst" in kwargs: kwargs["dataset_inst"] = kwargs["config_inst"].get_dataset(params["dataset"]) return kwargs
def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store references to the dataset instance self.dataset_inst = self.config_inst.get_dataset(self.dataset) # store dataset info for the global shift key = ( self.global_shift_inst.name if self.global_shift_inst and self.global_shift_inst.name in self.dataset_inst.info else "nominal" ) self.dataset_info_inst = self.dataset_inst.get_info(key)
[docs] def store_parts(self) -> law.util.InsertableDict: parts = super().store_parts() # insert the dataset parts.insert_after(self.config_store_anchor, "dataset", self.dataset_inst.name) return parts
@property def file_merging_factor(self) -> int: """ Returns the number of files that are handled in one branch. When the :py:attr:`file_merging` attribute is set to a positive integer, this value is returned. Otherwise, if the value is zero, the original number of files is used instead. Consecutive merging steps are not handled yet. """ n_files = self.dataset_info_inst.n_files if isinstance(self.file_merging, int): # interpret the file_merging attribute as the merging factor itself # zero means "merge all in one" if self.file_merging < 0: raise ValueError(f"invalid file_merging value {self.file_merging}") n_merge = n_files if self.file_merging == 0 else self.file_merging else: # no merging at all n_merge = 1 return n_merge
[docs] def create_branch_map(self): """ Define the branch map for when this task is used as a workflow. By default, use the merging information provided by :py:attr:`file_merging_factor` to return a dictionary which maps branches to one or more input file indices. E.g. `1 -> [3, 4, 5]` would mean that branch 1 is simultaneously handling input file indices 3, 4 and 5. """ n_merge = self.file_merging_factor n_files = self.dataset_info_inst.n_files # use iter_chunks which splits a list of length n_files into chunks of maximum size n_merge chunks = law.util.iter_chunks(n_files, n_merge) # use enumerate for simply indexing return dict(enumerate(chunks))
[docs] class CommandTask(AnalysisTask): """ A task that provides convenience methods to work with shell commands, i.e., printing them on the command line and executing them with error handling. """ print_command = law.CSVParameter( default=(), significant=False, description="print the command that this task would execute but do not run any task; this " "CSV parameter accepts a single integer value which sets the task recursion depth to also " "print the commands of required tasks (0 means non-recursive)", ) custom_args = luigi.Parameter( default="", description="custom arguments that are forwarded to the underlying command; they might not " "be encoded into output file paths; empty default", ) exclude_index = True exclude_params_req = {"custom_args"} interactive_params = AnalysisTask.interactive_params + ["print_command"] run_command_in_tmp = False def _print_command(self, args) -> None: max_depth = int(args[0]) print(f"print task commands with max_depth {max_depth}") for dep, _, depth in self.walk_deps(max_depth=max_depth, order="pre"): offset = depth * ("|" + law.task.interactive.ind) print(offset) print("{}> {}".format(offset, dep.repr(color=True))) offset += "|" + law.task.interactive.ind if isinstance(dep, CommandTask): # when dep is a workflow, take the first branch text = law.util.colored("command", style="bright") if isinstance(dep, law.BaseWorkflow) and dep.is_workflow(): dep = dep.as_branch(0) text += " (from branch {})".format(law.util.colored("0", "red")) text += ": " cmd = dep.build_command() if cmd: cmd = law.util.quote_cmd(cmd) if isinstance(cmd, (list, tuple)) else cmd text += law.util.colored(cmd, "cyan") else: text += law.util.colored("empty", "red") print(offset + text) else: print(offset + law.util.colored("not a CommandTask", "yellow"))
[docs] def build_command(self) -> str | list[str]: # this method should build and return the command to run raise NotImplementedError
[docs] def touch_output_dirs(self) -> None: # keep track of created uris so we can avoid creating them twice handled_parent_uris = set() for outp in law.util.flatten(self.output()): # get the parent directory target parent = None if isinstance(outp, law.SiblingFileCollection): parent = outp.dir elif isinstance(outp, law.FileSystemFileTarget): parent = outp.parent # create it if parent and parent.uri() not in handled_parent_uris: parent.touch() handled_parent_uris.add(parent.uri())
[docs] def run_command(self, cmd: str | list[str], optional: bool = False, **kwargs) -> subprocess.Popen: # proper command encoding cmd = (law.util.quote_cmd(cmd) if isinstance(cmd, (list, tuple)) else cmd).strip() # when no cwd was set and run_command_in_tmp is True, create a tmp dir if "cwd" not in kwargs and self.run_command_in_tmp: tmp_dir = law.LocalDirectoryTarget(is_tmp=True) tmp_dir.touch() kwargs["cwd"] = tmp_dir.abspath self.publish_message("cwd: {}".format(kwargs.get("cwd", os.getcwd()))) # call it with self.publish_step("running '{}' ...".format(law.util.colored(cmd, "cyan"))): p, lines = law.util.readable_popen(cmd, shell=True, executable="/bin/bash", **kwargs) for line in lines: print(line) # raise an exception when the call failed and optional is not True if p.returncode != 0 and not optional: raise Exception(f"command failed with exit code {p.returncode}: {cmd}") return p
[docs] @law.decorator.log @law.decorator.notify @law.decorator.safe_output def run(self, **kwargs): self.pre_run_command() # default run implementation # first, create all output directories self.touch_output_dirs() # build the command cmd = self.build_command() # run it self.run_command(cmd, **kwargs) self.post_run_command()
[docs] def pre_run_command(self) -> None: return
[docs] def post_run_command(self) -> None: return
[docs] def wrapper_factory( base_cls: law.task.base.Task, require_cls: AnalysisTask, enable: Sequence[str], cls_name: str | None = None, attributes: dict | None = None, docs: str | None = None, ) -> law.task.base.Register: """ Factory function creating wrapper task classes, inheriting from *base_cls* and :py:class:`~law.task.base.WrapperTask`, that do nothing but require multiple instances of *require_cls*. Unless *cls_name* is defined, the name of the created class defaults to the name of *require_cls* plus "Wrapper". Additional *attributes* are added as class-level members when given. The instances of *require_cls* to be required in the :py:meth:`~.wrapper_factory.Wrapper.requires()` method can be controlled by task parameters. These parameters can be enabled through the string sequence *enable*, which currently accepts: - ``configs``, ``skip_configs`` - ``shifts``, ``skip_shifts`` - ``datasets``, ``skip_datasets`` This allows to easily build wrapper tasks that loop over (combinations of) parameters that are either defined in the analysis or config, which would otherwise lead to mostly redundant code. Example: .. code-block:: python class MyTask(DatasetTask): ... MyTaskWrapper = wrapper_factory( base_cls=ConfigTask, require_cls=MyTask, enable=["datasets", "skip_datasets"], ) # this allows to run (e.g.) # law run MyTaskWrapper --datasets st_* --skip-datasets *_tbar When building the requirements, the full combinatorics of parameters is considered. However, certain conditions apply depending on enabled features. For instance, in order to use the "configs" feature (adding a parameter "--configs" to the created class, allowing to loop over a list of config instances known to an analysis), *require_cls* must be at least a :py:class:`ConfigTask` accepting "--config" (mind the singular form), whereas *base_cls* must explicitly not. :param base_cls: Base class for this wrapper :param require_cls: :py:class:`~law.task.base.Task` class to be wrapped :param enable: Enable these parameters to control the wrapped :py:class:`~law.task.base.Task` class instance. Currently allowed parameters are: "configs", "skip_configs", "shifts", "skip_shifts", "datasets", "skip_datasets" :param cls_name: Name of the wrapper instance. If :py:attr:`None`, defaults to the name of the :py:class:`~law.task.base.WrapperTask` class + `"Wrapper"` :param attributes: Add these attributes as class-level members of the new :py:class:`~law.task.base.WrapperTask` class :param docs: Manually set the documentation string `__doc__` of the new :py:class:`~law.task.base.WrapperTask` class instance :raises ValueError: If a parameter provided with `enable` is not in the list of known parameters :raises TypeError: If any parameter in `enable` is incompatible with the :py:class:`~law.task.base.WrapperTask` class instance or the inheritance structure of corresponding classes :raises ValueError: when `configs` are enabled but not found in the analysis config instance :raises ValueError: when `shifts` are enabled but not found in the analysis config instance :raises ValueError: when `datasets` are enabled but not found in the analysis config instance :return: The new :py:class:`~law.task.base.WrapperTask` for the :py:class:`~law.task.base.Task` class `required_cls` """ # check known features known_features = [ "configs", "skip_configs", "shifts", "skip_shifts", "datasets", "skip_datasets", ] for feature in enable: if feature not in known_features: raise ValueError( f"unknown enabled feature '{feature}', known features are " f"'{','.join(known_features)}'", ) # treat base_cls as a tuple base_classes = law.util.make_tuple(base_cls) base_cls = base_classes[0] # define wrapper feature flags has_configs = "configs" in enable has_skip_configs = has_configs and "skip_configs" in enable has_shifts = "shifts" in enable has_skip_shifts = has_shifts and "skip_shifts" in enable has_datasets = "datasets" in enable has_skip_datasets = has_datasets and "skip_datasets" in enable # helper to check if enabled features are compatible with required and base class def check_class_compatibility(name, min_require_cls, max_base_cls): if not issubclass(require_cls, min_require_cls): raise TypeError( f"when the '{name}' feature is enabled, require_cls must inherit from {min_require_cls}, but " f"{require_cls} does not", ) if issubclass(base_cls, min_require_cls): raise TypeError( f"when the '{name}' feature is enabled, base_cls must not inherit from {min_require_cls}, but " f"{base_cls} does", ) if not issubclass(max_base_cls, base_cls): raise TypeError( f"when the '{name}' feature is enabled, base_cls must be a super class of {max_base_cls}, but " f"{base_cls} is not", ) # check classes if has_configs: check_class_compatibility("configs", ConfigTask, AnalysisTask) if has_shifts: check_class_compatibility("shifts", ShiftTask, ConfigTask) if has_datasets: check_class_compatibility("datasets", DatasetTask, ShiftTask) # create the class class Wrapper(*base_classes, law.WrapperTask): exclude_params_repr_empty = set() exclude_params_req_set = set() if has_configs: configs = law.CSVParameter( default=(default_config,), description="names or name patterns of configs to use; can also be the key of a mapping defined in the " f"'config_groups' auxiliary data of the analysis; default: {default_config}", brace_expand=True, ) exclude_params_req_set.add("configs") if has_skip_configs: skip_configs = law.CSVParameter( default=(), description="names or name patterns of configs to skip after evaluating --configs; can also be the key " "of a mapping defined in the 'config_groups' auxiliary data of the analysis; empty default", brace_expand=True, ) exclude_params_repr_empty.add("skip_configs") exclude_params_req_set.add("skip_configs") if has_datasets: datasets = law.CSVParameter( default=("*",), description="names or name patterns of datasets to use; can also be the key of a mapping defined in " "the 'dataset_groups' auxiliary data of the corresponding config; default: ('*',)", brace_expand=True, ) exclude_params_req_set.add("datasets") if has_skip_datasets: skip_datasets = law.CSVParameter( default=(), description="names or name patterns of datasets to skip after evaluating --datasets; can also be the " "key of a mapping defined in the 'dataset_groups' auxiliary data of the corresponding config; empty " "default", brace_expand=True, ) exclude_params_repr_empty.add("skip_datasets") exclude_params_req_set.add("skip_datasets") if has_shifts: shifts = law.CSVParameter( default=("nominal",), description="names or name patterns of shifts to use; can also be the key of a mapping defined in the " "'shift_groups' auxiliary data of the corresponding config; default: ('nominal',)", brace_expand=True, ) exclude_params_req_set.add("shifts") if has_skip_shifts: skip_shifts = law.CSVParameter( default=(), description="names or name patterns of shifts to skip after evaluating --shifts; can also be the key " "of a mapping defined in the 'shift_groups' auxiliary data of the corresponding config; empty default", brace_expand=True, ) exclude_params_repr_empty.add("skip_shifts") exclude_params_req_set.add("skip_shifts") def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # store wrapper flags self.wrapper_require_cls = require_cls self.wrapper_has_configs = has_configs and self.configs self.wrapper_has_skip_configs = has_skip_configs self.wrapper_has_shifts = has_shifts and self.shifts self.wrapper_has_skip_shifts = has_skip_shifts self.wrapper_has_datasets = has_datasets and self.datasets self.wrapper_has_skip_datasets = has_skip_datasets # store wrapped fields self.wrapper_fields = [ field for field in ["config", "shift", "dataset"] if getattr(self, f"wrapper_has_{field}s") ] # build the parameter space self.wrapper_parameters = self._build_wrapper_parameters() def _build_wrapper_parameters(self): # collect a list of tuples whose order corresponds to wrapper_fields params = [] # get the target config instances if self.wrapper_has_configs: configs = self.find_config_objects( names=self.configs, container=self.analysis_inst, object_cls=od.Config, groups_str="config_groups", ) if not configs: raise ValueError(f"no configs found in analysis {self.analysis_inst} matching {self.configs}") if self.wrapper_has_skip_configs: skip_configs = self.find_config_objects( names=self.skip_configs, container=self.analysis_inst, object_cls=od.Config, groups_str="config_groups", ) configs = [c for c in configs if c not in skip_configs] if not configs: raise ValueError( f"no configs found in analysis {self.analysis_inst} after skipping {self.skip_configs}", ) config_insts = list(map(self.analysis_inst.get_config, sorted(configs))) else: config_insts = [self.config_inst] # for the remaining fields, build the full combinatorics per config_inst for config_inst in config_insts: # sequences for building combinatorics prod_sequences = [] if self.wrapper_has_configs: prod_sequences.append([config_inst.name]) # find all shifts if self.wrapper_has_shifts: shifts = self.find_config_objects( names=self.shifts, container=config_inst, object_cls=od.Shift, groups_str="shift_groups", ) if not shifts: raise ValueError(f"no shifts found in config {config_inst} matching {self.shifts}") if self.wrapper_has_skip_shifts: skip_shifts = self.find_config_objects( names=self.skip_shifts, container=config_inst, object_cls=od.Shift, groups_str="shift_groups", ) shifts = [s for s in shifts if s not in skip_shifts] if not shifts: raise ValueError(f"no shifts found in config {config_inst} after skipping {self.skip_shifts}") # move "nominal" to the front if present shifts = sorted(shifts) if "nominal" in shifts: shifts.insert(0, shifts.pop(shifts.index("nominal"))) prod_sequences.append(shifts) # find all datasets if self.wrapper_has_datasets: datasets = self.find_config_objects( names=self.datasets, container=config_inst, object_cls=od.Dataset, groups_str="dataset_groups", ) if not datasets: raise ValueError(f"no datasets found in config {config_inst} matching {self.datasets}") if self.wrapper_has_skip_datasets: skip_datasets = self.find_config_objects( names=self.skip_datasets, container=config_inst, object_cls=od.Dataset, groups_str="dataset_groups", ) datasets = [d for d in datasets if d not in skip_datasets] if not datasets: raise ValueError( f"no datasets found in config {config_inst} after skipping {self.skip_datasets}", ) prod_sequences.append(sorted(datasets)) # add the full combinatorics params.extend(itertools.product(*prod_sequences)) return params def requires(self) -> dict: # build all requirements based on the parameter space reqs = {} for values in self.wrapper_parameters: params = dict(zip(self.wrapper_fields, values)) # allow custom checks and updates params = self.update_wrapper_params(params) if not params: continue # add the requirement if not present yet req = self.wrapper_require_cls.req(self, **params) if req not in reqs.values(): reqs[values] = req return reqs def update_wrapper_params(self, params): return params # add additional class-level members if attributes: locals().update(attributes) # overwrite __module__ to point to the module of the calling stack frame = inspect.stack()[1] module = inspect.getmodule(frame[0]) Wrapper.__module__ = module.__name__ # overwrite __name__ Wrapper.__name__ = cls_name or f"{require_cls.__name__}Wrapper" # set docs if docs: Wrapper.__docs__ = docs return Wrapper