from __future__ import annotations
import pickle
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    Literal,
    Mapping,
    Sequence,
    overload,
)
import ase
import ase.io
import lmdb
import numpy as np
from ase import Atoms
from ase.data import chemical_symbols
from typing_extensions import Self, override
from yaml import dump
from .database import DatabaseEntry
from .utils import (
    LazyMapping,
    choose_n,
    freeze_dict,
    intersect,
    k_fold_split,
    random_split,
    split_keeping_ratio,
)
[docs]class AtomsDataset(ABC, Sequence[Atoms]):
    """
    An abstract base class for datasets of :class:`ase.Atoms` objects.
    This class provides a common interface for interacting with datasets of
    atomic structures, abstracting over the underlying storage mechanism.
    The two current concrete implementations are
    :class:`~load_atoms.atoms_dataset.InMemoryAtomsDataset`
    and :class:`~load_atoms.atoms_dataset.LmdbAtomsDataset`.
    """
    def __init__(self, description: DatabaseEntry | None = None):
        self.description = description
    @property
    @abstractmethod
    def structure_sizes(self) -> np.ndarray:
        """
        An array containing the number of atoms in each structure, such that:
        .. code-block:: python
            for idx, structure in enumerate(dataset):
                assert len(structure) == dataset.structure_sizes[idx]
        """
[docs]    @abstractmethod
    def __len__(self) -> int:
        """The number of structures in the dataset.""" 
    @overload
    def __getitem__(self, index: int) -> Atoms:
        ...
    @overload
    def __getitem__(
        self, index: list[int] | list[bool] | np.ndarray | slice
    ) -> Self:
        ...
[docs]    def __getitem__(  # type: ignore
        self,
        index: int | list[int] | np.ndarray | slice,
    ) -> Atoms | Self:
        r"""
        Get the structure(s) at the given index(es).
        If a single :class:`int` is provided, the corresponding structure is
        returned:
        .. code-block:: pycon
            >>> QM7 = load_dataset("QM7")
            >>> QM7[0]
            Atoms(symbols='CH4', pbc=False)
        If a :class:`slice` is provided, a new :class:`AtomsDataset` is returned
        containing the structures in the slice:
        .. code-block:: pycon
            >>> QM7[:5]
            Dataset:
                structures: 5
                atoms: 32
                species:
                    H: 68.75%
                    C: 28.12%
                    O: 3.12%
                properties:
                    per atom: ()
                    per structure: (energy)
        If a :class:`list` or :class:`numpy.ndarray` of :class:`int`\ s is
        provided, a new :class:`AtomsDataset` is returned containing the
        structures at the given indices:
        .. code-block:: pycon
            >>> len(QM7[[0, 2, 4]])
            3
        If a :class:`list` or :class:`numpy.ndarray` of :class:`bool`\ s is
        provided with the same length as the dataset, a new
        :class:`AtomsDataset` is returned containing the structures where the
        boolean is :code:`True`
        (see also :func:`~load_atoms.AtomsDataset.filter_by`):
        .. code-block:: pycon
            >>> bool_idx = [
            ...     len(s) > 10 for s in QM7
            ... ]
            >>> len(QM7[bool_idx]) == sum(bool_idx)
            True
        Parameters
        ----------
        index
            The index(es) to get the structure(s) at.
        """
        if isinstance(index, slice):
            idxs = range(len(self))[index]
            return self._index_subset(idxs)
        if isinstance(index, list):
            if all(isinstance(i, bool) for i in index):
                if len(index) != len(self):
                    raise ValueError(
                        "Boolean index list must be the same length as the "
                        "dataset."
                    )
                return self._index_subset([i for i, b in enumerate(index) if b])
            else:
                return self._index_subset(index)
        if isinstance(index, np.ndarray):
            return self._index_subset(np.arange(len(self))[index])
        index = int(index)
        return self._index_structure(index) 
    def __iter__(self) -> Iterator[ase.Atoms]:
        for i in range(len(self)):
            yield self._index_structure(i)
    @abstractmethod
    def _index_structure(self, index: int) -> Atoms:
        """Get the structure at the given index."""
    @abstractmethod
    def _index_subset(self, idxs: Sequence[int]) -> Self:
        """Get a new dataset containing the structures at the given indices."""
    def __repr__(self) -> str:
        name = "Dataset" if not self.description else self.description.name
        per_atom_properties = (
            "("
            + ", ".join(
                sorted(
                    set(self.arrays.keys())
                    - {
                        "numbers",
                        "positions",
                    }
                )
            )
            + ")"
        )
        per_structure_properties = (
            "(" + ", ".join(sorted(self.info.keys())) + ")"
        )
        species_counts = self.species_counts()
        species_percentages = {
            symbol: f"{count / self.n_atoms:0.2%}"
            for symbol, count in sorted(
                species_counts.items(), key=lambda item: item[1], reverse=True
            )
        }
        return dump(
            {
                name: {
                    "structures": f"{len(self):,}",
                    "atoms": f"{self.n_atoms:,}",
                    "species": species_percentages,
                    "properties": {
                        "per atom": per_atom_properties,
                        "per structure": per_structure_properties,
                    },
                }
            },
            sort_keys=False,
            indent=4,
        )
