Source code for graph_pes.config.shared

from __future__ import annotations

from dataclasses import dataclass, fields
from typing import Literal, Protocol, TypeVar

import dacite
import data2objects

from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.models import AdditionModel
from graph_pes.training.loss import Loss, TotalLoss
from graph_pes.utils.misc import nested_merge

T = TypeVar("T")


def _nice_dict_repr(d: dict) -> str:
    def _print_dict(d: dict, indent: int = 0) -> str:
        nice = {
            k: v
            if not isinstance(v, dict)
            else "\n" + _print_dict(v, indent + 1)
            for k, v in d.items()
        }
        return "\n".join([f"{'  ' * indent}{k}: {v}" for k, v in nice.items()])

    return _print_dict(d, 0)


class HasDefaults(Protocol):
    @classmethod
    def defaults(cls) -> dict: ...


HD = TypeVar("HD", bound=HasDefaults)


def instantiate_config_from_dict(
    config_dict: dict, config_class: type[HD]
) -> tuple[dict, HD]:
    """Instantiate a config object from a dictionary."""

    config_dict = nested_merge(config_class.defaults(), config_dict)
    final_dict: dict = data2objects.fill_referenced_parts(config_dict)  # type: ignore

    import graph_pes
    import graph_pes.data
    import graph_pes.interfaces
    import graph_pes.models
    import graph_pes.training
    import graph_pes.training.callbacks
    import graph_pes.training.loss
    import graph_pes.training.opt

    object_dict = data2objects.from_dict(
        final_dict,
        modules=[
            graph_pes,
            graph_pes.models,
            graph_pes.training,
            graph_pes.training.opt,
            graph_pes.training.loss,
            graph_pes.data,
            graph_pes.training.callbacks,
            graph_pes.interfaces,
        ],
    )
    field_names = {f.name for f in fields(config_class)}  # type: ignore
    object_dict = {
        k: v
        for k, v in object_dict.items()
        if k in field_names  # type: ignore
    }

    try:
        return (
            final_dict,
            dacite.from_dict(
                data_class=config_class,
                data=object_dict,
                config=dacite.Config(strict=True),
            ),
        )
    except Exception as e:
        raise ValueError(
            f"Failed to instantiate a config object of type {config_class} "
            f"from the following dictionary:\n{_nice_dict_repr(final_dict)}"
        ) from e


[docs] @dataclass class TorchConfig: """Configuration for PyTorch.""" dtype: Literal["float16", "float32", "float64"] """ The dtype to use for all model parameters and graph properties. Defaults is ``"float32"``. """ float32_matmul_precision: Literal["highest", "high", "medium"] """ The precision to use internally for float32 matrix multiplications. Refer to the `PyTorch documentation <https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html>`__ for details. Defaults to ``"high"`` to favour accelerated learning over numerical exactness for matmuls. """ # noqa: E501
def parse_model( model: GraphPESModel | dict[str, GraphPESModel], ) -> GraphPESModel: if isinstance(model, GraphPESModel): return model elif isinstance(model, dict): if not all(isinstance(m, GraphPESModel) for m in model.values()): _types = {k: type(v) for k, v in model.items()} raise ValueError( "Expected all values in the model dictionary to be " "GraphPESModel instances, but got something else: " f"types: {_types}\n" f"values: {model}\n" ) return AdditionModel(**model) raise ValueError( "Expected to be able to parse a GraphPESModel or a " "dictionary of named GraphPESModels from the model config, " f"but got something else: {model}" ) def parse_loss( loss: Loss | TotalLoss | dict[str, Loss] | list[Loss], ) -> TotalLoss: if isinstance(loss, Loss): return TotalLoss([loss]) elif isinstance(loss, TotalLoss): return loss elif isinstance(loss, dict): return TotalLoss(list(loss.values())) elif isinstance(loss, list): return TotalLoss(loss) raise ValueError( "Expected to be able to parse a Loss, TotalLoss, a list of " "Loss instances, or a dictionary mapping keys to Loss instances from " "the loss config, but got something else:\n" f"{loss}" )