# ruff: noqa: UP006, UP007# ^^ NB: dacite parsing requires the old type hint syntax in# order to be compatible with all versions of Python that# we are targeting (3.9+)from__future__importannotationsfromdataclassesimportdataclass,fieldfrompathlibimportPathfromtypingimportAny,Dict,List,Literal,Unionimportyamlfrompytorch_lightningimportCallbackfromgraph_pes.config.sharedimportTorchConfigfromgraph_pes.data.datasetsimportDatasetCollectionfromgraph_pes.graph_pes_modelimportGraphPESModelfromgraph_pes.training.callbacksimportVerboseSWACallbackfromgraph_pes.training.lossimportLoss,TotalLossfromgraph_pes.training.optimportLRScheduler,Optimizer@dataclassclassFittingOptions:"""Options for the fitting process."""pre_fit_model:boolmax_n_pre_fit:Union[int,None]early_stopping_patience:Union[int,None]loader_kwargs:Dict[str,Any]
[docs]@dataclassclassSWAConfig:""" Configuration for Stochastic Weight Averaging. Internally, this is handled by `this PyTorch Lightning callback <https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.StochasticWeightAveraging.html>`__. """lr:float""" The learning rate to use during the SWA phase. If not specified, the learning rate from the end of the training phase will be used. """start:Union[int,float]=0.8""" The epoch at which to start SWA. If a float, it will be interpreted as a fraction of the total number of epochs. """anneal_epochs:int=10""" The number of epochs over which to linearly anneal the learning rate to zero. """strategy:Literal["linear","cos"]="linear""""The strategy to use for annealing the learning rate."""definstantiate_lightning_callback(self):returnVerboseSWACallback(swa_lrs=self.lr,swa_epoch_start=self.start,annealing_epochs=self.anneal_epochs,annealing_strategy=self.strategy,)
@dataclassclassFittingConfig(FittingOptions):"""Configuration for the fitting process."""trainer_kwargs:Dict[str,Any]optimizer:Optimizer=Optimizer(name="AdamW",lr=1e-3,amsgrad=False)scheduler:Union[LRScheduler,None]=Noneswa:Union[SWAConfig,None]=Nonecallbacks:List[Callback]=field(default_factory=list)@dataclassclassGeneralConfig:"""General configuration for a training run."""seed:introot_dir:strrun_id:Union[str,None]torch:TorchConfiglog_level:Literal["DEBUG","INFO","WARNING","ERROR","CRITICAL"]="INFO"progress:Literal["rich","logged"]="rich"# TODO:# - move get_model to utils, call it parse_model# - turn loss into dict with human readable names + move parse_loss to utils@dataclassclassTrainingConfig:""" A schema for a configuration file to train a :class:`~graph_pes.GraphPESModel`. """model:Union[GraphPESModel,Dict[str,GraphPESModel]]data:DatasetCollectionloss:Union[Loss,TotalLoss,Dict[str,Loss],List[Loss]]fitting:FittingConfiggeneral:GeneralConfigwandb:Union[Dict[str,Any],None]### Methods ###defget_data(self)->DatasetCollection:ifisinstance(self.data,DatasetCollection):returnself.dataelifisinstance(self.data,dict):returnDatasetCollection(**self.data)raiseValueError("Expected to be able to parse a DatasetCollection instance or a ""dictionary mapping 'train' and 'valid' keys to GraphDataset ""instances from the data config, but got something else: "f"{self.data}")@classmethoddefdefaults(cls)->dict:withopen(Path(__file__).parent/"training-defaults.yaml")asf:returnyaml.safe_load(f)