Source code for autojob.harvest.harvest

"""Harvest data from the directories of completed tasks.

Example:
    Harvest the results in the current working directory as
    calculations

    .. code-block:: python

        from pathlib import Path

        from autojob.tasks.calculation import Calculation
        from autojob.harvest.harvest import harvest

        harvest(dir_name=Path.cwd(), strictness="relaxed", preferred="calculation")

.. important::

    Always verify the units of harvested quantities.
"""

import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal

from tqdm import tqdm

from autojob import SETTINGS
from autojob.plugins import get_task_class
from autojob.utils.files import find_task_dirs

if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase

logger = logging.getLogger(__name__)


def _concatenate_list_sources(sources: list[str] | list[Path]) -> list[str]:
    """Read the lines from a list of files and concatentate their lines.

    Args:
        sources: A list of filenames.

    Returns:
        The unique, non-empty lines in the files provided.
    """
    res = []
    for source in sources:
        with Path(source).open(mode="r", encoding="utf-8") as file:
            lines = [
                line.rstrip() for line in file.readlines() if line.rstrip()
            ]
        res.extend(lines)
    return res


[docs] def harvest( dir_name: str | Path, *, strictness: Literal["strict", "relaxed", "atomic"] | None = None, whitelists: list[str] | list[Path] | None = None, blacklists: list[str] | list[Path] | None = None, preferred: str | None = None, use_cache: bool = False, prog_bar: bool = False, ) -> "list[TaskBase]": """Collect all data in subdirectories of the given directory. Args: dir_name: The directory under which to collect data. strictness: How to treat tasks for which errors are thrown during their harvesting. If ``"strict"``, all harvesting will abort. If ``"atomic"``, only calculations for which errors are not thrown will be harvested. If ``"relaxed"``, every attempt to harvest all calculations. The default behaviour is controlled by the value of ``SETTINGS.STRICT_MODE``. If ``SETTINGS.STRICT_MODE=True``, the default behaviour will be that of ``strictness="strict"``. Otherwise, the default behaviour will be that of ``strictness="relaxed"``. whitelists: A list of strings or paths representing whitelist filenames, where each whitelist points to a list of task IDs that should be harvested. When specified, only tasks with task IDs matching these IDs will be harvested. Defaults to None in which case all tasks are eligible for harvesting. blacklists: A list of strings or paths representing blacklist filenames, where each blacklist points to a list of task IDs that should not be harvested. hen specified, no tasks with task IDs in this list will be harvested. Defaults to None in which case all tasks will be harvested. preferred: The name of the preferred TaskBase type to use to harvest each calculation. Defaults to ``SETTINGS.DEFAULT_TASK``. use_cache: Whether or not to use cached results. If False, then cached results will be ignored and overwritten. Otherwise, outputs will be read from an existing cache. prog_bar: Whether or not to display the progress bar. Defaults to False. Returns: A list of :class:`~task_base.TaskBase` s containing the data within ``dir_name``. """ logger.info("Harvesting calculations from directory: %s", dir_name) strict_mode = ( SETTINGS.STRICT_MODE if strictness is None else strictness in ("strict", "atomic") ) jobs = find_task_dirs(Path(dir_name)) task_class = get_task_class(preferred or SETTINGS.DEFAULT_TASK) if whitelists is not None: jobs = [ j for j in jobs if j.name in _concatenate_list_sources(whitelists) ] if blacklists is not None: jobs = [ j for j in jobs if j.name not in _concatenate_list_sources(blacklists) ] harvested = [] for job in tqdm(jobs, disable=prog_bar): archive_file = job.joinpath(SETTINGS.ARCHIVE_FILE) try: if use_cache and archive_file.exists(): logger.debug("Loading task from cache file: %s", archive_file) with archive_file.open(mode="r", encoding="utf-8") as file: cache: dict[str, Any] = json.load(file) d = next(iter(cache.values())) task_class = get_task_class(d["task_metadata"]["task_class"]) harvested_task = task_class(**d) logger.info( "Successfully loaded task from cache file: %s", archive_file, ) else: harvested_task = task_class.from_directory( job, strict_mode=strict_mode, magic_mode=True, ) harvested.append(harvested_task) except (FileNotFoundError, KeyError) as e: if strict_mode and strictness != "atomic": raise logger.error(e) logger.info(f"{len(harvested)} calculations harvested") return harvested