- class openff.nagl.GNNModel(config: ModelConfig, chemical_domain: Optional[ChemicalDomain] = None)[source]
Bases:
BaseGNNModel
Methods
Compute the trained property for a molecule.
Compute the trained property for a molecule.
Load a model from a file.
Save this model to a file.
Attributes
- compute_properties(molecule: Molecule, as_numpy: bool = True, check_domains: bool = False, error_if_unsupported: bool = True) Dict[str, Tensor] [source]
Compute the trained property for a molecule.
- Parameters
molecule (
Molecule
) – The molecule to compute the property for.as_numpy (
bool
) – Whether to return the result as a numpy array. IfFalse
, the result will be atorch.Tensor
.check_domains (
bool
) – Whether to check if the molecule is similar to the training data.error_if_unsupported (
bool
) – Whether to raise an error if the molecule is not represented in the training data. This is only used ifcheck_domains
isTrue
. IfFalse
, a warning will be raised instead.
- Returns
result (
Dict[str
,torch.Tensor]
orDict[str
,numpy.ndarray]
)
- compute_property(molecule: Molecule, readout_name: Optional[str] = None, as_numpy: bool = True, check_domains: bool = False, error_if_unsupported: bool = True)[source]
Compute the trained property for a molecule.
- Parameters
molecule (
Molecule
) – The molecule to compute the property for.readout_name (
str
) – The name of the readout property to return. If this is not given and there is only one readout, the result of that readout will be returned.as_numpy (
bool
) – Whether to return the result as a numpy array. IfFalse
, the result will be atorch.Tensor
.check_domains (
bool
) – Whether to check if the molecule is similar to the training data.error_if_unsupported (
bool
) – Whether to raise an error if the molecule is not represented in the training data. This is only used ifcheck_domains
isTrue
. IfFalse
, a warning will be raised instead.
- Returns
result (
torch.Tensor
ornumpy.ndarray
)
- classmethod from_yaml(filename)[source]
- classmethod load(model: str, eval_mode: bool = True)[source]
Load a model from a file.
- Parameters
- Returns
model (
GNNModel
)
Examples
>>> model.save("model.pt") >>> new_model = GNNModel.load("model.pt")
Notes
This method is not compatible with normal Pytorch models saved with
torch.save
, as it expects a dictionary of hyperparameters and a state dictionary.
- save(path: str)[source]
Save this model to a file.
- Parameters
path (
str
) – The path to save this file to.
Examples
>>> model.save("model.pt") >>> new_model = GNNModel.load("model.pt")
Notes
This method writes a dictionary of the hyperparameters and the state dictionary, with the keys “hyperparameters” and “state_dict”.