"""Represent a reference to a variable."""
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import MutableMapping
from functools import reduce
import logging
from typing import TYPE_CHECKING
from typing import Annotated
from typing import Any
from typing import Generic
from typing import TypeVar
from pydantic import Field
from pydantic import GetCoreSchemaHandler
from pydantic import ValidatorFunctionWrapHandler
from pydantic_core import CoreSchema
from pydantic_core import core_schema
from autojob.utils.schemas import Unset
if TYPE_CHECKING:
from autojob.bases.task_base import TaskBase
_T = TypeVar("_T")
_Referenceable = TypeVar("_Referenceable", MutableMapping[str, Any], object)
logger = logging.getLogger(__name__)
AttributePath = Annotated[list[str], Field(min_length=1)]
AttributePaths = Annotated[list[AttributePath], Field(min_length=1)]
# ! Only single source VariableRefences are supported ATM
# ! There should be a check upon Workflow creation for circular
# ! references
[docs]
class VariableReference(Generic[_T]):
"""A reference to a variable.
Attributes:
set_path: A list of strings indicating the path to the variable
to be set.
get_path: A list of strings indicating the path to the variable
to be obtained.
get_paths: A list of lists of strings each indicating a path to
a variable to be obtained.
constant: A value to be used to set the variable.
composer: A function that takes in an ``AttributePath`` and
``AttributePaths`` and returns a value.
Example: Evaluate the value of a VariableReference
>>> from autojob.parametrizations import VariableReference
>>> context = {
... "a": {
... "b": 4,
... }
... }
>>> ref = VariableReference(
... set_path=["a"],
... get_path=["a", "b"],
... constant=4,
... )
>>> ref.evaluate(context)
4
Example: Set a value using a VariableReference
>>> from autojob.parametrizations import VariableReference
>>> context = {
... "a": {
... "b": 4,
... }
... }
>>> ref = VariableReference(
... set_path=["a"],
... get_path=["a", "b"],
... constant=4,
... )
>>> class Object(object):
... pass
>>> target_object = Object()
>>> target_dict = {}
>>> ref.set_input_value(context, target_object)
>>> target_object.a
4
>>> ref.set_input_value(context, target_dict)
>>> target_dict["a"]
4
"""
def __init__(
self,
*,
set_path: AttributePath,
get_path: AttributePath | None = None,
get_paths: AttributePaths | None = None,
constant: Any = None,
composer: Callable | None = None,
) -> None:
"""Instantiate a ``VariableReference``.
Args:
set_path: An ``AttributePath`` indicating the variable to set.
get_path: An ``AttributePath`` indicating the source variable.
Defaults to None.
get_paths: A list of ``AttributePath`` s, each of which will be
combined to the source variable. Defaults to None.
constant: A constant value used to set the variable. Defaults to
None.
composer: A function that accepts the value of the source
variable(s) and returns a value to be used to set the
variable. Defaults to None.
"""
self.set_path = set_path
self.get_path = get_path
self.get_paths = get_paths
self.constant = constant
self.composer = composer
super().__init__()
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Get a Pydantic schema."""
set_path_schema = handler.generate_schema(AttributePath)
get_path_schema = handler.generate_schema(AttributePath | None)
get_paths_schema = handler.generate_schema(AttributePaths | None)
constant_schema = handler.generate_schema(Any)
# ! composer not supported and Callable not JSON serializable
# ! so just use None (instead of Callable | None)
composer_schema = handler.generate_schema(None)
def _set_path(
v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
) -> VariableReference[Any]:
v.set_path = handler(v.set_path)
return v
def _get_path(
v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
) -> VariableReference[Any]:
v.get_path = (
handler(v.get_path) if v.get_path is not None else v.get_path
)
return v
def _get_paths(
v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
) -> VariableReference[Any]:
v.get_paths = handler(v.get_paths)
return v
def _constant(
v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
) -> VariableReference[Any]:
v.constant = handler(v.constant)
return v
def _composer(
v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
) -> VariableReference[Any]:
v.composer = handler(v.composer)
return v
python_schema = core_schema.chain_schema(
[
core_schema.is_instance_schema(cls),
core_schema.no_info_wrap_validator_function(
_set_path, set_path_schema
),
core_schema.no_info_wrap_validator_function(
_get_path, get_path_schema
),
core_schema.no_info_wrap_validator_function(
_get_paths, get_paths_schema
),
core_schema.no_info_wrap_validator_function(
_constant, constant_schema
),
core_schema.no_info_wrap_validator_function(
_composer, composer_schema
),
]
)
return core_schema.json_or_python_schema(
json_schema=core_schema.chain_schema(
[
core_schema.typed_dict_schema(
{
"set_path": core_schema.typed_dict_field(
set_path_schema
),
"get_path": core_schema.typed_dict_field(
set_path_schema,
required=False,
),
"get_paths": core_schema.typed_dict_field(
get_paths_schema,
required=False,
),
# ! Use default JSON caster
"constant": core_schema.typed_dict_field(
constant_schema,
required=False,
),
"composer": core_schema.typed_dict_field(
composer_schema,
required=False,
),
}
),
core_schema.no_info_before_validator_function(
lambda data: VariableReference(
set_path=data["set_path"],
get_path=data.get("get_path", None),
get_paths=data.get("get_paths", None),
constant=data.get("constant", None),
composer=data.get("composer", None),
),
python_schema,
),
]
),
python_schema=python_schema,
)
[docs]
def evaluate(self, context: _Referenceable) -> _T:
"""Evaluate a variable reference in the given context.
Args:
context: A dictionary (or object) containing values to be used to
evaluate the ``VariableReference``.
Raises:
NotImplementedError: ``get_paths`` and ``composer``
``VariableReference`` s are not supported.
Returns:
The value.
"""
if self.get_path is not None:
value: _T = getattrpath(
context,
self.get_path,
)
elif not all(x is None for x in (self.get_paths, self.composer)):
msg = "Multiple get paths and composers are not yet implemented"
raise NotImplementedError(msg)
else:
value = self.constant
return value
[docs]
def getattrpath(obj: _Referenceable, path: Iterable[str]) -> Any:
"""Access an attribute or dictionary value with an attribute path.
Args:
obj: A dictionary or object.
path: A iterable of strings indicating the sequence of attributes or
dictionary keys pointing to the value to get.
Returns:
The attribute or dictionary value.
"""
def _get(attr, o) -> _T:
if isinstance(o, Mapping):
if attr not in o:
o[attr] = {}
return o.get(attr)
return getattr(o, attr)
return reduce(
lambda _obj, _name: _get(_name, _obj),
path,
obj,
)
def _find_targets_and_sources(
task: "TaskBase",
) -> list[tuple[list[str], dict[str, Any]]]:
targets_and_sources: list[tuple[list[str], dict[str, Any]]] = []
for name, attr in task.model_dump(exclude_none=True).items():
if not name.endswith("_inputs"):
continue
targets_and_sources.append(([name], attr))
return targets_and_sources
# TODO: Add support for task_mods, opt_mods, and analysis_mods
[docs]
def create_parametrization(
previous: "TaskBase",
calc_mods: dict[str, Any] | None = None,
sched_mods: dict[str, Any] | None = None,
exclude_metadata: Iterable[str] | None = None,
) -> list[VariableReference[Any]]:
"""Create a parametrization from parameter mod and a previous task.
Args:
previous: A :class:`.calculation.Calculation` representing the
previous calculation.
calc_mods: A dictionary containing modifications to calculator
**parameters**. Defaults to an empty dictionary.
sched_mods: A dictionary containing modifications to scheduler
inputs. Defaults to an empty dictionary.
exclude_metadata: A list of metadata fields to exclude from the
parametrization.
Returns:
A list of ``VariableReference`` s that can be used to set the values
of the new calculation.
Warning:
When specifying `sched_mods`, be wary of setting mutually exclusive
scheduler parameters (e.g, `mem` and `mem_per_cpu` or `cores` and
`cores_per_node`). For example, if the `mem` parameter is set and one
wants to set the `mem_per_cpu` parameter, set the `mem` key to `Unset`
in `sched_mods` in addition to setting the `mem_per_cpu` key.
Note:
When setting the input parameters for the restart task, this function
will assume that any top-level task attribute suffixed with `_inputs`
is an input. For example, when restarting a :class:`.Calculation`,
`attr`:`.Calculation.task_inputs`,
`attr`:`.Calculation.calculation_inputs`, and
`attr`:`.Calculation.scheduler_inputs` will be carried over. When
restarting a :class:`.MolecularDynamics` task,
`attr`:`.Calculation.md_inputs` will be carried over in addition
to the three aforementioned inputs.
"""
calc_mods = calc_mods or {}
sched_mods = sched_mods or {}
exclude_metadata = set(exclude_metadata) if exclude_metadata else None
metadata = previous.task_metadata.model_dump(exclude=exclude_metadata)
parametrization: list[VariableReference] = []
# order matters since earlier VariableReferences will be overwritten by
# later ones
targets_and_sources = [
(["task_metadata"], metadata),
]
targets_and_sources.extend(_find_targets_and_sources(previous))
if hasattr(previous, "calculation_inputs"):
targets_and_sources.append(
# calc_mods only modify calc_params
(["calculation_inputs", "calc_params"], calc_mods),
)
if hasattr(previous, "scheduler_inputs"):
targets_and_sources.append(
(["scheduler_inputs"], sched_mods),
)
for target, source in targets_and_sources:
for input_, value in source.items():
set_path = [*target, input_]
parametrization.append(
VariableReference(set_path=set_path, constant=value)
)
return parametrization