[docs]    def write(
        self,
        path: Path | str,
        format: str | None = None,
        append: bool = False,
        **kwargs: Any,
    ):
        """
        Write the dataset to a file, using :func:`ase.io.write`.
        Parameters
        ----------
        path
            The path to write the dataset to.
        format
            The format to write the dataset in.
        append
            Whether to append to the file.
        kwargs
            Additional keyword arguments to pass to :func:`ase.io.write`.
        """
        ase.io.write(
            path,
            self,
            format=format,  # type: ignore
            append=append,
            **kwargs,
        ) 
    @property
    @abstractmethod
    def info(self) -> Mapping[str, Any]:
        r"""
        Get a mapping from keys that are shared across all structures'
        ``.info`` attributes to the concatenated corresponding values.
        The returned mapping conforms to:
        .. code-block:: python
            for key, value in dataset.info.items():
                for i, structure in enumerate(dataset):
                    assert structure.info[key] == value[i]
        """
    @property
    @abstractmethod
    def arrays(self) -> Mapping[str, np.ndarray]:
        """
        Get a mapping from each structure's :code:`.arrays` keys to arrays.
        The returned mapping conforms to:
        .. code-block:: python
            for key, value in dataset.arrays.items():
                assert value.shape[0] == dataset.n_atoms
                assert value == np.vstack(
                    [structure.arrays[key] for structure in dataset]
                )
        """
[docs]    @classmethod
    @abstractmethod
    def save(
        cls,
        path: Path,
        structures: Iterable[Atoms],
        description: DatabaseEntry | None = None,
    ):
        """
        Save the dataset to a file.
        Parameters
        ----------
        path
            The path to save the dataset to.
        structures
            The structures to save to the dataset.
        description
            The description of the dataset.
        """ 
[docs]    @classmethod
    @abstractmethod
    def load(cls, path: Path) -> Self:
        """
        Load the dataset from a file.
        Parameters
        ----------
        path
            The path to load the dataset from.
        """
        pass 
    # concrete methods
[docs]    def species_counts(self) -> Mapping[str, int]:
        """
        Get the number of atoms of each species in the dataset.
        """
        return {
            chemical_symbols[species]: (self.arrays["numbers"] == species).sum()
            for species in np.unique(self.arrays["numbers"])
        } 
    def __contains__(self, item: Any) -> bool:
        """
        Check if the dataset contains a structure.
        Warning: this method is not efficient for large datasets.
        """
        return any(item == other for other in self)
    @property
    def n_atoms(self) -> int:
        r"""
        The total number of atoms in the dataset.
        This is equivalent to the sum of the number of atoms in each structure.
        """
        return int(self.structure_sizes.sum())
