Source code for load_atoms.database.backend

"""
The :code:`backend` module is responsible for down/loading datasets by name, 
storing them locally, and serving them to :code:`load-atoms` via the
:func:`~load_atoms.load_dataset` function.
"""

from __future__ import annotations

import shutil
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Iterator

import ase.io
from ase import Atoms
from typing_extensions import override

from load_atoms.atoms_dataset import (
    AtomsDataset,
    get_file_extension_and_dataset_class,
)
from load_atoms.database.database_entry import (
    LICENSE_URLS,
    DatabaseEntry,
)
from load_atoms.database.internet import FileDownload, download, download_all
from load_atoms.progress import Progress, get_progress_for_dataset
from load_atoms.utils import (
    UnknownDatasetException,
    debug_mode,
    frontend_url,
    testing,
)

BASE_GITHUB_URL = "https://github.com/jla-gardner/load-atoms/raw/main/database"


[docs]def load_dataset_by_id(dataset_id: str, root: Path) -> AtomsDataset: """ Load the :class:`AtomsDataset` and corresponding :class:`DatabaseEntry` for the given dataset id, saving the dataset to the given ``root`` directory. Parameters ---------- name The id of the dataset to load. root The root folder to save the structures to. """ # prepare local paths yaml_file_path = root / "database-entries" / f"{dataset_id}.yaml" yaml_file_path.parent.mkdir(parents=True, exist_ok=True) with get_progress_for_dataset(dataset_id) as progress: # down/load the dabase entry for the dataset database_entry = get_database_entry( dataset_id, yaml_file_path, progress ) # get the file extension and dataset class for the dataset extension, dataset_class = get_file_extension_and_dataset_class( database_entry.format ) data_file_path = root / f"{dataset_id}.{extension}" # if the dataset already exists, load it from disk if data_file_path.exists(): with progress.new_task("Reading from disk"): dataset = dataset_class.load(data_file_path) # otherwise, use the importer to get the structures else: # 1. get the Importer class from a suitably down/loaded file importer_type: type[BaseImporter] = get_importer_type( dataset_id, progress ) # 2. download the files to an appropriate directory download_dir_name = importer_type.permanent_download_dirname() use_tmp_dir = ( download_dir_name is None and not debug_mode() and not testing() ) if use_tmp_dir: download_dir = Path(tempfile.mkdtemp()) else: download_dir = ( root / "raw-downloads" / (download_dir_name or dataset_id) ) download_all( importer_type.files_to_download(), download_dir, progress ) # 3. use the importer to get the structures (removing annoying calc) def iterator(): for structure in importer_type.get_structures( download_dir, progress ): structure.calc = None yield structure try: dataset_class.save(data_file_path, iterator(), database_entry) except Exception as e: # remove the partially created dataset if data_file_path.exists(): if data_file_path.is_dir(): shutil.rmtree(data_file_path) else: data_file_path.unlink() raise ValueError( "Failed to import dataset: please report an issue at " "https://github.com/jla-gardner/load-atoms/issues if you " "think this is a bug." ) from e dataset = dataset_class.load(data_file_path) # 4. clean up the temporary directory if necessary if use_tmp_dir: shutil.rmtree(download_dir) # add the usage information to the progress bar log_usage_information(database_entry, progress) progress.refresh() return dataset
class BaseImporter(ABC): """ Base class to inherit from to create new, dataset-specific importers. Parameters ---------- files_to_download A list of :class:`FileDownload` s tmp_dirname The name of the temporary directory to download the files to. cleanup Whether to clean up the temporary directory after processing. """ @classmethod def files_to_download(cls) -> list[FileDownload]: return [] @classmethod @abstractmethod def get_structures( cls, tmp_dir: Path, progress: Progress, ) -> Iterator[Atoms]: """ Iterate over :class:`ase.Atoms` objects. All files passed to the base class will have already been downloaded and verified when this is called. Parameters ---------- tmp_dir The temporary directory where downloaded files are stored. Yields ------ Atoms An iterator of ASE Atoms objects processed from the downloaded files """ @classmethod def permanent_download_dirname(cls) -> str | None: """ Get a path to the directory where the files should be saved. If ``None`` (the default), is returned, the files will be downloaded to a temporary directory, and removed after the dataset is imported. """ return None class SingleFileImporter(BaseImporter): @classmethod @abstractmethod def file_to_download(cls) -> FileDownload: ... @classmethod def files_to_download(cls) -> list[FileDownload]: return [cls.file_to_download()] @override @classmethod def get_structures( cls, tmp_dir: Path, progress: Progress ) -> Iterator[Atoms]: file_path = tmp_dir / Path(cls.files_to_download()[0].local_name) with progress.new_task(f"Reading {file_path.resolve()}"): for atoms in cls._read_file(file_path): yield cls.process_atoms(atoms) @classmethod def process_atoms(cls, atoms: Atoms) -> Atoms: return atoms @classmethod def _read_file(cls, file_path: Path) -> Iterator[Atoms]: yield from ase.io.iread(file_path, index=":") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~ HELPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~ # def get_database_entry( dataset_id: str, yaml_file_path: Path, progress: Progress, ) -> DatabaseEntry: from load_atoms import __version__ as load_atoms_version if not yaml_file_path.exists(): try: download( DatabaseEntry.remote_url_for_yaml(dataset_id), yaml_file_path, progress, ) except Exception as e: raise UnknownDatasetException(dataset_id) from e db_entry = DatabaseEntry.from_yaml_file(yaml_file_path) if ( db_entry.minimum_load_atoms_version is not None and db_entry.minimum_load_atoms_version > load_atoms_version ): raise Exception( f"Dataset {dataset_id} requires load-atoms version " f">={db_entry.minimum_load_atoms_version} " f"(current version: {load_atoms_version}). " "Please upgrade load-atoms to load this dataset " "(e.g. `pip install --upgrade load-atoms`)." ) return db_entry def get_importer_type( dataset_id: str, progress: Progress, ) -> type[BaseImporter]: importer_name = DatabaseEntry.importer_file_stem(dataset_id) expected_importer_path = ( Path(__file__).parent / "importers" / f"{importer_name}.py" ) if not expected_importer_path.exists(): try: download( DatabaseEntry.remote_url_for_importer(dataset_id), expected_importer_path, progress, ) except Exception as e: # couldn't download the importer: raise UnknownDatasetException(dataset_id) from e try: return __import__( f"load_atoms.database.importers.{importer_name}", fromlist=["Importer"], ).Importer except Exception as e: raise Exception( f"Unable to load dataset {dataset_id} due to a problem loading " "the dataset's importer file. Please try updating load-atoms:\n" " pip install --upgrade load-atoms\n" "If the problem persists, please report an issue at:\n" " https://github.com/jla-gardner/load-atoms/issues" ) from e def log_usage_information(info: DatabaseEntry, progress: Progress): progress.add_text("\n") name = f"[bold]{info.name}[/bold]" if info.license is not None: style = f"dodger_blue2 link={LICENSE_URLS[info.license]} underline" progress.add_text( f"The {name} dataset is covered by the " f"[{style}]{info.license}[/] license." ) if info.citation is not None: progress.add_text( f"Please cite the {name} dataset " "if you use it in your work." ) progress.add_text(f"For more information about the {name} dataset, visit:") url = frontend_url(info) url_style = f"dodger_blue2 underline link={url}" progress.add_text(f"[{url_style}]load-atoms/{info.name}") def unzip_file(file_path: Path, progress: Progress) -> Path: """Unzip a file and return the path to the extracted directory. Parameters ---------- file_path The path to the file to unzip. progress A :class:`Progress` object to track the unzip progress. """ extract_to = file_path.parent / f"{file_path.name}-extracted" if not extract_to.exists(): with progress.new_task( f"Unzipping {file_path.resolve()}", ): shutil.unpack_archive(file_path, extract_dir=extract_to) return extract_to def rename(atoms: Atoms, mapping: dict[str, str]) -> Atoms: """Rename the properties of an Atoms object.""" for old_name, new_name in mapping.items(): if old_name in atoms.arrays: atoms.arrays[new_name] = atoms.arrays.pop(old_name) elif old_name in atoms.info: atoms.info[new_name] = atoms.info.pop(old_name) return atoms