[docs]@dataclassclassTestingConfig:"""Configuration for testing a GraphPES model."""model_path:str"""The path to the ``model.pt`` file."""data:Union[GraphDataset,dict[str,GraphDataset]]# noqa: UP007""" Either: - a single :class:`~graph_pes.data.GraphDataset`. Results will be logged as ``"<prefix>/<metric>"``. - a mapping from names to datasets. Results will be logged as ``"<prefix>/<dataset-name>/<metric>"``, allowing for testing on multiple datasets. """loader_kwargs:dict[str,Any]""" Keyword arguments to pass to the :class:`~graph_pes.data.loader.GraphDataLoader`. Defaults to: .. code-block:: yaml loader_kwargs: batch_size: 2 num_workers: 0 You should tune this to make testing faster. """torch:TorchConfig"""The torch configuration to use for testing."""logger:Union[Literal["auto","csv"],dict[str,Any]]="auto"# noqa: UP007""" The logger to use for logging the test metrics. If ``"auto"``, we will attempt to find the training config from ``<model_path>/../train-config.yaml``, and use the logger from that config. If ``"csv"``, we will use a CSVLogger. If a dictionary, we will instantiate a new :class:`~graph_pes.training.callbacks.WandbLogger` with the provided arguments. """accelerator:str="auto""""The accelerator to use for testing."""prefix:str="testing""""The prefix to use for logging. Individual metrics will be logged as ``<prefix>/<dataset_name>/<metric>``. """defget_logger(self)->Logger:root_dir=Path(self.model_path).parentifself.logger=="csv":returnCSVLogger(save_dir=root_dir,name="")elifisinstance(self.logger,dict):returnWandbLogger(output_dir=root_dir,log_epoch=False,**self.logger,)ifnotself.logger=="auto":raiseValueError(f"Invalid logger: {self.logger}")train_config_path=root_dir/"train-config.yaml"ifnottrain_config_path.exists():raiseValueError(f"Could not find training config at {train_config_path}. ""Please specify a logger explicitly.")withopen(train_config_path)asf:logger_data=yaml.safe_load(f).get("wandb",None)iflogger_dataisNone:returnCSVLogger(save_dir=root_dir,name="")returnWandbLogger(output_dir=root_dir,log_epoch=False,**logger_data)@classmethoddefdefaults(cls)->dict:return{"torch":{"float32_matmul_precision":"high","dtype":"float32"},"loader_kwargs":{"batch_size":2,"num_workers":0},}