[docs]    def filter_by(
        self,
        *functions: Callable[[ase.Atoms], bool],
        **info_kwargs: Any,
    ) -> Self:
        """
        Return a new dataset containing only the structures that match the
        given criteria.
        Parameters
        ----------
        functions
            Functions to filter the dataset by. Each function should take an
            ASE Atoms object as input and return a boolean.
        info_kwargs
            Keyword arguments to filter the dataset by. Only atoms objects with
            matching info keys and values will be returned.
        Example
        -------
        Get small, amorphous structures with large forces:
        .. code-block:: pycon
            :emphasize-lines: 3-7
            >>> from load_atoms import load_dataset
            >>> dataset = load_dataset("C-GAP-17")
            >>> dataset.filter_by(
            ...     lambda structure: len(structure) < 50,
            ...     lambda structure: structure.arrays["force"].max() > 5,
            ...     config_type="bulk_amo"
            ... )
            Dataset:
                structures: 609
                atoms: 23,169
                species:
                    C: 100.00%
                properties:
                    per atom: (force)
                    per structure: (config_type, detailed_ct, split, energy)
        """
        def matches_info(structure: ase.Atoms) -> bool:
            for key, value in info_kwargs.items():
                if structure.info.get(key, None) != value:
                    return False
            return True
        functions = (*functions, matches_info)
        def the_filter(structure: ase.Atoms) -> bool:
            return all(f(structure) for f in functions)
        index = [i for i, structure in enumerate(self) if the_filter(structure)]
        return self[index] 
[docs]    def random_split(
        self,
        splits: Sequence[float] | Sequence[int],
        seed: int = 42,
        keep_ratio: str | None = None,
    ) -> list[Self]:
        r"""
        Randomly split the dataset into multiple, disjoint parts.
        Parameters
        ----------
        splits
            The number of structures to put in each split.
            If a list of :class:`float`\ s, the splits will be
            calculated as a fraction of the dataset size.
        seed
            The random seed to use for shuffling the dataset.
        keep_ratio
            If not :code:`None`, splits will be generated to maintain the
            ratio of structures in each split with the specified :code:`.info`
            value.
        Returns
        -------
        list[Self]
            A list of new datasets, each containing a subset of the original
        Examples
        --------
        Split a :code:`dataset` into 80% training and 20% test sets:
        >>> train, test = dataset.random_split([0.8, 0.2])
        Split a :code:`dataset` into 3 parts:
        >>> train, val, test = dataset.random_split([1_000, 100, 100])
        Maintain the ratio of :code:`config_type` values in each split:
        .. code-block:: pycon
            :emphasize-lines: 16-19
            >>> from load_atoms import load_dataset
            >>> import numpy as np
            >>> # helper function
            >>> def ratios(thing):
            ...     values, counts = np.unique(thing, return_counts=True)
            ...     max_len = max(len(str(v)) for v in values)
            ...     for v, c in zip(values, counts / counts.sum()):
            ...         print(f"{v:>{max_len}}: {c:>6.2%}")
            ...
            >>> dataset = load_dataset("C-GAP-17")
            >>> ratios(dataset.info["config_type"])
              bulk_amo: 75.28%
            bulk_cryst:  8.83%
                 dimer:  0.66%
              surf_amo: 15.23%
            >>> train, val, test = dataset.random_split(
            ...     [0.6, 0.2, 0.2],
            ...     keep_ratio="config_type"
            ... )
            >>> ratios(train.info["config_type"])
              bulk_amo: 75.28%
            bulk_cryst:  8.83%
                 dimer:  0.66%
              surf_amo: 15.23%
        """
        if keep_ratio is None:
            return [
                self[split]
                for split in random_split(range(len(self)), splits, seed)
            ]
        if keep_ratio not in self.info:
            raise KeyError(
                f"Unknown key {keep_ratio}. "
                "Available keys are: " + ", ".join(self.info.keys())
            )
        if isinstance(splits[0], int):
            final_sizes: list[int] = splits  # type: ignore
        else:
            final_sizes = [int(s * len(self)) for s in splits]
        normalised_fractional_splits = [s / sum(splits) for s in splits]
        split_idxs = split_keeping_ratio(
            range(len(self)),
            group_ids=self.info[keep_ratio],
            splitting_function=partial(
                random_split, seed=seed, splits=normalised_fractional_splits
            ),
        )
        return [
            self[choose_n(split, size, seed)]
            for split, size in zip(split_idxs, final_sizes)
        ] 
