"""Create studies."""
from __future__ import annotations
from datetime import UTC
from datetime import datetime
from itertools import groupby
import json
import logging
from pathlib import Path
import shutil
from tempfile import TemporaryDirectory
from typing import Any
from typing import ClassVar
from pydantic import UUID4
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import FieldSerializationInfo
from pydantic import SerializerFunctionWrapHandler
from pydantic import field_serializer
from autojob import SETTINGS
from autojob.bases.task_base import TaskBase
from autojob.next import add_item_to_parent
from autojob.next import create_task_group_tree
from autojob.next import create_task_tree
from autojob.tasks.task import Task
from autojob.utils.files import create_templated_dir_name
from autojob.utils.files import find_task_dirs
from autojob.utils.files import find_task_group_dirs
from autojob.utils.schemas import id_factory
logger = logging.getLogger(__name__)
STUDY_FIELDS = [
"label",
"study_group_id",
"study_id",
"date_created",
]
[docs]
class Study(BaseModel):
"""A collection of tasks."""
tasks: list[TaskBase] = Field(
default=[], description="A list of tasks in the study"
)
date_created: datetime = Field(
default_factory=lambda: datetime.now(tz=UTC),
description="The date and time that the study was created",
)
study_id: UUID4 | str = Field(
default_factory=id_factory("s"), union_mode="left_to_right"
)
study_group_id: UUID4 | str = Field(
default_factory=id_factory("g"), union_mode="left_to_right"
)
name: str = Field(default="", description="The name of the study")
notes: list[str] = Field(
default=[], description="A list of notes about the study"
)
model_config: ClassVar = ConfigDict(populate_by_name=True)
[docs]
@field_serializer("tasks", mode="wrap", return_type=list[Task] | list[str])
def serialize_tasks(
self,
v: Any,
_: SerializerFunctionWrapHandler,
info: FieldSerializationInfo,
) -> list[Task] | list[str]:
"""Serialize the tasks in the study."""
if info.mode == "json":
return [str(t.task_metadata.task_id) for t in self.tasks]
return v
[docs]
@field_serializer("date_created", when_used="json")
def serialize_date_created(self, v: datetime) -> str:
"""Serialize the study creation date."""
return v.isoformat()
[docs]
@classmethod
def from_directory(
cls, dir_name: Path, *, strict_mode: bool | None = None
) -> Study:
"""Recreate a study from a directory.
Args:
dir_name: The directory of a completed Task.
strict_mode: Whether or not to require all outputs. If True,
errors will be thrown on missing outputs. Defaults to
``SETTINGS.STRICT_MODE``.
Returns:
The :class:`Study` contained in ``dir_name``.
"""
strict_mode = (
SETTINGS.STRICT_MODE if strict_mode is None else strict_mode
)
metadata_file = dir_name.joinpath(SETTINGS.STUDY_METADATA_FILE)
with metadata_file.open(mode="r", encoding="utf-8") as file:
metadata: dict[str, Any] = json.load(file)
task_groups: list[str] = metadata.pop("task_groups")
task_group_dirs: list[Path] = []
for p in find_task_group_dirs(dir_name):
with Path(p, SETTINGS.TASK_GROUP_METADATA_FILE).open(
mode="r", encoding="utf-8"
) as file:
task_group_metadata = json.load(file)
if task_group_metadata["task_group_id"] in task_groups:
task_group_dirs.append(p)
tasks: list[Task] = []
for source in task_group_dirs:
for task_dir in find_task_dirs(Path(dir_name, source)):
tasks.append(
Task.from_directory(
task_dir,
strict_mode=strict_mode,
magic_mode=True,
)
)
metadata["tasks"] = tasks
return cls(**metadata)
[docs]
def to_directory(
self,
dest: Path,
*,
study_template: str | None = None,
task_group_template: str | None = None,
task_template: str | None = None,
) -> Path:
"""Dump a study and its tasks to a directory.
Args:
dest: The directory in which to dump the :class:`Study`.
study_template: A template string for naming study directories.
Defaults to None in which case the study ID will be used to
create the directory.
task_group_template: A template string for naming task group
directories. Defaults to None in which case the task group ID
will be used to create the directory.
task_template: A template string for naming task directories.
Defaults to None in which case the task ID will be used to
create the directory.
Returns:
The study group directory that was created.
"""
with TemporaryDirectory() as tmpdir:
metadata = self.model_dump(mode="json", exclude={"tasks"})
def tg_key(t: TaskBase) -> str:
return str(t.task_metadata.task_group_id)
tasks = sorted(self.tasks, key=tg_key)
ids_and_tgs = [(x, list(y)) for x, y in groupby(tasks, key=tg_key)]
metadata["task_groups"] = [str(x) for x, _ in ids_and_tgs]
if study_template:
dir_name = create_templated_dir_name(
study_template, Path(tmpdir), metadata
)
else:
dir_name = str(self.study_id)
study_path = Path(tmpdir, dir_name)
study_path.mkdir()
metadata_file = Path(study_path, (SETTINGS.STUDY_METADATA_FILE))
with metadata_file.open(mode="w", encoding="utf-8") as file:
json.dump(metadata, file, indent=4)
for _, task_group in ids_and_tgs:
tg_dest = create_task_group_tree(
task=task_group[0],
dest=study_path,
name_template=task_group_template,
)
for i, task in enumerate(task_group):
_ = create_task_tree(
task,
tg_dest,
name_template=task_template,
)
task_id = str(task.task_metadata.task_id)
tg_metadata_file = Path(
tg_dest,
SETTINGS.TASK_GROUP_METADATA_FILE,
)
# The first task ID is added w/ create_task_group_tree
if i > 0:
add_item_to_parent(
task_id,
tg_metadata_file,
"tasks",
)
created = Path(dest, dir_name)
shutil.copytree(
study_path,
created,
dirs_exist_ok=True,
)
return created