Source code for autojob.tasks.task

"""This module defines a concrete implementation of :class:`TaskBase`."""

import json
import logging
from pathlib import Path
from typing import Any
from typing import Self
from uuid import UUID

import ase.io
from pydantic import UUID4
from pydantic import Field
from pydantic import ValidationError
from pydantic import ValidationInfo
from pydantic import ValidatorFunctionWrapHandler
from pydantic import field_validator

from autojob import SETTINGS
from autojob.bases.task_base import SetTaskClassMixin
from autojob.bases.task_base import TaskBase
from autojob.bases.task_base import TaskInputsBase
from autojob.bases.task_base import TaskMetadataBase
from autojob.bases.task_base import TaskOutputsBase
from autojob.plugins import get_task_class
from autojob.utils.files import get_uri
from autojob.utils.files import template_script

logger = logging.getLogger(__name__)

# TODO: define annotated ID_string type for Pydantic models
LEGACY_TASK_ID_LENGTH = 10


[docs] class TaskMetadata(TaskMetadataBase): """A concrete implementation of TaskMetadataBase."""
[docs] @field_validator( "study_group_id", "study_id", "task_group_id", "task_id", mode="wrap" ) @classmethod def validate_ids( cls, v: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo, ) -> str | UUID4: """Validate an ID. IDs can either be a UUID or a 10-digit alphanumeric shortuuid string. """ value = v try: value = handler(v) # string validation is insuffucient, so we only accept validation # on UUIDs # TODO: define annotated ID_string if isinstance(value, UUID): return value except ValidationError: pass if ( isinstance(value, str) and len(value) == LEGACY_TASK_ID_LENGTH and v.isalnum() ): return value if ( info.field_name in ( "study_group_id", "study_id", "workflow_step_id", "task_group_id", ) and value is None ): return None msg = f"{v} is not a UUID4 or a 10-digit alphanumeric shortuuid string" raise ValueError(msg)
# TODO: use get_last_updated and test
[docs] @classmethod def from_directory(cls, src: str | Path) -> "TaskMetadataBase": """Create a TaskMetadata document from a task directory.""" logger.debug(f"Loading task metadata from {src}") task_file = Path(src).joinpath(SETTINGS.TASK_METADATA_FILE) with task_file.open(mode="r", encoding="utf-8") as file: raw_metadata: dict[str, Any] = json.load(file) raw_metadata["uri"] = get_uri(dir_name=src) logger.debug(f"Successfully loaded task metadata from {src}") return cls(**raw_metadata)
[docs] class TaskInputs(TaskInputsBase): """The set of task-level inputs."""
[docs] @classmethod def from_directory(cls, src: str | Path, **kwargs) -> "TaskInputs": """Generate a TaskInputs document from a completed task's directory. Args: src: The directory of a completed Task. kwargs: Additional keyword arguments: - strict_mode: Whether or not to raise an error if the input atoms are not found. Defaults to `SETTINGS.STRICT_MODE`. Returns: A class:`TaskInputs` object. """ strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE) logger.debug("Loading task inputs from directory: %s", src) logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis") inputs_json = Path(src, SETTINGS.INPUTS_FILE) with inputs_json.open(mode="r", encoding="utf-8") as file: data = json.load(file) inputs = cls(**data.pop("task_inputs"), **data) try: if inputs.atoms_filename: logger.debug("Retrieving input atoms") inputs.atoms = ase.io.read(Path(src, inputs.atoms_filename)) logger.info( "Successfully loaded input atoms from directory: %s", src ) else: logger.debug("No input atoms to retrieve") except FileNotFoundError: logger.info( "Unable to retrieve input atoms from directory: %s", src ) if strict_mode: raise logger.info("Successfully loaded task inputs from directory: %s", src) return inputs
[docs] class TaskOutputs(TaskOutputsBase): """The set of task-level outputs."""
[docs] @classmethod def from_directory( cls, src: str | Path, **kwargs, ) -> "TaskOutputs": """Generate a TaskOutputs document from a completed task's directory. Args: src: The directory of a completed task. kwargs: Additional keyword arguments: - strict_mode: Whether or not to catch thrown errors. Errors will not be caught if ``strict_mode=True``. Defaults to `SETTINGS.STRICT_MODE`. Returns: A :class:`~TaskOutputs` object. """ src = Path(src) strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE) logger.debug("Loading task outputs from directory: %s", src) logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis") structure = Path(src, SETTINGS.OUTPUT_ATOMS_FILE) try: atoms = ase.io.read(structure) except FileNotFoundError: if strict_mode: raise atoms = None logger.warning( "Unable to retrieve output atoms from directory: %s", src ) logger.debug( "Successfully loaded task outputs from directory: %s", src ) return cls(atoms=atoms)
[docs] class Task(TaskBase, SetTaskClassMixin): """A concrete implementation of TaskBase.""" task_metadata: TaskMetadata = Field( default_factory=TaskMetadata, description="Task metadata" ) task_inputs: TaskInputs = Field( default_factory=TaskInputs, description="Task inputs" ) task_outputs: TaskOutputs | None = Field( default=None, description="Task outputs" )
[docs] @classmethod def load_magic(cls, src: str | Path, *, strict_mode: bool = True) -> Self: """Load a :class:`~TaskBase` subclass using its "base class" metadata. Args: src: The directory from which to load the task. strict_mode: Whether or not to require all outputs. If True, errors will be thrown on missing outputs. Defaults to ``SETTINGS.STRICT_MODE``. Raises: RuntimeError: No build class specified in the task metadata. Only raised if ``strict_mode`` is True. Returns: The loaded task. """ logger.debug("Magically loading task from directory: %s", src) logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis") class_name = TaskMetadata.from_directory(src).task_class if class_name: logger.debug("Loading task with task class: %s", class_name) task_class = get_task_class(class_name) return task_class.from_directory(src, strict_mode=strict_mode) elif strict_mode: msg = ( f"No build class provided for task in {src!s}. " "Unable to use magic mode" ) raise RuntimeError(msg) else: msg = ( "No build class provided for task in %s. Unable to " "use magic mode, so a %s will be created instead." ) logger.warning(msg, src, cls.__name__) return cls.from_directory(src=src, strict_mode=strict_mode)
[docs] @classmethod def from_directory(cls, src: str | Path, **kwargs) -> Self: """Generate a Task document from a completed task's directory. Args: src: The directory of a completed Task. kwargs: Additional keyword arguments: - strict_mode: Whether or not to require all outputs. If True, errors will be thrown on missing outputs. Defaults to ``SETTINGS.STRICT_MODE``. - magic_mode: Whether or not to instantiate subclasses. If True, the task returned must be an instance determined by metadata in the directory. Defaults to False. Returns: An instance of :class:`Task` or a :class:`Task` subclass. """ strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE) magic_mode = kwargs.get("magic_mode", False) logger.debug("Loading task from directory: %s", src) logger.debug("Magic mode: %sabled", "en" if magic_mode else "dis") logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis") if magic_mode: return cls.load_magic(src=src, strict_mode=strict_mode) metadata = TaskMetadata.from_directory(src=src) inputs = TaskInputs.from_directory(src=src, strict_mode=strict_mode) outputs = TaskOutputs.from_directory(src=src, strict_mode=strict_mode) new_task = cls( task_metadata=metadata, task_inputs=inputs, task_outputs=outputs ) logger.debug("Successfully loaded task from directory: %s", src) return new_task
[docs] def write_input_atoms(self, dest: str | Path) -> Path | None: """Write the input atoms to a file. Args: dest: The directory in which to write the Atoms file. Returns: The filename in which the Atoms where written. """ atoms = None logger.debug("Writing input atoms to directory: %s", dest) if not self.task_inputs.atoms: logger.debug("No input atoms to write") else: atoms = Path(dest, self.task_inputs.atoms_filename) logger.debug("Successfully wrote task metadata to file: %s", atoms) ase.io.write(filename=atoms, images=self.task_inputs.atoms) return atoms
[docs] def write_inputs_json( self, dest: str | Path, *, additional_data: dict[str, Any] | None = None, ) -> Path: """Write the inputs JSON to a file. Args: dest: The directory in which to write the inputs JSON. additional_data: A dictionary mapping strings to JSON-serializable values to be merged with the task inputs that will be written to the inputs JSON. Defaults to an empty dictionary. Returns: The filename in which the inputs JSON written. """ logger.debug("Writing input atoms to directory: %s", dest) additional_data = additional_data or {} inputs_json_data = { "task_inputs": self.task_inputs.model_dump( mode="json", exclude={"atoms"} ), **additional_data, } inputs_json = Path(dest, SETTINGS.INPUTS_FILE) with inputs_json.open(mode="w", encoding="utf-8") as file: json.dump(inputs_json_data, file, indent=4) logger.debug("Successfully wrote input json to file: %s", inputs_json) return inputs_json
[docs] def write_metadata(self, dest: str | Path) -> Path: """Write the task metadata to a file. Args: dest: The directory in which to write the task metadata. Returns: The filename in which the task metadata was written. """ logger.debug("Writing task metadata to directory: %s", dest) task_metadata = self.task_metadata.model_dump(mode="json") metadata = Path(dest, SETTINGS.TASK_METADATA_FILE) with metadata.open(mode="w", encoding="utf-8") as file: json.dump(task_metadata, file, indent=4) logger.debug("Successfully wrote task metadata to file: %s", metadata) return metadata
[docs] def write_task_script( self, dest: str | Path, *, additional_data: dict[str, Any] | None = None, ) -> Path: """Write the SLURM input script using the template given. Args: dest: The directory in which to write the SLURM file. additional_data: A dictionary mapping strings to JSON-serializable values to be merged with the task inputs that will be written to the task script. Defaults to an empty dictionary. Returns: A Path representing the filename of the written SLURM script. """ logger.debug("Writing task script to directory: %s", dest) context = {**self.model_dump(), "settings": SETTINGS.model_dump()} context |= additional_data or {} filename = Path(dest, self.task_inputs.task_script) template_script( dest=filename, script_template=self.task_inputs.task_script_template, context=context, ) logger.debug("Successfully wrote task script to file: %s", filename) return filename
[docs] def write_inputs( self, dest: str | Path, **kwargs, # noqa: ARG002 ) -> list[Path]: """Write required inputs for a task to a diretory. Args: dest: The directory in which to save the task results. kwargs: Additional keyword arguments. Returns: A list of input files written. """ logger.debug( "Writing %s inputs to directory: %s", self.__class__.__name__, dest ) atoms = self.write_input_atoms(dest) input_json = self.write_inputs_json(dest) task_metadata = self.write_metadata(dest) script = self.write_task_script(dest) files = [atoms] if atoms else [] files = [*files, input_json, task_metadata, script] logger.debug( "Successfully wrote %s inputs to directory: %s", self.__class__.__name__, dest, ) return files