[docs]@dataclassclassTestingConfig:"""Configuration for testing a GraphPES model."""model_path:str"""The path to the ``model.pt`` file."""data:Union[GraphDataset,dict[str,GraphDataset],None]# noqa: UP007""" Either: - a single :class:`~graph_pes.data.GraphDataset`. Results will be logged as ``"<prefix>/<metric>"``. - a mapping from names to datasets. Results will be logged as ``"<prefix>/<dataset-name>/<metric>"``, allowing for testing on multiple datasets. - ``None``, in which case we will attempt to load the datasets specified during training from the ``<model_path>/../train-config.yaml`` file. """loader_kwargs:dict[str,Any]""" Keyword arguments to pass to the :class:`~graph_pes.data.loader.GraphDataLoader`. Defaults to: .. code-block:: yaml loader_kwargs: batch_size: 2 num_workers: 0 You should tune this to make testing faster. """torch:TorchConfig"""The torch configuration to use for testing."""logger:Union[Literal["auto","csv"],dict[str,Any]]="auto"# noqa: UP007""" The logger to use for logging the test metrics. If ``"auto"``, we will attempt to find the training config from ``<model_path>/../train-config.yaml``, and use the logger from that config. If ``"csv"``, we will use a CSVLogger. If a dictionary, we will instantiate a new :class:`~graph_pes.training.callbacks.WandbLogger` with the provided arguments. """accelerator:str="auto""""The accelerator to use for testing."""prefix:str="test""""The prefix to use for logging. Individual metrics will be logged as ``<prefix>/<dataset_name>/<metric>``. """defget_logger(self)->Logger:root_dir=Path(self.model_path).parentifself.logger=="csv":returnCSVLogger(save_dir=root_dir,name="")elifisinstance(self.logger,dict):returnWandbLogger(output_dir=root_dir,log_epoch=False,**self.logger,)ifnotself.logger=="auto":raiseValueError(f"Invalid logger: {self.logger}")train_config_path=root_dir/"train-config.yaml"ifnottrain_config_path.exists():raiseValueError(f"Could not find training config at {train_config_path}. ""Please specify a logger explicitly.")withopen(train_config_path)asf:logger_data=yaml.safe_load(f).get("wandb",None)iflogger_dataisNone:returnCSVLogger(save_dir=root_dir,name="")returnWandbLogger(output_dir=root_dir,log_epoch=False,**logger_data)defget_datasets(self)->dict[str,GraphDataset]:ifself.dataisNone:# try to find the data object from the training config# this is a bit of a hack, but it works:# 1. load the config file# 2. create a dummy config class# 3. instantiate the config from the dict# 4. extract the data object in the correct formatroot_dir=Path(self.model_path).parenttrain_config_path=root_dir/"train-config.yaml"withopen(train_config_path)asf:train_config=yaml.safe_load(f)@dataclassclassDummyConfig:data:DatasetCollection@classmethoddefdefaults(cls)->dict:return{}data_collection=instantiate_config_from_dict({"data":train_config.get("data",{})},DummyConfig)[1].datadata={"train":data_collection.train,"valid":data_collection.valid,}ifdata_collection.testisnotNone:ifisinstance(data_collection.test,dict):data.update(data_collection.test)else:data["test"]=data_collection.testreturndataelifisinstance(self.data,GraphDataset):return{"test":self.data}else:returnself.data@classmethoddefdefaults(cls)->dict:return{"torch":{"float32_matmul_precision":"high","dtype":"float32"},"loader_kwargs":{"batch_size":2,"num_workers":0},"data":None,}