Source code for autojob.task
"""Represent and model the results of a task."""
from datetime import UTC
from datetime import datetime
from enum import StrEnum
from enum import unique
import importlib
import json
import logging
from pathlib import Path
import re
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import Self
from typing import TextIO
from uuid import UUID
from uuid import uuid4
import warnings
from ase import Atoms
import ase.io
import jinja2
from pydantic import UUID4
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import ImportString
from pydantic import PrivateAttr
from pydantic import SerializationInfo
from pydantic import TypeAdapter
from pydantic import ValidationError
from pydantic import ValidationInfo
from pydantic import ValidatorFunctionWrapHandler
from pydantic import field_serializer
from pydantic import field_validator
from pydantic import model_validator
from pymatgen.entries.computed_entries import ComputedEntry
from autojob import SETTINGS
from autojob import hpc
from autojob.calculation.parameters import CalculatorType
from autojob.coordinator.classification import CalculationType
from autojob.legacy import StudyType
from autojob.utils.files import extract_structure_name
from autojob.utils.files import get_loader
from autojob.utils.files import get_uri
from autojob.utils.schemas import PydanticAtoms
if TYPE_CHECKING:
from ase.calculators.calculator import Calculator
from pydantic.main import IncEx
logger = logging.getLogger(__name__)
LEGACY_TASK_ID_LENGTH = 10
[docs]
@unique
class TaskOutcome(StrEnum):
"""The state of a task."""
SUCCESS = "successful"
FAILED = "failed"
ERROR = "error"
RUNNING = "running"
IDLE = "idle"
[docs]
class TaskMetadata(BaseModel):
"""The metadata for a task."""
model_config = ConfigDict(populate_by_name=True, extra="allow")
label: str = Field(
default="", description="A description of the job", alias="Name"
)
tags: list[str] = Field(
default=[],
title="tag",
description="Metadata tagged to a given job",
alias="Notes",
)
uri: str | None = Field(
default=None,
description="The uri for the directory containing this task",
)
study_group_id: UUID4 | str | None = Field(
default=None,
description="The study group uuid",
alias="Study Group ID",
union_mode="left_to_right",
)
study_id: UUID4 | str | None = Field(
default=None,
description="The study uuid",
alias="Study ID",
union_mode="left_to_right",
)
workflow_step_id: UUID4 | None = Field(
default=None, description="The workflow step uuid"
)
task_id: UUID4 | str = Field(
default=uuid4(),
description="The task uuid.",
alias="Job ID",
union_mode="left_to_right",
)
calculation_id: str | None = Field(
default=None,
description="The Calculation uuid (for backwards-compatibility)",
alias="Calculation ID",
)
calculation_type: CalculationType | None = Field(
default=None,
description="The Calculation type (for backwards-compatibility)",
alias="Calculation Type",
)
calculator_type: CalculatorType | None = Field(
default=None,
description="The Calculator type (for backwards-compatibility)",
alias="Calculator Type",
)
study_type: StudyType | None = Field(
default=None,
description="The study type (for backwards-compatibility)",
alias="Study Type",
)
last_updated: datetime | None = Field(
default=None,
description="Timestamp for the most recent calculation for this task "
"document",
)
# A special key used to determine how to instantiate subclasses.
_build_class: ImportString["type[Task]"] | None = PrivateAttr(
None,
)
[docs]
@model_validator(mode="after")
def add_build_class(self) -> "TaskMetadata":
"""Add a build class to a constructed TaskMetadata object.
Note that this is for backwards-compatibility with Tasks
created with ``calculation_type`` set and will be removed in future
releases. ``_build_class`` can be set directly during
instantiation.
.. deprecated: 0.12.0
"""
v = self.__pydantic_extra__.pop("_build_class", None)
if v is not None:
if isinstance(v, str):
builder = TypeAdapter(
ImportString["type[Task]"]
).validate_json(v)
else:
builder = TypeAdapter(
ImportString["type[Task]"]
).validate_python(v)
self._build_class = builder
elif self.calculation_type and v is None:
match self.calculation_type:
case CalculationType.RELAXATION:
builder = "calculation"
case CalculationType.VIB:
builder = "vibration"
case _:
warnings.warn(
"No build class conversion defined for "
"calculation type {self.calculation_type!s}. The "
"default, Task, will be used.",
stacklevel=2,
)
builder = "task"
module = importlib.import_module(f"autojob.calculation.{builder}")
self._build_class = getattr(module, builder.capitalize())
return self
[docs]
@field_validator("study_group_id", "study_id", "task_id", mode="wrap")
@classmethod
def validate_ids(
cls,
v: Any,
handler: ValidatorFunctionWrapHandler,
info: ValidationInfo, # noqa: ARG003
) -> str | UUID4:
"""Validate an ID."""
value = v
try:
value = handler(v)
if isinstance(value, UUID):
return value
except ValidationError:
pass
if (
isinstance(value, str)
and len(value) == LEGACY_TASK_ID_LENGTH
and v.isalnum()
):
return value
msg = f"{v} is not a UUID4 or a 10-digit alphanumeric shortuuid string"
raise ValueError(msg)
[docs]
@field_validator("tags", mode="wrap")
@classmethod
def validate_tags(
cls,
v: Any,
handler: ValidatorFunctionWrapHandler,
info: ValidationInfo, # noqa: ARG003
) -> list[str]:
"""Validate job notes/tags."""
try:
tags: list[str] = handler(v)
return tags
except ValidationError:
return [x.strip() for x in str(v).split(";")] if v else []
[docs]
@field_serializer(
"study_group_id",
"study_id",
"task_id",
"workflow_step_id",
mode="plain",
)
@staticmethod
def serialize_ids(
v: Any,
info: SerializationInfo, # noqa: ARG004
) -> str | None:
"""Serialize IDs."""
return v if v is None else str(v)
[docs]
@field_serializer(
"last_updated",
mode="plain",
)
@staticmethod
def serialize_last_updated(
v: Any,
info: SerializationInfo, # noqa: ARG004
) -> str | None:
"""Serialize the last updated time."""
return v if v is None else str(v)
[docs]
@field_serializer(
"calculation_type",
"calculator_type",
"study_type",
mode="plain",
)
@staticmethod
def serialize_types(
v: Any,
info: SerializationInfo, # noqa: ARG004
) -> str | None:
"""Serialize ``autojob`` types."""
return str(v) if v else None
[docs]
@field_serializer(
"tags",
mode="plain",
when_used="json",
)
@staticmethod
def serialize_tags(v: Any, info: SerializationInfo) -> str: # noqa: ARG004
"""Serialize tags."""
return "; ".join(v)
[docs]
@classmethod
def from_directory(cls, dir_name: str | Path) -> "TaskMetadata":
"""Create a TaskMetadata document from a task directory."""
logger.debug(f"Loading task metadata from {dir_name}")
task_file = Path(dir_name).joinpath(SETTINGS.JOB_FILE)
with task_file.open(mode="r", encoding="utf-8") as file:
raw_metadata: dict[str, Any] = json.load(file)
raw_metadata["uri"] = get_uri(dir_name=dir_name)
logger.debug(f"Successfully loaded task metadata from {dir_name}")
return cls(**raw_metadata)
[docs]
def model_dump_legacy(
self,
) -> dict[str, Any]:
"""Return a legacy calculation metadata dictionary."""
exclude: IncEx = {
"workflow_step_id",
"uri",
"calculator_type",
"last_updated",
"task_id",
}
model = self.model_dump(
exclude=exclude,
by_alias=True,
mode="json",
)
model["Date Created"] = datetime.now(tz=UTC).isoformat()
return model
class _TaskIODoc(BaseModel):
"""A base class for task input/output documents."""
atoms: PydanticAtoms | None = Field(
default=None, description="Input or output ase.Atoms"
)
[docs]
class TaskInputs(_TaskIODoc):
"""The set of task-level inputs."""
files_to_copy: list[str] = Field(
default=[],
description="The files to copy from the preceding task into the "
"scratch directory of this task.",
)
files_to_delete: list[str] = Field(
default=[],
description="The files to delete from the directory of the task after "
"job completion.",
)
files_to_carry_over: list[str] = Field(
default=[],
description="The files to carry over from the completed task to the "
"new job.",
)
auto_restart: bool = Field(
default=True,
description="Whether or not to automatically restart this calculation "
"with the same parameters if the task finishes unsuccessfully.",
)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")
# TODO: use SETTINGS.SLURM_SCRIPT
[docs]
@classmethod
def from_directory(cls, dir_name: str | Path) -> "TaskInputs":
"""Generate a TaskInputs document from a completed task's directory.
Note that this method retrieves task inputs from the first .sh file
in the `dir_name` returned by `Path.rglob`.
Args:
dir_name: The directory of a completed Task.
Returns:
A TaskInputs object.
"""
logger.debug(f"Loading task inputs from {dir_name}")
logger.warning(
"Note that it is currently not possible to retrieve the "
"`files_to_carryover` attribute from the directory of a completed "
"Task. Currently, `files_to_carryover` is set in "
"`Calculation.from_directory`."
)
with Path(dir_name, SETTINGS.SLURM_SCRIPT).open(
mode="r", encoding="utf-8"
) as file:
lines = file.readlines()
files_to_copy = TaskInputs.extract_files_to_copy(stream=lines)
files_to_delete = TaskInputs.extract_files_to_delete(stream=lines)
auto_restart = TaskInputs.check_auto_restart(stream=lines)
atoms = TaskInputs.get_input_atoms(dir_name=dir_name)
logger.debug(f"Successfully loaded task inputs from {dir_name}")
return cls(
atoms=atoms,
files_to_copy=files_to_copy,
files_to_delete=files_to_delete,
auto_restart=auto_restart,
)
[docs]
@staticmethod
def extract_files_to_copy(stream: TextIO | list[str]) -> list[str]:
"""Parse the slurm submission script for the calculation files to copy.
This method will parse the files to copy when they are listed in
legacy format (i.e., without assignment to a variable) or if they are
assigned to a variable named AUTOJOB_FILES_TO_COPY.
Args:
stream: A TextIO or list of strings containing the text in the
slurm submission script.
Returns:
A list of the strings passed as calculation files to copy.
"""
log_addendum = ""
if hasattr(stream, "name") and hasattr(stream, "tell"):
offset = stream.tell()
log_addendum = f" from {stream.name}"
logger.debug(f"Extracting files to copy{log_addendum}")
env_var_re1 = re.compile(r'AUTOJOB_FILES_TO_COPY="(?P<files>.*)"')
env_var_re2 = re.compile(r'AUTOJOB_COPY_TO_SCRATCH="(?P<files>.*)"')
legacy_delete_re = re.compile(
r'cp -v "\$SLURM_SUBMIT_DIR"/\{(?P<files>\S+)\}'
)
regexes = [env_var_re1, env_var_re2, legacy_delete_re]
files: list[str] = []
for line in stream:
for regex in regexes:
match = regex.match(line)
if match:
files = match.group("files").split(",")
break
if files:
break
if hasattr(stream, "seek"):
stream.seek(offset)
logger.debug(
f"Successfully extracted files to copy{log_addendum}\n"
f"Files: {files!r}"
)
return files
[docs]
@staticmethod
def extract_files_to_delete(stream: TextIO | list[str]) -> list[str]:
"""Parse the slurm submission script for the deleted calculation files.
This method will parse the files to delete when they are listed in
legacy format (i.e., without assignment to a variable) or if they are
assigned to a variable named AUTOJOB_FILES_TO_DELETE.
Args:
stream: A TextIO or list of strings containing the text in the
slurm submission script.
Returns:
A list of the strings passed as calculation files to delete.
"""
log_addendum = ""
if hasattr(stream, "name") and hasattr(stream, "tell"):
offset = stream.tell()
log_addendum = f" from {stream.name}"
logger.debug(f"Extracting files to delete{log_addendum}")
env_var_re = re.compile(r'AUTOJOB_FILES_TO_DELETE="(?P<files>.*)"')
file_delete_start_re = re.compile(r"^rm -vf (?P<files>[^\\]*)(.*)?$")
listing_files = False
files: list[str] = []
for line in stream:
match1 = env_var_re.match(line)
match2 = file_delete_start_re.match(line)
if match1:
files.extend(match1.group("files").split())
break
elif match2:
listing_files = True
files.extend(match2.group("files").split())
elif not line.startswith("sleep 10") and listing_files:
files.extend(line.rstrip("\\ \n").split())
elif listing_files:
break
if hasattr(stream, "seek"):
stream.seek(offset)
logger.debug(
f"Successfully extracted files to delete{log_addendum}\n"
f"Files: {files!r}"
)
return files
[docs]
@staticmethod
def check_auto_restart(stream: TextIO | list[str]) -> bool:
"""Determines if auto-restart was enabled during job submission.
Args:
stream: A TextIO or list[str] containing the lines of the slurm job
submission script.
Returns:
Whether or not auto-restart was enabled during job submission.
"""
log_addendum = ""
if hasattr(stream, "name") and hasattr(stream, "tell"):
offset = stream.tell()
log_addendum = f" in {stream.name}"
logger.debug(f"Checking if auto-restart enabled{log_addendum}")
auto_restart_enabled = False
conditional_re = re.compile(
r"^if \[(\[)? \$restart = true (\])?\]; then$"
)
advance_re = re.compile(r"^autojob advance")
for line in stream:
if conditional_re.match(line) or advance_re.match(line):
auto_restart_enabled = True
break
if hasattr(stream, "seek"):
stream.seek(offset)
adverb = "" if auto_restart_enabled else " not"
logger.debug(f"Auto-restart was{adverb} enabled{log_addendum}")
return auto_restart_enabled
[docs]
@staticmethod
def get_input_atoms(dir_name: str | Path) -> Atoms:
"""Retrieve an Atoms object representing the input structure.
Note that the filename used to identify the structure file is saved to
:attr:`Atoms.info` under the ``"structure"`` key.
Args:
dir_name: the directory containing the completed calculation
Returns:
An Atoms object.
"""
dir_name = Path(dir_name)
logger.debug(f"Retrieving input atoms from {dir_name}")
try:
python_script = dir_name.joinpath(SETTINGS.PYTHON_SCRIPT)
with python_script.open(mode="r", encoding="utf-8") as file:
filename = extract_structure_name(file)
except (FileNotFoundError, RuntimeError):
logger.warning(
f"Unable to extract structure filename from Python script: {python_script}"
)
filename = SETTINGS.INPUT_ATOMS
atoms = ase.io.read(dir_name.joinpath(filename))
if "structure" not in atoms.info:
atoms.info["structure"] = Path(filename).stem.removeprefix("./")
logger.debug(f"Successfully retrieved input atoms from {dir_name}")
return atoms
[docs]
class TaskOutputs(_TaskIODoc):
"""The set of task-level outputs."""
entry: ComputedEntry | None = Field(
default=None, description="The ComputedEntry from the task"
)
outcome: TaskOutcome = Field(
TaskOutcome.IDLE, description="The outcome of the task"
)
[docs]
@classmethod
def from_directory(
cls,
dir_name: str | Path,
*,
strict_mode: bool = SETTINGS.STRICT_MODE,
) -> "TaskOutputs":
"""Generate a TaskOutputs document from a completed task's directory.
Note that the `atoms` object may not be set if the task is incomplete.
In such a case, one may need to use a task-specific `get_output_atoms`
function (i.e., `Calculation.get_output_atoms`)
Args:
dir_name: The directory of a completed Task.
strict_mode: Whether or not to catch thrown errors. Errors will be
thrown if ``strict_mode=True``.
Returns:
A TaskOutputs object.
"""
dir_name = Path(dir_name)
logger.debug(f"Loading task outputs from {dir_name}")
logger.debug(f"Strict mode {'en' if strict_mode else 'dis'}abled")
outputs = {
"atoms": TaskOutputs.get_output_atoms(
dir_name=dir_name, strict_mode=strict_mode
)
}
# TODO: Change this to an external call to the scheduler executable
# TODO: e.g., sacct -j XXXXXXXX | grep ...
if dir_name.joinpath("scratch_dir").exists():
outputs["outcome"] = TaskOutcome.RUNNING
logger.debug(f"Successfully loaded task outputs from {dir_name}")
return cls(**outputs)
[docs]
@staticmethod
def get_output_atoms(
dir_name: str | Path,
*,
strict_mode: bool = SETTINGS.STRICT_MODE,
) -> Atoms | None:
"""Retrieve an Atoms object representing the output structure.
Args:
dir_name: The directory from which to retrieve the output
structure.
strict_mode: Whether or not to raise an error if reading the output
atoms file fails. Defaults to True.
Returns:
An Atoms object representing the output structure or None if no
Atoms object can be retrieved.
"""
structure = Path(dir_name).joinpath(SETTINGS.OUTPUT_ATOMS)
logger.debug(f"Retrieving output atoms from {structure}")
logger.debug(f"Strict mode {'en' if strict_mode else 'dis'}abled")
atoms: Atoms | None = None
try:
atoms = ase.io.read(structure, index=-1)
logger.debug(
f"Successfully retrieved output atoms from {structure}"
)
except (FileNotFoundError, AttributeError):
if strict_mode:
raise
msg = (
f"Unable to retrieve atoms from: {structure}.\n"
"File not found."
)
logger.warning(msg)
return atoms
[docs]
class Task(BaseModel):
"""Represent the result of a task."""
task_metadata: TaskMetadata
task_inputs: TaskInputs = Field(description="Task inputs")
task_outputs: TaskOutputs | None = Field(
default=None, description="Task outputs"
)
model_config = ConfigDict(extra="allow")
[docs]
@classmethod
def load_magic(
cls,
dir_name: str | Path,
*,
strict_mode: bool = True,
) -> Self:
"""Magically load the contents of a directory as a ``Task`` subclass.
Args:
dir_name: The directory from which to load the task.
strict_mode: Whether to raise errors thrown due to missing outputs
Defaults to True in which case errors will be thrown.
Raises:
RuntimeError: No build class specified in the task metadata. Only
raised if ``strict_mode`` is True.
Returns:
The loaded Task.
"""
logger.debug("Magic mode enabled")
logger.debug(f"Strict mode {'en' if strict_mode else 'dis'}abled")
metadata = TaskMetadata.from_directory(dir_name)
if metadata._build_class:
return metadata._build_class.from_directory(
dir_name, strict_mode=strict_mode
)
elif strict_mode:
msg = (
f"No build class provided for task in {dir_name!s}. "
"Unable to use magic mode"
)
raise RuntimeError(msg)
else:
msg = (
"No build class provided for task in %s. Unable to "
"use magic mode, so a %s will be created instead."
)
logger.warning(msg, dir_name, cls.__name__)
return cls.from_directory(dir_name=dir_name, strict_mode=strict_mode)
[docs]
@classmethod
def from_directory(
cls,
dir_name: str | Path,
*,
strict_mode: bool = SETTINGS.STRICT_MODE,
magic_mode: bool = False,
) -> Self:
"""Generate a Task document from a completed task's directory.
Args:
dir_name: The directory of a completed Task.
strict_mode: Whether or not to require all outputs. If True,
errors will be thrown on missing outputs. Defaults to
``SETTINGS.STRICT_MODE``.
magic_mode: Whether to defer the final object creation. If True,
the final object will be an instance of the class specified
by the `_build_class` attribute of the :class:`TaskMetadata`
object created. Otherwise, a :class:`Task` object will be
returned. Defaults to False.
Returns:
A :class:`Task` object.
Note:
The class indicated by ``_build_class`` should be a subclass of the
class from which this method is called as successive steps may
expect attributes and methods of the calling class to be present.
"""
logger.debug(f"Loading task from {dir_name}")
logger.debug(f"Strict mode {'en' if strict_mode else 'dis'}abled")
metadata = TaskMetadata.from_directory(dir_name=dir_name)
if magic_mode:
return cls.load_magic(dir_name=dir_name, strict_mode=strict_mode)
inputs = TaskInputs.from_directory(dir_name=dir_name)
outputs = TaskOutputs.from_directory(
dir_name=dir_name, strict_mode=strict_mode
)
new_task = cls(
task_metadata=metadata, task_inputs=inputs, task_outputs=outputs
)
logger.debug(f"Successfully loaded task from {dir_name}")
return new_task
[docs]
@staticmethod
def create_shell(context: dict[str, Any] | None = None) -> "Task":
"""Recursively create a minimal Task, a shell, for writing inputs.
Args:
context: A dictionary mapping strings to Task attributes that will
be used to populate the shell. Defaults to an empty dictionary.
Returns:
The minimal Task.
"""
context = context or {}
return Task(
task_metadata=context.get("task_metadata", TaskMetadata()),
task_inputs=context.get("task_inputs", TaskInputs()),
)
[docs]
def create_new_task_tree(
self,
root: Path,
create_legacy_dir: bool = False,
) -> Path:
"""Create the directory and parent directories of a new task.
Args:
root: A Path representing the root directory from which to create the
directory of the new task.
create_legacy_dir: Whether or not to create the legacy calculation
directory. Note that calculation ID must be specified to create
legacy directory.
Raises:
TypeError: The calculation ID is None and ``create_legacy_dir`` is True.
Returns:
A Path object representing the directory of the newly created task.
"""
if create_legacy_dir:
new_calc_id = self.task_metadata.calculation_id
if new_calc_id is None:
msg = "Calculation ID must be specified to create legacy directory."
raise TypeError(msg)
new_calc = Path(root).joinpath(new_calc_id)
new_task_parent = new_calc
else:
new_task_parent = root
new_task_id = str(self.task_metadata.task_id)
new_task = new_task_parent.joinpath(new_task_id)
new_task.mkdir(parents=True)
return new_task
[docs]
def to_directory(
self,
dst: str | Path,
*,
legacy_mode: bool = False,
) -> None:
"""Dump the results of a task to a directory.
Args:
dst: The directory in which to save the task results.
legacy_mode: Whether or not to use the legacy mode.
"""
logger.debug(f"Dumping Task to {dst}")
_ = self.write_inputs(dst)
filename = Path(dst).joinpath(SETTINGS.JOB_FILE)
if legacy_mode:
exclude: IncEx = {
"workflow_step_id",
"uri",
"last_updated",
}
else:
exclude = {
"calculation_id",
"calculator_type",
"last_updated",
"calculation_type",
"study_type",
}
model = self.task_metadata.model_dump(
exclude=exclude,
by_alias=True,
mode="json",
)
with filename.open(mode="w", encoding="utf-8") as file:
json.dump(model, file, indent=4)
logger.debug(f"Successfully dumped Task to {dst}")
[docs]
def patch_task(
self,
*,
output_atoms: Atoms,
converged: bool,
error: hpc.SchedulerError | None,
files_to_carry_over: list[str],
strict_mode: bool = SETTINGS.STRICT_MODE,
) -> None:
"""Patch Task attributes using Calculation values.
Note that this method modifies the Task in place. The following
attributes are patched:
- ``Task.task_outputs.atoms``: replaced with ``output_atoms``
- ``Task.task_inputs.files_to_carryover``: replaced with
``files_to_carry_over``
- ``Task.task_outputs.outcome``: set according to ``converged`` and
``error``
Args:
dir_name: The directory from which to source values.
output_atoms: An Atoms object representing the output geometry.
converged: Whether the Calculation is converged.
error: The hpc.SchedulerError from the calculation.
files_to_carry_over: The files to carry over from the previous
calculation.
strict_mode: Whether to raise an error if no output atoms found.
Defaults to True.
"""
if self.task_outputs is None:
if strict_mode:
msg = (
"Patching incomplete Tasks is not supported in strict_mode"
)
raise RuntimeError(msg)
logger.info(
"No task outputs to patch in task %s",
self.task_metadata.task_id,
)
return None
if self.task_outputs.atoms is None:
logger.debug("Patching output atoms")
self.task_outputs.atoms = output_atoms
if not self.task_inputs.files_to_carry_over:
logger.debug(
f"Patching files to carryover: {files_to_carry_over!r}"
)
self.task_inputs.files_to_carry_over = files_to_carry_over
if converged:
self.task_outputs.outcome = TaskOutcome.SUCCESS
elif error is not None:
self.task_outputs.outcome = TaskOutcome.ERROR
[docs]
def prepare_input_atoms(self) -> None:
"""Copy the final magnetic moments to initial magnetic moments.
This function modifies atoms in place. Note that if atoms were obtained
from a ``vasprun.xml`` via ``ase.io.read("vasprun.xml")``, no magnetic
moments will be read. In order to ensure continuity between runs, it is
a good idea to retain the ``WAVECAR`` between runs.
"""
logger.debug("Preparing atoms for next run.")
atoms = self.task_inputs.atoms
if atoms is None:
logger.info("No input atoms found.")
return None
calc: Calculator = self.task_inputs.atoms.calc
if calc is None:
logger.info("No calculator found.")
return None
magmoms = calc.results.get("magmoms", None)
if magmoms is None:
logger.info(
"No magnetic moments to copy found. Using the initial "
"magnetic moments: "
f"{self.task_inputs.atoms.get_initial_magnetic_moments()!r}"
)
return None
self.task_inputs.atoms.set_initial_magnetic_moments(magmoms)
logger.debug("Copied magnetic moments to initial magnetic moments")
[docs]
def write_inputs(
self,
dir_name: str | Path,
*,
run_script_template: str = SETTINGS.DEFAULT_TEMPLATE,
) -> list[Path]:
"""Write the required inputs for a Task to a directory.
If input atoms are present, they will be written to ``dir_name`` with
the filename found under the ``"filename"`` key in the ``Atoms.info``
dictionary if a value exists. Otherwise, the value of
``SETTINGS.INPUT_ATOMS`` will be used.
Args:
dir_name: The directory in which to write the inputs.
run_script_template: The template file to use. Defaults to
``SETTINGS.DEFAULT_TEMPLATE``.
Returns:
A list of Path objects where each Path represents the filename of
an input written to ``dir_name``.
"""
logger.debug(f"Writing {self.__class__} inputs to {dir_name}")
inputs: list[Path] = []
inputs.append(
self.write_script(
dir_name, run_script_template=run_script_template
)
)
if self.task_inputs.atoms is None:
logger.debug("No input atoms to write")
return inputs
atoms_filename = self.task_inputs.atoms.info.get("filename", None)
input_atoms = (
SETTINGS.INPUT_ATOMS if atoms_filename is None else atoms_filename
)
filename = Path(dir_name).joinpath(input_atoms)
self.task_inputs.atoms.write(filename)
inputs.append(filename)
logger.debug(
f"Successfully wrote {self.__class__} inputs to {dir_name}: "
f"{inputs!r}"
)
return inputs
[docs]
def write_script(
self,
dst: Path,
*,
run_script_template: str = SETTINGS.DEFAULT_TEMPLATE,
**kwargs,
) -> Path:
"""Write the input script using the template given.
Args:
dst: The directory in which to write the input script.
run_script_template: The template file to use. Defaults to
``SETTINGS.DEFAULT_TEMPLATE``.
**kwargs: additional keyword arguments to be used to
render the script template.
Returns:
A Path representing the filename of the written input script.
"""
env = jinja2.Environment(
loader=get_loader(),
autoescape=True,
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
)
template = env.get_template(run_script_template)
filename = dst.joinpath(SETTINGS.SLURM_SCRIPT)
inputs = self.task_inputs.model_dump()
inputs.update(kwargs)
with filename.open(mode="x", encoding="utf-8") as file:
file.write(template.render(**inputs))
return filename
Task.model_rebuild()