Source code for atomistics.calculators.wrapper

"""
A wrapper for mapping between functions that evaluate a single structure to those
that evaluate a task dictionary.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, cast

from ase import Atoms

from atomistics.calculators.interface import TaskEnum, TaskName, TaskOutputEnum

if TYPE_CHECKING:
    from atomistics.calculators.interface import ResultsDict, SimpleEvaluator, TaskDict


def _convert_task_dict(
    old_task_dict: dict[TaskName, dict[str, Atoms] | Atoms],
) -> TaskDict:
    """
    Converts the existing task dictionaries of the format
    `{result_type_string: {structure_label_string: structure, ...}, ...}`
    to the new format
    `{structure_label_string: (structure, [result_type_string, ...]), ...}`.

    Can be removed if/when the rest of the codebase passing in these task
    dictionaries gets updated to the new format.
    """
    task_dict: TaskDict = {}
    for method_name, subdict in old_task_dict.items():
        if not isinstance(subdict, dict):
            subdict = {"label_hidden": subdict}
        for label, structure in subdict.items():
            try:
                task_dict[label][1].append(method_name)
            except KeyError:
                task_dict[label] = (structure, [method_name])
    return task_dict


[docs] def as_task_dict_evaluator( calculate: SimpleEvaluator, ) -> Callable[..., ResultsDict]: """ Takes a callable that acts on a single structure and a (string) list of tasks to and maps it to a function that operates on a task-list dictionary of structures, structure labels, and the same task list strings. Similarly, maps the output from a single dictionary of task-name-related-output-labels to a nested dictionary using both the output labels and the structure labels. Args: calculate (SimpleEvaluator): The function that interprets structures into physical properties. Returns: callable: The function operating on a different space. """ def evaluate_with_calculator( task_dict: dict[TaskName, dict[str, Atoms]], # TODO: Make workflows pass task dicts: dict[str, TaskSpec] ~ TaskDict, *calculate_args: Any, **calculate_kwargs: Any, ) -> ResultsDict: """ Evaluate all structures in ``task_dict`` and aggregate results. Converts the legacy task-dict format to the new per-structure format, calls the wrapped ``calculate`` function for each structure, and maps the per-structure outputs back to the nested results dictionary. Args: task_dict (dict[TaskName, dict[str, Atoms]]): Mapping from task names to structure label → structure dicts. *calculate_args: Positional arguments forwarded to ``calculate``. **calculate_kwargs: Keyword arguments forwarded to ``calculate``. Returns: ResultsDict: Nested mapping ``{output_label: {structure_label: result}}``. """ converted_task_dict = _convert_task_dict( cast(dict[TaskName, dict[str, Atoms] | Atoms], task_dict) ) results_dict: ResultsDict = {} for label, (structure, task_lst) in converted_task_dict.items(): tasks = [TaskEnum(t) for t in task_lst] output = calculate(structure, tasks, *calculate_args, **calculate_kwargs) for task_name in tasks: result_name = TaskOutputEnum(task_name).name if label != "label_hidden": try: results_dict[result_name][label] = output[result_name] except KeyError: results_dict[result_name] = {label: output[result_name]} else: results_dict[result_name] = output[result_name] return results_dict return evaluate_with_calculator