[docs]    def k_fold_split(
        self,
        k: int = 5,
        fold: int = 0,
        shuffle: bool = True,
        seed: int = 42,
        keep_ratio: str | None = None,
    ) -> tuple[Self, Self]:
        """
        Generate (an optionally shuffled) train/test split for cross-validation.
        Parameters
        ----------
        k
            The number of folds to use.
        fold
            The fold to use for testing.
        shuffle
            Whether to shuffle the dataset before splitting.
        seed
            The random seed to use for shuffling the dataset.
        keep_ratio
            If not :code:`None`, splits will be generated to maintain the
            ratio of structures in each split with the specified :code:`.info`
            value.
        Returns
        -------
        Tuple[Self, Self]
            The train and test datasets.
        Example
        -------
        Basic usage:
        .. code-block:: pycon
            :emphasize-lines: 2
            >>> for i in range(5):
            ...     train, test = dataset.k_fold_split(k=5, fold=i)
            ...     ...  # do something, e.g. train a model
        Maintain the ratio of :code:`config_type` values in each split
        (see also :func:`~load_atoms.AtomsDataset.random_split` for a more
        detailed example of this feature):
        .. code-block:: pycon
            >>> train, test = dataset.k_fold_split(
            ...     k=5, fold=0, keep_ratio="config_type"
            ... )
        """
        if k < 2:
            raise ValueError("k must be at least 2")
        fold = fold % k
        if shuffle:
            idxs = np.random.RandomState(seed).permutation(len(self))
        else:
            idxs = np.arange(len(self))
        if keep_ratio is None:
            train_idxs, test_idxs = k_fold_split(idxs.tolist(), k, fold)
        else:
            if keep_ratio not in self.info:
                raise KeyError(
                    f"Unknown key {keep_ratio}. "
                    "Available keys are: " + ", ".join(self.info.keys())
                )
            if not shuffle:
                raise ValueError(
                    "Keep ratio splits are only supported when shuffling."
                )
            group_ids = self.info[keep_ratio][idxs]
            train_idxs, test_idxs = split_keeping_ratio(
                idxs.tolist(), group_ids, partial(k_fold_split, k=k, fold=fold)
            )
        return self[train_idxs], self[test_idxs]  
[docs]class InMemoryAtomsDataset(AtomsDataset):
    """
    An in-memory implementation of :class:`AtomsDataset`.
    Internally, this class wraps a :class:`list` of :class:`ase.Atoms` objects,
    all of which are stored in RAM. Suitable for small to moderately large
    datasets.
    """
    def __init__(
        self,
        structures: list[ase.Atoms],
        description: DatabaseEntry | None = None,
    ):
        super().__init__(description)
        if len(structures) == 1:
            warnings.warn(
                "Creating a dataset with a single structure. "
                "Typically, datasets contain multiple structures - "
                "did you mean to do this?",
                stacklevel=2,
            )
        self._structures = structures
        self._info = _get_info_mapping(structures)
        self._arrays = _get_arrays_mapping(structures)
    @property
    @override
    def info(self) -> LazyMapping[str, Any]:
        return self._info
    @property
    @override
    def arrays(self) -> LazyMapping[str, np.ndarray]:
        return self._arrays
    @property
    @override
    def structure_sizes(self) -> np.ndarray:
        return np.array([len(s) for s in self._structures])
    @override
    def __len__(self) -> int:
        return len(self._structures)
    @override
    def _index_structure(self, index: int) -> Atoms:
        return self._structures[index]
    @override
    def _index_subset(self, idxs: Sequence[int]) -> InMemoryAtomsDataset:
        return InMemoryAtomsDataset([self._index_structure(i) for i in idxs])
    @override
    @classmethod
    def save(
        cls,
        path: Path,
        structures: Iterable[Atoms],
        description: DatabaseEntry | None = None,
    ):
        path.parent.mkdir(parents=True, exist_ok=True)
        to_save = {
            "structures": list(structures),
            "description": description,
        }
        with open(path, "wb") as f:
            pickle.dump(to_save, f)
    @classmethod
    @override
    def load(cls, path: Path) -> InMemoryAtomsDataset:
        with open(path, "rb") as f:
            data = pickle.load(f)
        return cls(**data) 
