Source code for columnflow.tasks.framework.base

# coding: utf-8

"""
Generic tools and base tasks that are defined along typical objects in an analysis.
"""

from __future__ import annotations

import os
import enum
import importlib
import itertools
import inspect
import functools
import collections
import copy

import luigi
import law
import order as od

from columnflow.columnar_util import mandatory_coffea_columns, Route, ColumnCollection
from columnflow.util import is_regex, DotDict
from columnflow.types import Sequence, Callable, Any, T


logger = law.logger.get_logger(__name__)

# default analysis and config related objects
default_analysis = law.config.get_expanded("analysis", "default_analysis")
default_config = law.config.get_expanded("analysis", "default_config")
default_dataset = law.config.get_expanded("analysis", "default_dataset")

# placeholder to denote a default value that is resolved dynamically
RESOLVE_DEFAULT = "DEFAULT"


[docs]class Requirements(DotDict): """General class for requirements of different tasks. Can be initialized with other :py:class:`~columnflow.util.DotDict` instances and additional keyword arguments ``kwargs``, which are added. """ def __init__(self, *others, **kwargs): super().__init__() # add others and kwargs for reqs in others + (kwargs,): self.update(reqs)
[docs]class BaseTask(law.Task): task_namespace = law.config.get_expanded("analysis", "cf_task_namespace", "cf") # container for upstream requirements for convenience reqs = Requirements()
[docs]class OutputLocation(enum.Enum): """ Output location flag. """ config = "config" local = "local" wlcg = "wlcg" wlcg_mirrored = "wlcg_mirrored"
[docs]class AnalysisTask(BaseTask, law.SandboxTask): analysis = luigi.Parameter( default=default_analysis, description=f"name of the analysis; default: '{default_analysis}'", ) version = luigi.Parameter( description="mandatory version that is encoded into output paths", ) notify_slack = law.slack.NotifySlackParameter(significant=False) notify_mattermost = law.mattermost.NotifyMattermostParameter(significant=False) notify_custom = law.NotifyCustomParameter(significant=False) allow_empty_sandbox = True sandbox = None message_cache_size = 25 local_workflow_require_branches = False output_collection_cls = law.SiblingFileCollection # defaults for targets default_store = "$CF_STORE_LOCAL" default_wlcg_fs = law.config.get_expanded("target", "default_wlcg_fs", "wlcg_fs") default_output_location = "config" exclude_params_index = {"user"} exclude_params_req = {"user", "notify_slack", "notify_mattermost", "notify_custom"} exclude_params_repr = {"user", "notify_slack", "notify_mattermost", "notify_custom"} exclude_params_branch = {"user"} exclude_params_workflow = {"user", "notify_slack", "notify_mattermost", "notify_custom"} # cached and parsed sections of the law config for faster lookup _cfg_outputs_dict = None _cfg_versions_dict = None _cfg_resources_dict = None
[docs] @classmethod def modify_param_values(cls, params: dict) -> dict: params = super().modify_param_values(params) params = cls.resolve_param_values(params) return params
[docs] @classmethod def resolve_param_values(cls, params: dict) -> dict: # store a reference to the analysis inst if "analysis_inst" not in params and "analysis" in params: params["analysis_inst"] = cls.get_analysis_inst(params["analysis"]) return params
[docs] @classmethod def get_analysis_inst(cls, analysis: str) -> od.Analysis: # prepare names if "." not in analysis: raise ValueError(f"invalid analysis format: {analysis}") module_id, name = analysis.rsplit(".", 1) # import the module try: mod = importlib.import_module(module_id) except ImportError as e: raise ImportError(f"cannot import analysis module {module_id}: {e}") # get the analysis instance analysis_inst = getattr(mod, name, None) if analysis_inst is None: raise Exception(f"module {module_id} does not contain analysis instance {name}") return analysis_inst
[docs] @classmethod def req_params(cls, inst: AnalysisTask, **kwargs) -> dict: """ Returns parameters that are jointly defined in this class and another task instance of some other class. The parameters are used when calling ``Task.req(self)``. """ # always prefer certain parameters given as task family parameters (--TaskFamily-parameter) _prefer_cli = law.util.make_set(kwargs.get("_prefer_cli", [])) | { "version", "workflow", "job_workers", "poll_interval", "walltime", "max_runtime", "retries", "acceptance", "tolerance", "parallel_jobs", "shuffle_jobs", "htcondor_cpus", "htcondor_gpus", "htcondor_memory", "htcondor_disk", "htcondor_pool", "pilot", } kwargs["_prefer_cli"] = _prefer_cli # build the params params = super().req_params(inst, **kwargs) # when not explicitly set in kwargs and no global value was defined on the cli for the task # family, evaluate and use the default value if ( isinstance(getattr(cls, "version", None), luigi.Parameter) and "version" not in kwargs and not law.parser.global_cmdline_values().get(f"{cls.task_family}_version") and cls.task_family != law.parser.root_task().task_family ): default_version = cls.get_default_version(inst, params) if default_version and default_version != law.NO_STR: params["version"] = default_version return params
@classmethod def _structure_cfg_items(cls, items: list[tuple[str, Any]]) -> dict: if not items: return {} # apply brace expansion to keys items = sum(( [(_key, value) for _key in law.util.brace_expand(key)] for key, value in items ), []) # breakup keys at double underscores and create a nested dictionary items_dict = {} for key, value in items: if not value: continue d = items_dict parts = key.split("__") for i, part in enumerate(parts): if i < len(parts) - 1: # fill intermediate structure if part not in d: d[part] = {} elif not isinstance(d[part], dict): d[part] = {"*": d[part]} d = d[part] else: # assign value to the last nesting level if part in d and isinstance(d[part], dict): d[part]["*"] = value else: d[part] = value return items_dict @classmethod def _get_cfg_outputs_dict(cls): if cls._cfg_outputs_dict is None and law.config.has_section("outputs"): # collect config item pairs skip_keys = {"wlcg_file_systems", "lfn_sources"} items = [ (key, law.config.get_expanded("outputs", key, None, split_csv=True)) for key, value in law.config.items("outputs") if value and key not in skip_keys ] cls._cfg_outputs_dict = cls._structure_cfg_items(items) return cls._cfg_outputs_dict @classmethod def _get_cfg_versions_dict(cls): if cls._cfg_versions_dict is None and law.config.has_section("versions"): # collect config item pairs items = [ (key, value) for key, value in law.config.items("versions") if value ] cls._cfg_versions_dict = cls._structure_cfg_items(items) return cls._cfg_versions_dict @classmethod def _get_cfg_resources_dict(cls): if cls._cfg_resources_dict is None and law.config.has_section("resources"): # helper to split resource values into key-value pairs themselves def parse(key: str, value: str) -> tuple[str, list[tuple[str, Any]]]: params = [] for part in value.split(","): part = part.strip() if not part: continue if "=" not in part: logger.warning_once( f"invalid_resource_{key}", f"resource for key {key} contains invalid instruction {part}, skipping", ) continue param, value = (s.strip() for s in part.split("=", 1)) params.append((param, value)) return key, params # collect config item pairs items = [ parse(key, value) for key, value in law.config.items("resources") if value ] cls._cfg_resources_dict = cls._structure_cfg_items(items) return cls._cfg_resources_dict
[docs] @classmethod def get_default_version(cls, inst: AnalysisTask, params: dict[str, Any]) -> str | None: """ Determines the default version for instances of *this* task class when created through :py:meth:`req` from another task *inst* given parameters *params*. :param inst: The task instance from which *this* task should be created via :py:meth:`req`. :param params: The parameters that are passed to the task instance. :return: The default version, or *None* if no default version can be defined. """ # get different attributes by which the default version might be looked up keys = cls.get_config_lookup_keys(params) # forward to lookup implementation version = cls._get_default_version(inst, params, keys) # after a version is found, it can still be an exectuable taking the same arguments return version(cls, inst, params) if callable(version) else version
@classmethod def _get_default_version( cls, inst: AnalysisTask, params: dict[str, Any], keys: law.util.InsertableDict, ) -> str | None: # try to lookup the version in the analysis's auxiliary data analysis_inst = params.get("analysis_inst") or getattr(inst, "analysis_inst", None) if analysis_inst: version = cls._dfs_key_lookup(keys, analysis_inst.x("versions", {})) if version: return version # try to find it in the analysis section in the law config if law.config.has_section("versions"): versions_dict = cls._get_cfg_versions_dict() if versions_dict: version = cls._dfs_key_lookup(keys, versions_dict) if version: return version # no default version found return None
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: AnalysisTask | dict[str, Any], ) -> law.util.InsertiableDict: """ Returns a dictionary with keys that can be used to lookup state specific values in a config or dictionary, such as default task versions or output locations. :param inst_or_params: The tasks instance or its parameters. :return: A dictionary with keys that can be used for nested lookup. """ keys = law.util.InsertableDict() get = ( inst_or_params.get if isinstance(inst_or_params, dict) else lambda attr: (getattr(inst_or_params, attr, None)) ) # add the analysis name analysis = get("analysis") if analysis not in {law.NO_STR, None, ""}: keys["analysis"] = analysis # add the task family keys["task_family"] = cls.task_family return keys
@classmethod def _dfs_key_lookup( cls, keys: law.util.InsertableDict[str, str] | Sequence[str], nested_dict: dict[str, Any], empty_value: Any = None, ) -> str | Callable | None: # opinionated nested dictionary lookup alongside in ordered sequence of (optional) keys, # that allows for patterns in the keys and, interpreting the nested dict as a tree, finds # matches in a depth-first (dfs) manner if not nested_dict: return empty_value # the keys to use for the lookup are the values of the keys dict keys = collections.deque(keys.values() if isinstance(keys, dict) else keys) # start tree traversal using a queue lookup consisting of names and values of tree nodes, # as well as the remaining keys (as a deferred function) to compare for that particular path lookup = collections.deque([tpl + ((lambda: keys.copy()),) for tpl in nested_dict.items()]) while lookup: pattern, obj, keys_func = lookup.popleft() # create the copy of comparison keys on demand # (the original sequence is living once on the previous stack until now) _keys = keys_func() # check if the pattern matches any key regex = is_regex(pattern) while _keys: key = _keys.popleft() if law.util.multi_match(key, pattern, regex=regex): # when obj is not a dict, we found the value if not isinstance(obj, dict): return obj # go one level deeper and stop the current iteration keys_func = (lambda _keys: (lambda: _keys.copy()))(_keys) lookup.extendleft(tpl + (keys_func,) for tpl in reversed(obj.items())) break # at this point, no value could be found return empty_value
[docs] @classmethod def get_known_shifts(cls, config_inst: od.Config, params: dict) -> tuple[set[str], set[str]]: """ Returns two sets of shifts in a tuple: shifts implemented by _this_ task, and dependent shifts that are implemented by upstream tasks. """ # get shifts from upstream dependencies, consider both their own and upstream shifts as one upstream_shifts = set() for req in cls.reqs.values(): upstream_shifts |= set.union(*(req.get_known_shifts(config_inst, params) or (set(),))) return set(), upstream_shifts
[docs] @classmethod def get_array_function_kwargs( cls, task: AnalysisTask | None = None, **params, ) -> dict[str, Any]: if task: analysis_inst = task.analysis_inst elif "analysis_inst" in params: analysis_inst = params["analysis_inst"] else: analysis_inst = cls.get_analysis_inst(params["analysis"]) return { "task": task, "analysis_inst": analysis_inst, }
[docs] @classmethod def get_calibrator_kwargs(cls, *args, **kwargs) -> dict[str, Any]: # implemented here only for simplified mro control return cls.get_array_function_kwargs(*args, **kwargs)
[docs] @classmethod def get_selector_kwargs(cls, *args, **kwargs) -> dict[str, Any]: # implemented here only for simplified mro control return cls.get_array_function_kwargs(*args, **kwargs)
[docs] @classmethod def get_producer_kwargs(cls, *args, **kwargs) -> dict[str, Any]: # implemented here only for simplified mro control return cls.get_array_function_kwargs(*args, **kwargs)
[docs] @classmethod def get_weight_producer_kwargs(cls, *args, **kwargs) -> dict[str, Any]: # implemented here only for simplified mro control return cls.get_array_function_kwargs(*args, **kwargs)
[docs] @classmethod def find_config_objects( cls, names: str | Sequence[str] | set[str], container: od.UniqueObject, object_cls: od.UniqueObjectMeta, object_groups: dict[str, list] | None = None, accept_patterns: bool = True, deep: bool = False, strict: bool = False, ) -> list[str]: """ Returns all names of objects of type *object_cls* known to a *container* (e.g. :py:class:`od.Analysis` or :py:class:`od.Config`) that match *names*. A name can also be a pattern to match if *accept_patterns* is *True*, or, when given, the key of a mapping *object_group* that matches group names to object names. When *deep* is *True* the lookup of objects in the *container* is recursive. When *strict* is *True*, an error is raised if no matches are found for any of the *names*. Example: .. code-block:: python find_config_objects(["st_tchannel_*"], config_inst, od.Dataset) # -> ["st_tchannel_t", "st_tchannel_tbar"] """ singular = object_cls.cls_name_singular plural = object_cls.cls_name_plural _cache = {} def get_all_object_names(): if "all_object_names" not in _cache: if deep: _cache["all_object_names"] = { obj.name for obj, _, _ in getattr(container, f"walk_{plural}")() } else: _cache["all_object_names"] = set(getattr(container, plural).names()) return _cache["all_object_names"] def has_obj(name): if "has_obj_func" not in _cache: kwargs = {} if object_cls in container._deep_child_classes: kwargs["deep"] = deep _cache["has_obj_func"] = functools.partial( getattr(container, f"has_{singular}"), **kwargs, ) return _cache["has_obj_func"](name) object_names = [] lookup = law.util.make_list(names) missing = set() while lookup: name = lookup.pop(0) if has_obj(name): # known object object_names.append(name) elif object_groups and name in object_groups: # a key in the object group dict lookup.extend(list(object_groups[name])) elif accept_patterns: # must eventually be a pattern, perform an object traversal found = [] for _name in sorted(get_all_object_names()): if law.util.multi_match(_name, name): found.append(_name) if not found: missing.add(name) object_names.extend(found) if missing and strict: missing_str = ",".join(sorted(missing)) raise ValueError(f"names/patterns did not yield any matches: {missing_str}") return law.util.make_unique(object_names)
[docs] @classmethod def resolve_config_default( cls, task_params: dict[str, Any], param: str | tuple[str] | None, container: str | od.AuxDataMixin = "config_inst", default_str: str | None = None, multiple: bool = False, ) -> str | tuple | Any | None: """ Resolves a given parameter value *param*, checks if it should be placed with a default value when empty, and in this case, does the actual default value resolution. This resolution is triggered only in case *param* refers to :py:attr:`RESOLVE_DEFAULT`, a 1-tuple containing this attribute, or *None*, If so, the default is identified via the *default_str* from an :py:class:`order.AuxDataMixin` *container* and points to an auxiliary that can be either a string or a function. In the latter case, it is called with the task class, the container instance, and all task parameters. Note that when no *container* is given, *param* is returned unchanged. When *multiple* is *True*, a tuple is returned. If *multiple* is *False* and the resolved parameter is an iterable, the first entry is returned. Example: .. code-block:: python def resolve_param_values(params): params["producer"] = AnalysisTask.resolve_config_default( params, params.get("producer"), container=params["config_inst"] default_str="default_producer", multiple=True, ) config_inst = od.Config( id=0, name="my_config", aux={"default_producer": ["my_producer_1", "my_producer_2"]}, ) params = { "config_inst": config_inst, "producer": RESOLVE_DEFAULT, } resolve_param_values(params) # sets params["producer"] to ("my_producer_1", "my_producer_2") params = { "config_inst": config_inst, "producer": "some_other_producer", } resolve_param_values(params) # sets params["producer"] to "some_other_producer" Example where the default points to a function: .. code-block:: python def resolve_param_values(params): params["ml_model"] = AnalysisTask.resolve_config_default( params, params.get("ml_model"), container=params["config_inst"] default_str="default_ml_model", multiple=True, ) # a function that chooses the ml_model based on an attibute that is set in an inference_model def default_ml_model(task_cls, container, task_params): default_ml_model = None # check if task is using an inference model if "inference_model" in task_params.keys(): inference_model = task_params.get("inference_model", None) # if inference model is not set, assume it's the container default if inference_model in (None, "NO_STR"): inference_model = container.x.default_inference_model # get the default_ml_model from the inference_model_inst inference_model_inst = columnflow.inference.InferenceModel._subclasses[inference_model] default_ml_model = getattr(inference_model_inst, "ml_model_name", default_ml_model) return default_ml_model return default_ml_model config_inst = od.Config( id=0, name="my_config", aux={"default_ml_model": default_ml_model}, ) @inference_model(ml_model_name="default_ml_model") def my_inference_model(self): # some inference model implementation ... params = {"config_inst": config_inst, "ml_model": None, "inference_model": "my_inference_model"} resolve_param_values(params) # sets params["ml_model"] to "my_ml_model" params = {"config_inst": config_inst, "ml_model": "some_ml_model", "inference_model": "my_inference_model"} resolve_param_values(params) # sets params["ml_model"] to "some_ml_model" """ # check if the parameter value is to be resolved resolve_default = param in (None, RESOLVE_DEFAULT, (RESOLVE_DEFAULT,)) # interpret missing parameters (e.g. NO_STR) as None # (special case: an empty string is usually an active decision, but counts as missing too) if law.is_no_param(param) or resolve_default or param == "" or param == (): param = None # actual resolution if resolve_default: # get the container inst (mostly a config_inst or analysis_inst) if isinstance(container, str): container = task_params.get(container) # expand default when container is set if container and default_str: param = container.x(default_str, None) if default_str else None # allow default to be a function, taking task parameters as input if isinstance(param, Callable): param = param(cls, container, task_params) # when still empty, return an empty value if param is None: return () if multiple else None # return either a tuple or the first param, based on the *multiple* param = law.util.make_tuple(param) return param if multiple else (param[0] if param else None)
[docs] @classmethod def resolve_config_default_and_groups( cls, task_params: dict[str, Any], param: str | tuple[str] | None, container: str | od.AuxDataMixin = "config_inst", default_str: str | None = None, groups_str: str | None = None, ) -> tuple[str]: """ This method is similar to :py:meth:`~.resolve_config_default` in that it checks if a parameter value *param* is empty and should be replaced with a default value. See the referenced method for documentation on *task_params*, *param*, *container* and *default_str*. What this method does in addition is that it checks if the values contained in *param* (after default value resolution) refers to a group of values identified via the *groups_str* from the :py:class:`order.AuxDataMixin` *container* that maps a string to a tuple of strings. If it does, each value in *param* that refers to a group is expanded by the actual group values. Example: .. code-block:: python config_inst = od.Config( id=0, name="my_config", aux={ "default_producer": ["features_1", "my_producer_group"], "producer_groups": {"my_producer_group": ["features_2", "features_3"]}, }, ) params = {"producer": RESOLVE_DEFAULT} AnalysisTask.resolve_config_default_and_groups( params, params.get("producer"), container=config_inst, default_str="default_producer", groups_str="producer_groups", ) # -> ("features_1", "features_2", "features_3") """ # resolve the parameter param = cls.resolve_config_default( task_params=task_params, param=param, container=container, default_str=default_str, multiple=True, ) if not param: return param # get the container inst and return if it's not set if isinstance(container, str): container = task_params.get(container, None) if not container: return param # expand groups recursively if groups_str and container.x(groups_str, {}): param_groups = container.x(groups_str) values = [] lookup = law.util.make_list(param) handled_groups = set() while lookup: value = lookup.pop(0) if value in param_groups: if value in handled_groups: raise Exception( f"definition of '{groups_str}' contains circular references involving " f"group '{value}'", ) lookup = law.util.make_list(param_groups[value]) + lookup handled_groups.add(value) else: values.append(value) param = values return law.util.make_tuple(param)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # store the analysis instance self.analysis_inst = self.get_analysis_inst(self.analysis) # cached values added and accessed by cached_value() self._cached_values = {}
[docs] def cached_value(self, key: str, func: Callable[[], T]) -> T: """ Upon first invocation, the function *func* is called and its return value is stored under *key* in :py:attr:`_cached_values`. Subsequent calls with the same *key* return the cached value. :param key: The key under which the value is stored. :param func: The function that is called to generate the value. :return: The cached value. """ if key not in self._cached_values: self._cached_values[key] = func() return self._cached_values[key]
[docs] def store_parts(self) -> law.util.InsertableDict: """ Returns a :py:class:`law.util.InsertableDict` whose values are used to create a store path. For instance, the parts ``{"keyA": "a", "keyB": "b", 2: "c"}`` lead to the path "a/b/c". The keys can be used by subclassing tasks to overwrite values. :return: Dictionary with parts to create a path to store intermediary results. """ parts = law.util.InsertableDict() # add the analysis name parts["analysis"] = self.analysis_inst.name # in this base class, just add the task class name parts["task_family"] = self.task_family # add the version when set if self.version is not None: parts["version"] = self.version return parts
def _get_store_parts_modifier( self, modifier: str | Callable[[AnalysisTask, dict], dict], ) -> Callable[[AnalysisTask, dict], dict] | None: if isinstance(modifier, str): # interpret it as a name of an entry in the store_parts_modifiers aux entry modifier = self.analysis_inst.x("store_parts_modifiers", {}).get(modifier) return modifier if callable(modifier) else None
[docs] def local_path( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, ) -> str: """ local_path(*path, store=None, fs=None, store_parts_modifier=None) Joins path fragments from *store* (defaulting to :py:attr:`default_store`), :py:meth:`store_parts` and *path* and returns the joined path. In case a *fs* is defined, it should refer to the config section of a local file system, and consequently, *store* is not prepended to the returned path as the resolution of absolute paths is handled by that file system. """ # if no fs is set, determine the main store directory parts = () if not kwargs.pop("fs", None): store = kwargs.get("store") or self.default_store parts += (store,) # get and optional modify the store parts store_parts = self.store_parts() store_parts_modifier = self._get_store_parts_modifier(store_parts_modifier) if callable(store_parts_modifier): store_parts = store_parts_modifier(self, store_parts) # concatenate all parts that make up the path and join them parts += tuple(store_parts.values()) + path path = os.path.join(*map(str, parts)) return path
[docs] def local_target( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, ): """ local_target(*path, dir=False, store=None, fs=None, store_parts_modifier=None, **kwargs) Creates either a local file or directory target, depending on *dir*, forwarding all *path* fragments, *store* and *fs* to :py:meth:`local_path` and all *kwargs* the respective target class. """ _dir = kwargs.pop("dir", False) store = kwargs.pop("store", None) fs = kwargs.get("fs", None) # select the target class cls = law.LocalDirectoryTarget if _dir else law.LocalFileTarget # create the local path path = self.local_path(*path, store=store, fs=fs, store_parts_modifier=store_parts_modifier) # create the target instance and return it return cls(path, **kwargs)
[docs] def wlcg_path( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, ) -> str: """ Joins path fragments from *store_parts()* and *path* and returns the joined path. The full URI to the target is not considered as it is usually defined in ``[wlcg_fs]`` sections in the law config and hence subject to :py:func:`wlcg_target`. """ # get and optional modify the store parts store_parts = self.store_parts() store_parts_modifier = self._get_store_parts_modifier(store_parts_modifier) if callable(store_parts_modifier): store_parts = store_parts_modifier(self, store_parts) # concatenate all parts that make up the path and join them parts = tuple(store_parts.values()) + path path = os.path.join(*map(str, parts)) return path
[docs] def wlcg_target( self, *path, store_parts_modifier: str | Callable[[AnalysisTask, dict], dict] | None = None, **kwargs, ): """ wlcg_target(*path, dir=False, fs=default_wlcg_fs, store_parts_modifier=None, **kwargs) Creates either a remote WLCG file or directory target, depending on *dir*, forwarding all *path* fragments to :py:meth:`wlcg_path` and all *kwargs* the respective target class. When *None*, *fs* defaults to the *default_wlcg_fs* class level attribute. """ _dir = kwargs.pop("dir", False) if not kwargs.get("fs"): kwargs["fs"] = self.default_wlcg_fs # select the target class cls = law.wlcg.WLCGDirectoryTarget if _dir else law.wlcg.WLCGFileTarget # create the local path path = self.wlcg_path(*path, store_parts_modifier=store_parts_modifier) # create the target instance and return it return cls(path, **kwargs)
[docs] def target(self, *path, **kwargs): """ target(*path, location=None, **kwargs) """ # get the default location location = kwargs.pop("location", self.default_output_location) # parse it and obtain config values if necessary if isinstance(location, str): location = OutputLocation[location] if location == OutputLocation.config: lookup_keys = self.get_config_lookup_keys(self) outputs_dict = self._get_cfg_outputs_dict() location = copy.deepcopy(self._dfs_key_lookup(lookup_keys, outputs_dict)) if not location: self.logger.debug( f"no option 'outputs::{self.task_family}' found in law.cfg to obtain target " "location, falling back to 'local'", ) location = ["local"] location[0] = OutputLocation[location[0]] location = law.util.make_list(location) # forward to correct function if location[0] == OutputLocation.local: # get other options loc, store_parts_modifier = (location[1:] + [None, None])[:2] loc_key = "fs" if (loc and law.config.has_section(loc)) else "store" kwargs.setdefault(loc_key, loc) kwargs.setdefault("store_parts_modifier", store_parts_modifier) return self.local_target(*path, **kwargs) if location[0] == OutputLocation.wlcg: # get other options fs, store_parts_modifier = (location[1:] + [None, None])[:2] kwargs.setdefault("fs", fs) kwargs.setdefault("store_parts_modifier", store_parts_modifier) return self.wlcg_target(*path, **kwargs) if location[0] == OutputLocation.wlcg_mirrored: # get other options loc, wlcg_fs, store_parts_modifier = (location[1:] + [None, None, None])[:3] kwargs.setdefault("store_parts_modifier", store_parts_modifier) # create the wlcg target wlcg_kwargs = kwargs.copy() wlcg_kwargs.setdefault("fs", wlcg_fs) wlcg_target = self.wlcg_target(*path, **wlcg_kwargs) # TODO: add rule for falling back to wlcg target? # create the local target local_kwargs = kwargs.copy() loc_key = "fs" if (loc and law.config.has_section(loc)) else "store" local_kwargs.setdefault(loc_key, loc) local_target = self.local_target(*path, **local_kwargs) # build the mirrored target from these two mirrored_target_cls = ( law.MirroredFileTarget if isinstance(local_target, law.LocalFileTarget) else law.MirroredDirectoryTarget ) # whether to wait for local synchrnoization (for debugging purposes) local_sync = law.util.flag_to_bool(os.getenv("CF_MIRRORED_TARGET_LOCAL_SYNC", "true")) # create and return the target return mirrored_target_cls( path=local_target.abspath, remote_target=wlcg_target, local_target=local_target, local_sync=local_sync, ) raise Exception(f"cannot determine output location based on '{location}'")
[docs] def get_parquet_writer_opts(self, repeating_values: bool = False) -> dict[str, Any]: """ Returns an option dictionary that can be passed as *writer_opts* to :py:meth:`~law.pyarrow.merge_parquet_task`, for instance, at the end of chunked processing steps that produce a single parquet file. See :py:class:`~pyarrow.parquet.ParquetWriter` for valid options. This method can be overwritten in subclasses to customize the exact behavior. :param repeating_values: Whether the values to be written have predominantly repeating values, in which case differnt compression and encoding strategies are followed. :return: A dictionary with options that can be passed to parquet writer objects. """ # use dict encoding if values are repeating dict_encoding = bool(repeating_values) # build and return options return { "compression": "ZSTD", "compression_level": 1, "use_dictionary": dict_encoding, # ensure that after merging, the resulting parquet structure is the same as that of the # input files, e.g. do not switch from "*.list.item.*" to "*.list.element*." structures, # see https://github.com/scikit-hep/awkward/issues/3331 and # https://github.com/apache/arrow/issues/31731 "use_compliant_nested_type": False, }
[docs]class ConfigTask(AnalysisTask): config = luigi.Parameter( default=default_config, description=f"name of the analysis config to use; default: '{default_config}'", )
[docs] @classmethod def resolve_param_values(cls, params: dict) -> dict: params = super().resolve_param_values(params) # store a reference to the config inst if "config_inst" not in params and "analysis_inst" in params and "config" in params: params["config_inst"] = params["analysis_inst"].get_config(params["config"]) return params
@classmethod def _get_default_version( cls, inst: AnalysisTask, params: dict[str, Any], keys: law.util.InsertableDict, ) -> str | None: # try to lookup the version in the config's auxiliary data config_inst = params.get("config_inst") or getattr(inst, "config_inst", None) if config_inst: version = cls._dfs_key_lookup(keys, config_inst.x("versions", {})) if version: return version return super()._get_default_version(inst, params, keys)
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: ConfigTask | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) get = ( inst_or_params.get if isinstance(inst_or_params, dict) else lambda attr: (getattr(inst_or_params, attr, None)) ) # add the config name in front of the task family config = get("config") if config not in {law.NO_STR, None, ""}: keys.insert_before("task_family", "config", config) return keys
[docs] @classmethod def get_array_function_kwargs(cls, task=None, **params): kwargs = super().get_array_function_kwargs(task=task, **params) if task: kwargs["config_inst"] = task.config_inst elif "config_inst" in params: kwargs["config_inst"] = params["config_inst"] elif "config" in params and "analysis_inst" in kwargs: kwargs["config_inst"] = kwargs["analysis_inst"].get_config(params["config"]) return kwargs
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # store a reference to the config instance self.config_inst = self.analysis_inst.get_config(self.config)
[docs] def store_parts(self): parts = super().store_parts() # add the config name parts.insert_after("task_family", "config", self.config_inst.name) return parts
[docs] def find_keep_columns(self: ConfigTask, collection: ColumnCollection) -> set[Route]: """ Returns a set of :py:class:`Route` objects describing columns that should be kept given a type of column *collection*. :param collection: The collection to return. :return: A set of :py:class:`Route` objects. """ columns = set() if collection == ColumnCollection.MANDATORY_COFFEA: columns |= set(Route(c) for c in mandatory_coffea_columns) return columns
def _expand_keep_column( self: ConfigTask, column: ColumnCollection | Route | str | Sequence[str | int | slice | type(Ellipsis) | list | tuple], ) -> set[Route]: """ Expands a *column* into a set of :py:class:`Route` objects. *column* can be a :py:class:`ColumnCollection`, a string, or any type that is accepted by :py:class:`Route`. Collections are expanded through :py:meth:`find_keep_columns`. :param column: The column to expand. :return: A set of :py:class:`Route` objects. """ # expand collections if isinstance(column, ColumnCollection): return self.find_keep_columns(column) # brace expand strings if isinstance(column, str): return set(map(Route, law.util.brace_expand(column))) # let Route handle it return {Route(column)}
[docs]class ShiftTask(ConfigTask): shift = luigi.Parameter( default="nominal", description="name of a systematic shift to apply; must fulfill order.Shift naming rules; " "default: 'nominal'", ) local_shift = luigi.Parameter(default=law.NO_STR) # skip passing local_shift to cli completion, req params and sandboxing exclude_params_index = {"local_shift"} exclude_params_req = {"local_shift"} exclude_params_sandbox = {"local_shift"} exclude_params_remote_workflow = {"local_shift"} allow_empty_shift = False
[docs] @classmethod def modify_param_values(cls, params): """ When "config" and "shift" are set, this method evaluates them to set the global shift. For that, it takes the shifts stored in the config instance and compares it with those defined by this class. """ params = super().modify_param_values(params) # get params config_inst = params.get("config_inst") requested_shift = params.get("shift") requested_local_shift = params.get("local_shift") # require that the config is set if config_inst in (None, law.NO_STR): return params # require that the shift is set and known if requested_shift in (None, law.NO_STR): if cls.allow_empty_shift: params["shift"] = law.NO_STR params["local_shift"] = law.NO_STR return params raise Exception(f"no shift found in params: {params}") if requested_shift not in config_inst.shifts: raise ValueError(f"shift {requested_shift} unknown to {config_inst}") # determine the known shifts for this class shifts, upstream_shifts = cls.get_known_shifts(config_inst, params) # actual shift resolution: compare the requested shift to known ones # local_shift -> the requested shift if implemented by the task itself, else nominal # shift -> the requested shift if implemented by this task # or an upsteam task (== global shift), else nominal if requested_local_shift in (None, law.NO_STR): if requested_shift in shifts: params["shift"] = requested_shift params["local_shift"] = requested_shift elif requested_shift in upstream_shifts: params["shift"] = requested_shift params["local_shift"] = "nominal" else: params["shift"] = "nominal" params["local_shift"] = "nominal" # store references params["global_shift_inst"] = config_inst.get_shift(params["shift"]) params["local_shift_inst"] = config_inst.get_shift(params["local_shift"]) return params
[docs] @classmethod def resolve_param_values(cls, params: dict) -> dict: params = super().resolve_param_values(params) # set default shift if params.get("shift") in (None, law.NO_STR): params["shift"] = "nominal" return params
[docs] @classmethod def get_array_function_kwargs(cls, task=None, **params): kwargs = super().get_array_function_kwargs(task=task, **params) if task: if task.local_shift_inst: kwargs["local_shift_inst"] = task.local_shift_inst if task.global_shift_inst: kwargs["global_shift_inst"] = task.global_shift_inst else: if "local_shift_inst" in params: kwargs["local_shift_inst"] = params["local_shift_inst"] if "global_shift_inst" in params: kwargs["global_shift_inst"] = params["global_shift_inst"] return kwargs
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: ShiftTask | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) get = ( inst_or_params.get if isinstance(inst_or_params, dict) else lambda attr: (getattr(inst_or_params, attr, None)) ) # add the (global) shift name shift = get("shift") if shift not in {law.NO_STR, None, ""}: keys["shift"] = shift return keys
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # store references to the shift instances self.local_shift_inst = None self.global_shift_inst = None if self.shift not in (None, law.NO_STR) and self.local_shift not in (None, law.NO_STR): self.global_shift_inst = self.config_inst.get_shift(self.shift) self.local_shift_inst = self.config_inst.get_shift(self.local_shift)
[docs] def store_parts(self): parts = super().store_parts() # add the shift name if self.global_shift_inst: parts.insert_after("config", "shift", self.global_shift_inst.name) return parts
[docs]class DatasetTask(ShiftTask): dataset = luigi.Parameter( default=default_dataset, description=f"name of the dataset to process; default: '{default_dataset}'", ) file_merging = None
[docs] @classmethod def resolve_param_values(cls, params): params = super().resolve_param_values(params) # store a reference to the dataset inst if "dataset_inst" not in params and "config_inst" in params and "dataset" in params: params["dataset_inst"] = params["config_inst"].get_dataset(params["dataset"]) return params
[docs] @classmethod def get_known_shifts(cls, config_inst: od.Config, params: dict) -> tuple[set[str], set[str]]: # dataset can have shifts, that are considered as upstream shifts shifts, upstream_shifts = super().get_known_shifts(config_inst, params) dataset_inst = params.get("dataset_inst") if dataset_inst: if dataset_inst.is_data: # clear all shifts for data shifts.clear() upstream_shifts.clear() else: # extend with dataset variations for mc upstream_shifts |= set(dataset_inst.info.keys()) return shifts, upstream_shifts
[docs] @classmethod def get_config_lookup_keys( cls, inst_or_params: DatasetTask | dict[str, Any], ) -> law.util.InsertiableDict: keys = super().get_config_lookup_keys(inst_or_params) get = ( inst_or_params.get if isinstance(inst_or_params, dict) else lambda attr: (getattr(inst_or_params, attr, None)) ) # add the dataset name before the shift name dataset = get("dataset") if dataset not in {law.NO_STR, None, ""}: keys.insert_before("shift", "dataset", dataset) return keys
[docs] @classmethod def get_array_function_kwargs(cls, task=None, **params): kwargs = super().get_array_function_kwargs(task=task, **params) if task: kwargs["dataset_inst"] = task.dataset_inst elif "dataset_inst" in params: kwargs["dataset_inst"] = params["dataset_inst"] elif "dataset" in params and "config_inst" in kwargs: kwargs["dataset_inst"] = kwargs["config_inst"].get_dataset(params["dataset"]) return kwargs
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # store references to the dataset instance self.dataset_inst = self.config_inst.get_dataset(self.dataset) # store dataset info for the global shift key = ( self.global_shift_inst.name if self.global_shift_inst and self.global_shift_inst.name in self.dataset_inst.info else "nominal" ) self.dataset_info_inst = self.dataset_inst.get_info(key)
[docs] def store_parts(self): parts = super().store_parts() # insert the dataset parts.insert_after("config", "dataset", self.dataset_inst.name) return parts
@property def file_merging_factor(self) -> int: """ Returns the number of files that are handled in one branch. When the :py:attr:`file_merging` attribute is set to a positive integer, this value is returned. Otherwise, if the value is zero, the original number of files is used instead. Consecutive merging steps are not handled yet. """ n_files = self.dataset_info_inst.n_files if isinstance(self.file_merging, int): # interpret the file_merging attribute as the merging factor itself # zero means "merge all in one" if self.file_merging < 0: raise ValueError(f"invalid file_merging value {self.file_merging}") n_merge = n_files if self.file_merging == 0 else self.file_merging else: # no merging at all n_merge = 1 return n_merge
[docs] def create_branch_map(self): """ Define the branch map for when this task is used as a workflow. By default, use the merging information provided by :py:attr:`file_merging_factor` to return a dictionary which maps branches to one or more input file indices. E.g. `1 -> [3, 4, 5]` would mean that branch 1 is simultaneously handling input file indices 3, 4 and 5. """ n_merge = self.file_merging_factor n_files = self.dataset_info_inst.n_files # use iter_chunks which splits a list of length n_files into chunks of maximum size n_merge chunks = law.util.iter_chunks(n_files, n_merge) # use enumerate for simply indexing return dict(enumerate(chunks))
[docs]class CommandTask(AnalysisTask): """ A task that provides convenience methods to work with shell commands, i.e., printing them on the command line and executing them with error handling. """ print_command = law.CSVParameter( default=(), significant=False, description="print the command that this task would execute but do not run any task; this " "CSV parameter accepts a single integer value which sets the task recursion depth to also " "print the commands of required tasks (0 means non-recursive)", ) custom_args = luigi.Parameter( default="", description="custom arguments that are forwarded to the underlying command; they might not " "be encoded into output file paths; empty default", ) exclude_index = True exclude_params_req = {"custom_args"} interactive_params = AnalysisTask.interactive_params + ["print_command"] run_command_in_tmp = False def _print_command(self, args): max_depth = int(args[0]) print(f"print task commands with max_depth {max_depth}") for dep, _, depth in self.walk_deps(max_depth=max_depth, order="pre"): offset = depth * ("|" + law.task.interactive.ind) print(offset) print("{}> {}".format(offset, dep.repr(color=True))) offset += "|" + law.task.interactive.ind if isinstance(dep, CommandTask): # when dep is a workflow, take the first branch text = law.util.colored("command", style="bright") if isinstance(dep, law.BaseWorkflow) and dep.is_workflow(): dep = dep.as_branch(0) text += " (from branch {})".format(law.util.colored("0", "red")) text += ": " cmd = dep.build_command() if cmd: cmd = law.util.quote_cmd(cmd) if isinstance(cmd, (list, tuple)) else cmd text += law.util.colored(cmd, "cyan") else: text += law.util.colored("empty", "red") print(offset + text) else: print(offset + law.util.colored("not a CommandTask", "yellow"))
[docs] def build_command(self): # this method should build and return the command to run raise NotImplementedError
[docs] def touch_output_dirs(self): # keep track of created uris so we can avoid creating them twice handled_parent_uris = set() for outp in law.util.flatten(self.output()): # get the parent directory target parent = None if isinstance(outp, law.SiblingFileCollection): parent = outp.dir elif isinstance(outp, law.FileSystemFileTarget): parent = outp.parent # create it if parent and parent.uri() not in handled_parent_uris: parent.touch() handled_parent_uris.add(parent.uri())
[docs] def run_command(self, cmd, optional=False, **kwargs): # proper command encoding cmd = (law.util.quote_cmd(cmd) if isinstance(cmd, (list, tuple)) else cmd).strip() # when no cwd was set and run_command_in_tmp is True, create a tmp dir if "cwd" not in kwargs and self.run_command_in_tmp: tmp_dir = law.LocalDirectoryTarget(is_tmp=True) tmp_dir.touch() kwargs["cwd"] = tmp_dir.abspath self.publish_message("cwd: {}".format(kwargs.get("cwd", os.getcwd()))) # call it with self.publish_step("running '{}' ...".format(law.util.colored(cmd, "cyan"))): p, lines = law.util.readable_popen(cmd, shell=True, executable="/bin/bash", **kwargs) for line in lines: print(line) # raise an exception when the call failed and optional is not True if p.returncode != 0 and not optional: raise Exception(f"command failed with exit code {p.returncode}: {cmd}") return p
[docs] @law.decorator.log @law.decorator.notify @law.decorator.safe_output def run(self, **kwargs): self.pre_run_command() # default run implementation # first, create all output directories self.touch_output_dirs() # build the command cmd = self.build_command() # run it self.run_command(cmd, **kwargs) self.post_run_command()
[docs] def pre_run_command(self): return
[docs] def post_run_command(self): return
[docs]def wrapper_factory( base_cls: law.task.base.Task, require_cls: AnalysisTask, enable: Sequence[str], cls_name: str | None = None, attributes: dict | None = None, docs: str | None = None, ) -> law.task.base.Register: """Factory function creating wrapper task classes, inheriting from *base_cls* and :py:class:`~law.task.base.WrapperTask`, that do nothing but require multiple instances of *require_cls*. Unless *cls_name* is defined, the name of the created class defaults to the name of *require_cls* plus "Wrapper". Additional *attributes* are added as class-level members when given. The instances of *require_cls* to be required in the :py:meth:`~.wrapper_factory.Wrapper.requires()` method can be controlled by task parameters. These parameters can be enabled through the string sequence *enable*, which currently accepts: - ``configs``, ``skip_configs`` - ``shifts``, ``skip_shifts`` - ``datasets``, ``skip_datasets`` This allows to easily build wrapper tasks that loop over (combinations of) parameters that are either defined in the analysis or config, which would otherwise lead to mostly redundant code. Example: .. code-block:: python class MyTask(DatasetTask): ... MyTaskWrapper = wrapper_factory( base_cls=ConfigTask, require_cls=MyTask, enable=["datasets", "skip_datasets"], ) # this allows to run (e.g.) # law run MyTaskWrapper --datasets st_* --skip-datasets *_tbar When building the requirements, the full combinatorics of parameters is considered. However, certain conditions apply depending on enabled features. For instance, in order to use the "configs" feature (adding a parameter "--configs" to the created class, allowing to loop over a list of config instances known to an analysis), *require_cls* must be at least a :py:class:`ConfigTask` accepting "--config" (mind the singular form), whereas *base_cls* must explicitly not. :param base_cls: Base class for this wrapper :param require_cls: :py:class:`~law.task.base.Task` class to be wrapped :param enable: Enable these parameters to control the wrapped :py:class:`~law.task.base.Task` class instance. Currently allowed parameters are: "configs", "skip_configs", "shifts", "skip_shifts", "datasets", "skip_datasets" :param cls_name: Name of the wrapper instance. If :py:attr:`None`, defaults to the name of the :py:class:`~law.task.base.WrapperTask` class + `"Wrapper"` :param attributes: Add these attributes as class-level members of the new :py:class:`~law.task.base.WrapperTask` class :param docs: Manually set the documentation string `__doc__` of the new :py:class:`~law.task.base.WrapperTask` class instance :raises ValueError: If a parameter provided with `enable` is not in the list of known parameters :raises TypeError: If any parameter in `enable` is incompatible with the :py:class:`~law.task.base.WrapperTask` class instance or the inheritance structure of corresponding classes :raises ValueError: when `configs` are enabled but not found in the analysis config instance :raises ValueError: when `shifts` are enabled but not found in the analysis config instance :raises ValueError: when `datasets` are enabled but not found in the analysis config instance :return: The new :py:class:`~law.task.base.WrapperTask` for the :py:class:`~law.task.base.Task` class `required_cls` """ # check known features known_features = [ "configs", "skip_configs", "shifts", "skip_shifts", "datasets", "skip_datasets", ] for feature in enable: if feature not in known_features: raise ValueError( f"unknown enabled feature '{feature}', known features are " f"'{','.join(known_features)}'", ) # treat base_cls as a tuple base_classes = law.util.make_tuple(base_cls) base_cls = base_classes[0] # define wrapper feature flags has_configs = "configs" in enable has_skip_configs = has_configs and "skip_configs" in enable has_shifts = "shifts" in enable has_skip_shifts = has_shifts and "skip_shifts" in enable has_datasets = "datasets" in enable has_skip_datasets = has_datasets and "skip_datasets" in enable # helper to check if enabled features are compatible with required and base class def check_class_compatibility(name, min_require_cls, max_base_cls): if not issubclass(require_cls, min_require_cls): raise TypeError( f"when the '{name}' feature is enabled, require_cls must inherit from " f"{min_require_cls}, but {require_cls} does not", ) if issubclass(base_cls, min_require_cls): raise TypeError( f"when the '{name}' feature is enabled, base_cls must not inherit from " f"{min_require_cls}, but {base_cls} does", ) if not issubclass(max_base_cls, base_cls): raise TypeError( f"when the '{name}' feature is enabled, base_cls must be a super class of " f"{max_base_cls}, but {base_cls} is not", ) # check classes if has_configs: check_class_compatibility("configs", ConfigTask, AnalysisTask) if has_shifts: check_class_compatibility("shifts", ShiftTask, ConfigTask) if has_datasets: check_class_compatibility("datasets", DatasetTask, ShiftTask) # create the class class Wrapper(*base_classes, law.WrapperTask): exclude_params_repr_empty = set() if has_configs: configs = law.CSVParameter( default=(default_config,), description="names or name patterns of configs to use; can also be the key of a " "mapping defined in the 'config_groups' auxiliary data of the analysis; " f"default: {default_config}", brace_expand=True, ) if has_skip_configs: skip_configs = law.CSVParameter( default=(), description="names or name patterns of configs to skip after evaluating --configs; " "can also be the key of a mapping defined in the 'config_groups' auxiliary data " "of the analysis; empty default", brace_expand=True, ) exclude_params_repr_empty.add("skip_configs") if has_datasets: datasets = law.CSVParameter( default=("*",), description="names or name patterns of datasets to use; can also be the key of a " "mapping defined in the 'dataset_groups' auxiliary data of the corresponding " "config; default: ('*',)", brace_expand=True, ) if has_skip_datasets: skip_datasets = law.CSVParameter( default=(), description="names or name patterns of datasets to skip after evaluating " "--datasets; can also be the key of a mapping defined in the 'dataset_groups' " "auxiliary data of the corresponding config; empty default", brace_expand=True, ) exclude_params_repr_empty.add("skip_datasets") if has_shifts: shifts = law.CSVParameter( default=("nominal",), description="names or name patterns of shifts to use; can also be the key of a " "mapping defined in the 'shift_groups' auxiliary data of the corresponding " "config; default: ('nominal',)", brace_expand=True, ) if has_skip_shifts: skip_shifts = law.CSVParameter( default=(), description="names or name patterns of shifts to skip after evaluating --shifts; " "can also be the key of a mapping defined in the 'shift_groups' auxiliary data " "of the corresponding config; empty default", brace_expand=True, ) exclude_params_repr_empty.add("skip_shifts") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # store wrapper flags self.wrapper_require_cls = require_cls self.wrapper_has_configs = has_configs and self.configs self.wrapper_has_skip_configs = has_skip_configs self.wrapper_has_shifts = has_shifts and self.shifts self.wrapper_has_skip_shifts = has_skip_shifts self.wrapper_has_datasets = has_datasets and self.datasets self.wrapper_has_skip_datasets = has_skip_datasets # store wrapped fields self.wrapper_fields = [ field for field in ["config", "shift", "dataset"] if getattr(self, f"wrapper_has_{field}s") ] # build the parameter space self.wrapper_parameters = self._build_wrapper_parameters() def _build_wrapper_parameters(self): # collect a list of tuples whose order corresponds to wrapper_fields params = [] # get the target config instances if self.wrapper_has_configs: configs = self.find_config_objects( self.configs, self.analysis_inst, od.Config, self.analysis_inst.x("config_groups", {}), ) if not configs: raise ValueError( f"no configs found in analysis {self.analysis_inst} matching {self.configs}", ) if self.wrapper_has_skip_configs: skip_configs = self.find_config_objects( self.skip_configs, self.analysis_inst, od.Config, self.analysis_inst.x("config_groups", {}), ) configs = [c for c in configs if c not in skip_configs] if not configs: raise ValueError( f"no configs found in analysis {self.analysis_inst} after skipping " f"{self.skip_configs}", ) config_insts = list(map(self.analysis_inst.get_config, sorted(configs))) else: config_insts = [self.config_inst] # for the remaining fields, build the full combinatorics per config_inst for config_inst in config_insts: # sequences for building combinatorics prod_sequences = [] if self.wrapper_has_configs: prod_sequences.append([config_inst.name]) # find all shifts if self.wrapper_has_shifts: shifts = self.find_config_objects( self.shifts, config_inst, od.Shift, config_inst.x("shift_groups", {}), ) if not shifts: raise ValueError( f"no shifts found in config {config_inst} matching {self.shifts}", ) if self.wrapper_has_skip_shifts: skip_shifts = self.find_config_objects( self.skip_shifts, config_inst, od.Shift, config_inst.x("shift_groups", {}), ) shifts = [s for s in shifts if s not in skip_shifts] if not shifts: raise ValueError( f"no shifts found in config {config_inst} after skipping " f"{self.skip_shifts}", ) # move "nominal" to the front if present shifts = sorted(shifts) if "nominal" in shifts: shifts.insert(0, shifts.pop(shifts.index("nominal"))) prod_sequences.append(shifts) # find all datasets if self.wrapper_has_datasets: datasets = self.find_config_objects( self.datasets, config_inst, od.Dataset, config_inst.x("dataset_groups", {}), ) if not datasets: raise ValueError( f"no datasets found in config {config_inst} matching " f"{self.datasets}", ) if self.wrapper_has_skip_datasets: skip_datasets = self.find_config_objects( self.skip_datasets, config_inst, od.Dataset, config_inst.x("dataset_groups", {}), ) datasets = [d for d in datasets if d not in skip_datasets] if not datasets: raise ValueError( f"no datasets found in config {config_inst} after skipping " f"{self.skip_datasets}", ) prod_sequences.append(sorted(datasets)) # add the full combinatorics params.extend(itertools.product(*prod_sequences)) return params def requires(self) -> Requirements: """Collect requirements defined by the underlying ``require_cls`` of the :py:class:`~law.task.base.WrapperTask` depending on optional additional parameters. :return: Requirements for the :py:class:`~law.task.base.WrapperTask` instance. """ # build all requirements based on the parameter space reqs = {} for values in self.wrapper_parameters: params = dict(zip(self.wrapper_fields, values)) # allow custom checks and updates params = self.update_wrapper_params(params) if not params: continue # add the requirement if not present yet req = self.wrapper_require_cls.req(self, **params) if req not in reqs.values(): reqs[values] = req return reqs def update_wrapper_params(self, params): return params # add additional class-level members if attributes: locals().update(attributes) # overwrite __module__ to point to the module of the calling stack frame = inspect.stack()[1] module = inspect.getmodule(frame[0]) Wrapper.__module__ = module.__name__ # overwrite __name__ Wrapper.__name__ = cls_name or require_cls.__name__ + "Wrapper" # set docs if docs: Wrapper.__docs__ = docs return Wrapper