from __future__ import annotations
import hashlib
import os
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Callable, Iterable, KeysView, Mapping, Sequence, TypeVar
import numpy as np
from ase import Atoms
FRONTEND_URL = "https://jla-gardner.github.io/load-atoms/datasets/"
BASE_REMOTE_URL = "https://github.com/jla-gardner/load-atoms/raw/main/database/"
def debug_mode() -> bool:
"""Check if debug mode is enabled."""
return os.environ.get("LOAD_ATOMS_DEBUG", "0") == "1"
def testing() -> bool:
"""Check if testing mode is enabled."""
return "PYTEST_CURRENT_TEST" in os.environ
def generate_checksum(file_path: Path | str) -> str:
"""Generate a checksum for a file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()[:12]
def matches_checksum(file_path: Path, hash: str) -> bool:
"""Check if a file matches a given hash."""
return generate_checksum(file_path) == hash
T = TypeVar("T")
Y = TypeVar("Y")
G = TypeVar("G")
[docs]class LazyMapping(Mapping[T, Y]):
"""
A mapping that lazily loads its values.
Concretely, the first time a key is accessed, the loader function is called
to get the value for that key. Subsequent accesses to the same key will
return the same value without calling the loader function again.
Parameters
----------
keys: Sequence[T]
The keys of the mapping.
loader: Callable[[T], Y]
A function that takes a key and returns a value.
Examples
--------
>>> def loader(key):
... print(f"Loading value for key={key}")
... return key * 2
...
>>> mapping = LazyMapping([1, 2, 3], loader)
>>> mapping[3]
Loading value for key=3
6
>>> mapping[3]
6
>>> 1 in mapping
True
>>> 4 in mapping
False
"""
def __init__(
self,
keys: Sequence[T],
loader: Callable[[T], Y],
):
self._keys = keys
self.loader = loader
self._mapping = {}
def __getitem__(self, key: T) -> Y:
if key not in self._keys:
raise KeyError(key)
if key not in self._mapping:
self._mapping[key] = self.loader(key)
return self._mapping[key]
def keys(self):
return KeysView(self)
def __iter__(self):
return iter(self._keys)
def __contains__(self, key: object) -> bool:
return key in self._keys
def __repr__(self) -> str:
return f"LazyMapping(keys={self._keys})"
def __len__(self):
return len(self._keys)
def frontend_url(dataset_info):
"""Get the URL for a dataset's information page."""
return FRONTEND_URL + dataset_info.name + ".html"
class UnknownDatasetException(Exception):
def __init__(self, dataset_id):
super().__init__(f"Unknown dataset: {dataset_id}")
def union(things: Iterable[Iterable]):
"""Get the set union of a list of iterables."""
return set.union(*map(set, things), set())
def intersect(things: Iterable[Iterable]):
"""Get the set intersection of a list of iterables."""
sets = list(map(set, things))
if not sets:
return set()
return set.intersection(*sets)
def lpad(thing: str, length: int = 4, fill: str = " "):
"""Left pad a string with a given fill character."""
sep = f"{fill * length}"
return sep + thing.replace("\n", f"\n{sep}")
def random_split(
things: Sequence[T],
splits: Sequence[int] | Sequence[float],
seed: int = 0,
) -> list[list[T]]:
"""Split a list into random chunks of given sizes."""
if isinstance(splits[0], float):
splits = [int(s * len(things)) for s in splits]
if sum(splits) > len(things):
raise ValueError(
"The sum of the splits cannot exceed the dataset size."
)
cumulative_sum = np.cumsum(splits)
idxs = np.random.RandomState(seed).permutation(len(things))
return [
[things[x] for x in idxs[i:j]]
for i, j in zip([0, *cumulative_sum], cumulative_sum)
]
def k_fold_split(
things: list[T],
k: int,
fold: int,
) -> tuple[list[T], list[T]]:
assert 0 <= fold < k
idxs = np.arange(len(things))
idxs = np.roll(idxs, fold * len(things) // k)
n_test = len(things) // k
train, test = idxs[:-n_test], idxs[-n_test:]
return [things[i] for i in train], [things[i] for i in test]
def split_keeping_ratio(
things_to_split: Sequence[T],
group_ids: Sequence[G],
splitting_function: Callable[[list[T]], Sequence[list[T]]],
):
assert len(things_to_split) == len(group_ids)
# 1. separate into groups
groups: dict[G, list[T]] = defaultdict(list)
for thing, group_id in zip(things_to_split, group_ids):
groups[group_id].append(thing)
# 2. split each group
splits_by_group: dict[G, Sequence[list[T]]] = {
key: splitting_function(value) for key, value in groups.items()
}
# 3. merge the splits, thus keeping the ratio
n_splits = len(splits_by_group[list(groups.keys())[0]])
final_splits: list[list[T]] = [[] for _ in range(n_splits)]
for group_splits in splits_by_group.values():
for i, split in enumerate(group_splits):
final_splits[i].extend(split)
return final_splits
def choose_n(things: Sequence[Y], n: int, seed: int = 42) -> list[Y]:
idxs = np.random.RandomState(seed).permutation(len(things))
return [things[i] for i in idxs[:n]]
_default_error_msg = (
"This dictionary is read-only: any modifications are ignored."
)
[docs]class FrozenDict(dict):
"""
A dictionary that raises an error when any modifications are attempted.
"""
def __init__(self, d: dict, error_msg: str = _default_error_msg):
super().__init__(d)
self.error_msg = error_msg
def __setitem__(self, key, value):
raise ValueError(self.error_msg)
def __delitem__(self, key):
raise ValueError(self.error_msg)
def clear(self):
raise ValueError(self.error_msg)
def update(self, *args, **kwargs):
raise ValueError(self.error_msg)
def setdefault(self, key, default=None):
raise ValueError(self.error_msg)
def pop(self, key, default=None):
raise ValueError(self.error_msg)
def popitem(self):
raise ValueError(self.error_msg)
def freeze_dict(d: dict, error_msg: str = _default_error_msg) -> FrozenDict:
return FrozenDict(d, error_msg)
def matches(thing1: float | np.ndarray, thing2: float | np.ndarray) -> bool:
shape1 = thing1.shape if isinstance(thing1, np.ndarray) else ()
shape2 = thing2.shape if isinstance(thing2, np.ndarray) else ()
return shape1 == shape2 and np.allclose(thing1, thing2)
[docs]def remove_calculator(atoms: Atoms) -> None:
"""
Intelligently remove the atom's calculator object:
- get the results
- move them to the atoms.info/array dictionary if they are not present
- warn if they are present
"""
calc = atoms.calc
if not calc:
return
atoms.calc = None
results = calc.results
mappings = {
"energy": atoms.info,
"forces": atoms.arrays,
"stress": atoms.info,
}
for key, result in results.items():
mapping = mappings.get(key, None)
if mapping is None:
continue
value_in_mapping = mapping.get(key, None)
if value_in_mapping is not None and not matches(
value_in_mapping, result
):
warnings.warn(
f'We found different values for "{key}" on an atoms object '
"and its calculator. We will preserve the value already on "
"the atoms object and discard that from the calculator. ",
stacklevel=2,
)
continue
mapping[key] = result