@dataclass
class LmdbMetadata:
    structure_sizes: np.ndarray
    species_per_structure: list[dict[str, int]]
    per_atom_properties: list[str]
    per_structure_properties: list[str]
[docs]class LmdbAtomsDataset(AtomsDataset):
    r"""
    An LMDB-backed implementation of :class:`AtomsDataset`.
    Internally, this class wraps an :class:`lmdb.Environment` object, which
    stores the dataset in an LMDB database. Suitable for large datasets that
    cannot fit in memory. Accessing data from this dataset type is (marginally)
    slower than for :class:`InMemoryAtomsDataset`\ s, but allows for efficient
    processing of extremely large datasets that cannot otherwise fit in memory.
    .. warning::
        The :class:`ase.Atoms` objects in an LMDB dataset are read-only.
        Modifying the :code:`.info` or :code:`.arrays` of an :class:`ase.Atoms`
        object will have no effect, and will instead throw an error.
    """
    def __init__(self, path: Path, idx_subset: np.ndarray | None = None):
        self.path = path
        # setup lmdb environment, and keep a transaction open for the lifetime
        # of the dataset (to enable fast reads)
        self.env = lmdb.open(
            str(path), readonly=True, lock=False, map_async=True
        )
        self.txn = self.env.begin(write=False)
        super().__init__(
            description=pickle.loads(
                self.txn.get("description".encode("ascii"))
            )
        )
        self.metadata: LmdbMetadata = pickle.loads(
            self.txn.get("metadata".encode("ascii"))
        )
        self.idx_subset = (
            idx_subset
            if idx_subset is not None
            else np.arange(len(self.metadata.structure_sizes))
        )
        # TODO: add warnings to loaders about potential slowness
        self._info = _get_info_mapping(
            structures=self,
            keys=self.metadata.per_structure_properties,
            loader_warning=(
                "LmdbAtomsDatasets do not hold all structure properties in "
                "memory: accessing .info will be slow."
            ),
        )
        self._arrays = _get_arrays_mapping(
            structures=self,
            keys=self.metadata.per_atom_properties + ["numbers", "positions"],
            loader_warning=(
                "LmdbAtomsDatasets do not hold all per-atom properties in "
                "memory: accessing .arrays will be slow."
            ),
        )
    def close(self):
        self.txn.close()
        self.env.close()
    def __enter__(self):
        return self
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()
    def _index_structure(self, index: int) -> Atoms:
        index = self.idx_subset[index]
        pickled_data = self.txn.get(f"{index}".encode("ascii"))
        atoms = pickle.loads(pickled_data)
        error_msg = (
            "The Atoms objects in an LMDB dataset are read-only: "
            "modifying the {} of an Atoms object will have no effect."
        )
        atoms.info = freeze_dict(atoms.info, error_msg.format("info"))
        atoms.arrays = freeze_dict(atoms.arrays, error_msg.format("arrays"))
        return atoms
    @property
    @override
    def structure_sizes(self) -> np.ndarray:
        sizes = self.metadata.structure_sizes
        if self.idx_subset is not None:
            return sizes[self.idx_subset]
        return sizes
    @override
    def __len__(self) -> int:
        return len(self.structure_sizes)
    @override
    def species_counts(self) -> Mapping[str, int]:
        to_sum: list[dict[str, int]] = []
        all_species: set[str] = set()
        for idx in self.idx_subset:
            species_count = self.metadata.species_per_structure[idx]
            to_sum.append(species_count)
            all_species.update(species_count.keys())
        summed = {symbol: 0 for symbol in all_species}
        for species_count in to_sum:
            for symbol, count in species_count.items():
                summed[symbol] += count
        return summed
    @override
    def __contains__(self, item: Any) -> bool:
        warnings.warn(
            "Checking if an LMDB dataset contains a structure is slow "
            "because it requires loading every structure into memory. "
            "Consider using a different dataset format if possible.",
            stacklevel=2,
        )
        return any(item == structure for structure in self)
    @override
    def _index_subset(self, idxs: Sequence[int]) -> LmdbAtomsDataset:
        return LmdbAtomsDataset(self.path, self.idx_subset[idxs])
    @property
    @override
    def info(self) -> LazyMapping[str, Any]:
        return self._info
    @property
    @override
    def arrays(self) -> LazyMapping[str, np.ndarray]:
        return self._arrays
    @classmethod
    @override
    def load(cls, path: Path) -> LmdbAtomsDataset:
        return cls(path)
    @classmethod
    @override
    def save(
        cls,
        path: Path,
        structures: Iterable[Atoms],
        description: DatabaseEntry | None = None,
    ):
        path.mkdir(parents=True, exist_ok=True)
        one_TB = int(1e12)
        env = lmdb.open(str(path), map_size=one_TB)
        with env.begin(write=True) as txn:
            structure_sizes = []
            species_per_structure = []
            per_atom_properties: list[set[str]] = []
            per_structure_properties: list[set[str]] = []
            for idx, structure in enumerate(structures):
                # Save structure
                txn.put(f"{idx}".encode("ascii"), pickle.dumps(structure))
                # Update metadata
                structure_sizes.append(len(structure))
                species_per_structure.append(
                    {
                        chemical_symbols[Z]: (
                            structure.arrays["numbers"] == Z
                        ).sum()
                        for Z in np.unique(structure.arrays["numbers"])
                    }
                )
                per_atom_properties.append(
                    set(structure.arrays.keys()) - {"numbers", "positions"}
                )
                per_structure_properties.append(set(structure.info.keys()))
            # Save metadata
            metadata = LmdbMetadata(
                structure_sizes=np.array(structure_sizes),
                species_per_structure=species_per_structure,
                per_atom_properties=sorted(intersect(per_atom_properties)),
                per_structure_properties=sorted(
                    intersect(per_structure_properties)
                ),
            )
            txn.put("metadata".encode("ascii"), pickle.dumps(metadata))
            # Save description if provided
            if description:
                txn.put(
                    "description".encode("ascii"), pickle.dumps(description)
                )
        env.close() 
