Source code for autojob.parametrizations

"""Represent a reference to a variable."""

from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import MutableMapping
from functools import reduce
import logging
from typing import TYPE_CHECKING
from typing import Annotated
from typing import Any
from typing import Generic
from typing import TypeVar

from pydantic import Field
from pydantic import GetCoreSchemaHandler
from pydantic import ValidatorFunctionWrapHandler
from pydantic_core import CoreSchema
from pydantic_core import core_schema

from autojob.utils.schemas import Unset

if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase

_T = TypeVar("_T")
_Referenceable = TypeVar("_Referenceable", MutableMapping[str, Any], object)

logger = logging.getLogger(__name__)

AttributePath = Annotated[list[str], Field(min_length=1)]
AttributePaths = Annotated[list[AttributePath], Field(min_length=1)]


# ! Only single source VariableRefences are supported ATM
# ! There should be a check upon Workflow creation for circular
# ! references
[docs] class VariableReference(Generic[_T]): """A reference to a variable. Attributes: set_path: A list of strings indicating the path to the variable to be set. get_path: A list of strings indicating the path to the variable to be obtained. get_paths: A list of lists of strings each indicating a path to a variable to be obtained. constant: A value to be used to set the variable. composer: A function that takes in an ``AttributePath`` and ``AttributePaths`` and returns a value. Example: Evaluate the value of a VariableReference >>> from autojob.parametrizations import VariableReference >>> context = { ... "a": { ... "b": 4, ... } ... } >>> ref = VariableReference( ... set_path=["a"], ... get_path=["a", "b"], ... constant=4, ... ) >>> ref.evaluate(context) 4 Example: Set a value using a VariableReference >>> from autojob.parametrizations import VariableReference >>> context = { ... "a": { ... "b": 4, ... } ... } >>> ref = VariableReference( ... set_path=["a"], ... get_path=["a", "b"], ... constant=4, ... ) >>> class Object(object): ... pass >>> target_object = Object() >>> target_dict = {} >>> ref.set_input_value(context, target_object) >>> target_object.a 4 >>> ref.set_input_value(context, target_dict) >>> target_dict["a"] 4 """ def __init__( self, *, set_path: AttributePath, get_path: AttributePath | None = None, get_paths: AttributePaths | None = None, constant: Any = None, composer: Callable | None = None, ) -> None: """Instantiate a ``VariableReference``. Args: set_path: An ``AttributePath`` indicating the variable to set. get_path: An ``AttributePath`` indicating the source variable. Defaults to None. get_paths: A list of ``AttributePath`` s, each of which will be combined to the source variable. Defaults to None. constant: A constant value used to set the variable. Defaults to None. composer: A function that accepts the value of the source variable(s) and returns a value to be used to set the variable. Defaults to None. """ self.set_path = set_path self.get_path = get_path self.get_paths = get_paths self.constant = constant self.composer = composer super().__init__() @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: """Get a Pydantic schema.""" set_path_schema = handler.generate_schema(AttributePath) get_path_schema = handler.generate_schema(AttributePath | None) get_paths_schema = handler.generate_schema(AttributePaths | None) constant_schema = handler.generate_schema(Any) # ! composer not supported and Callable not JSON serializable # ! so just use None (instead of Callable | None) composer_schema = handler.generate_schema(None) def _set_path( v: VariableReference[Any], handler: ValidatorFunctionWrapHandler ) -> VariableReference[Any]: v.set_path = handler(v.set_path) return v def _get_path( v: VariableReference[Any], handler: ValidatorFunctionWrapHandler ) -> VariableReference[Any]: v.get_path = ( handler(v.get_path) if v.get_path is not None else v.get_path ) return v def _get_paths( v: VariableReference[Any], handler: ValidatorFunctionWrapHandler ) -> VariableReference[Any]: v.get_paths = handler(v.get_paths) return v def _constant( v: VariableReference[Any], handler: ValidatorFunctionWrapHandler ) -> VariableReference[Any]: v.constant = handler(v.constant) return v def _composer( v: VariableReference[Any], handler: ValidatorFunctionWrapHandler ) -> VariableReference[Any]: v.composer = handler(v.composer) return v python_schema = core_schema.chain_schema( [ core_schema.is_instance_schema(cls), core_schema.no_info_wrap_validator_function( _set_path, set_path_schema ), core_schema.no_info_wrap_validator_function( _get_path, get_path_schema ), core_schema.no_info_wrap_validator_function( _get_paths, get_paths_schema ), core_schema.no_info_wrap_validator_function( _constant, constant_schema ), core_schema.no_info_wrap_validator_function( _composer, composer_schema ), ] ) return core_schema.json_or_python_schema( json_schema=core_schema.chain_schema( [ core_schema.typed_dict_schema( { "set_path": core_schema.typed_dict_field( set_path_schema ), "get_path": core_schema.typed_dict_field( set_path_schema, required=False, ), "get_paths": core_schema.typed_dict_field( get_paths_schema, required=False, ), # ! Use default JSON caster "constant": core_schema.typed_dict_field( constant_schema, required=False, ), "composer": core_schema.typed_dict_field( composer_schema, required=False, ), } ), core_schema.no_info_before_validator_function( lambda data: VariableReference( set_path=data["set_path"], get_path=data.get("get_path", None), get_paths=data.get("get_paths", None), constant=data.get("constant", None), composer=data.get("composer", None), ), python_schema, ), ] ), python_schema=python_schema, )
[docs] def evaluate(self, context: _Referenceable) -> _T: """Evaluate a variable reference in the given context. Args: context: A dictionary (or object) containing values to be used to evaluate the ``VariableReference``. Raises: NotImplementedError: ``get_paths`` and ``composer`` ``VariableReference`` s are not supported. Returns: The value. """ if self.get_path is not None: value: _T = getattrpath( context, self.get_path, ) elif not all(x is None for x in (self.get_paths, self.composer)): msg = "Multiple get paths and composers are not yet implemented" raise NotImplementedError(msg) else: value = self.constant return value
[docs] def set_input_value( self, context: dict[str, Any], shell: _Referenceable ) -> None: """Set the value of a key specified by the ``VariableReference``. This method modifies ``shell`` in place. Args: context: A dictionary containing values to be used to evaluate the ``VariableReference``. shell: A dictionary or object containing values to be set. """ to_set = self.set_path[-1] to_get = getattrpath(shell, self.set_path[:-1]) value = self.evaluate(context) if value == Unset: logger.info(f"Unsetting value: {to_set}") if isinstance(to_get, Mapping): del to_get[to_set] else: delattr(to_get, to_set) else: logger.info(f"Setting value: {to_set} to: {value}") if isinstance(to_get, Mapping): to_get[to_set] = value else: setattr(to_get, to_set, value)
[docs] def getattrpath(obj: _Referenceable, path: Iterable[str]) -> Any: """Access an attribute or dictionary value with an attribute path. Args: obj: A dictionary or object. path: A iterable of strings indicating the sequence of attributes or dictionary keys pointing to the value to get. Returns: The attribute or dictionary value. """ def _get(attr, o) -> _T: if isinstance(o, Mapping): if attr not in o: o[attr] = {} return o.get(attr) return getattr(o, attr) return reduce( lambda _obj, _name: _get(_name, _obj), path, obj, )
def _find_targets_and_sources( task: "TaskBase", ) -> list[tuple[list[str], dict[str, Any]]]: targets_and_sources: list[tuple[list[str], dict[str, Any]]] = [] for name, attr in task.model_dump(exclude_none=True).items(): if not name.endswith("_inputs"): continue targets_and_sources.append(([name], attr)) return targets_and_sources # TODO: Add support for task_mods, opt_mods, and analysis_mods
[docs] def create_parametrization( previous: "TaskBase", calc_mods: dict[str, Any] | None = None, sched_mods: dict[str, Any] | None = None, exclude_metadata: Iterable[str] | None = None, ) -> list[VariableReference[Any]]: """Create a parametrization from parameter mod and a previous task. Args: previous: A :class:`.calculation.Calculation` representing the previous calculation. calc_mods: A dictionary containing modifications to calculator **parameters**. Defaults to an empty dictionary. sched_mods: A dictionary containing modifications to scheduler inputs. Defaults to an empty dictionary. exclude_metadata: A list of metadata fields to exclude from the parametrization. Returns: A list of ``VariableReference`` s that can be used to set the values of the new calculation. Warning: When specifying `sched_mods`, be wary of setting mutually exclusive scheduler parameters (e.g, `mem` and `mem_per_cpu` or `cores` and `cores_per_node`). For example, if the `mem` parameter is set and one wants to set the `mem_per_cpu` parameter, set the `mem` key to `Unset` in `sched_mods` in addition to setting the `mem_per_cpu` key. Note: When setting the input parameters for the restart task, this function will assume that any top-level task attribute suffixed with `_inputs` is an input. For example, when restarting a :class:`.Calculation`, `attr`:`.Calculation.task_inputs`, `attr`:`.Calculation.calculation_inputs`, and `attr`:`.Calculation.scheduler_inputs` will be carried over. When restarting a :class:`.MolecularDynamics` task, `attr`:`.Calculation.md_inputs` will be carried over in addition to the three aforementioned inputs. """ calc_mods = calc_mods or {} sched_mods = sched_mods or {} exclude_metadata = set(exclude_metadata) if exclude_metadata else None metadata = previous.task_metadata.model_dump(exclude=exclude_metadata) parametrization: list[VariableReference] = [] # order matters since earlier VariableReferences will be overwritten by # later ones targets_and_sources = [ (["task_metadata"], metadata), ] targets_and_sources.extend(_find_targets_and_sources(previous)) if hasattr(previous, "calculation_inputs"): targets_and_sources.append( # calc_mods only modify calc_params (["calculation_inputs", "calc_params"], calc_mods), ) if hasattr(previous, "scheduler_inputs"): targets_and_sources.append( (["scheduler_inputs"], sched_mods), ) for target, source in targets_and_sources: for input_, value in source.items(): set_path = [*target, input_] parametrization.append( VariableReference(set_path=set_path, constant=value) ) return parametrization