class InfoLoader:
    def __init__(self, structures: Iterable[Atoms], warning: str | None = None):
        self.structures = structures
        self.warning = warning
    def __call__(self, key: str) -> np.ndarray:
        if self.warning is not None:
            warnings.warn(self.warning, stacklevel=2)
        return np.array([s.info[key] for s in self.structures])
class ArraysLoader:
    def __init__(self, structures: Iterable[Atoms], warning: str | None = None):
        self.structures = structures
        self.warning = warning
    def __call__(self, key: str) -> np.ndarray:
        if self.warning is not None:
            warnings.warn(self.warning, stacklevel=2)
        return np.concatenate([s.arrays[key] for s in self.structures])
def _get_info_mapping(
    structures: Iterable[Atoms],
    keys: list[str] | None = None,
    loader_warning: str | None = None,
) -> LazyMapping[str, np.ndarray]:
    if keys is None:
        keys = list(intersect(s.info.keys() for s in structures))
    return LazyMapping(keys, InfoLoader(structures, loader_warning))
def _get_arrays_mapping(
    structures: Iterable[Atoms],
    keys: list[str] | None = None,
    loader_warning: str | None = None,
) -> LazyMapping[str, np.ndarray]:
    if keys is None:
        keys = list(intersect(s.arrays.keys() for s in structures))
    return LazyMapping(keys, ArraysLoader(structures, loader_warning))
def summarise_dataset(
    structures: list[Atoms] | AtomsDataset,
    description: DatabaseEntry | None = None,
) -> str:
    if isinstance(structures, AtomsDataset):
        return str(structures)
    return str(InMemoryAtomsDataset(structures, description))
def get_file_extension_and_dataset_class(
    format: Literal["lmdb", "memory"]
) -> tuple[str, type[AtomsDataset]]:
    return {
        "lmdb": ("lmdb", LmdbAtomsDataset),
        "memory": ("pkl", InMemoryAtomsDataset),
    }[format]