- class openff.nagl.GNNModel(config: ModelConfig, chemical_domain: ChemicalDomain | None = None, lookup_tables: dict[str, AtomPropertiesLookupTable] = None)[source]
Bases:
BaseGNNModel
A GNN model for predicting properties of molecules.
- Parameters:
config (
ModelConfig
ordict
) – The configuration for the model.chemical_domain (
ChemicalDomain
ordict
) – The applicable chemical domain for the model.lookup_tables (
dict
) – A dictionary of lookup tables for properties. The keys should be the property names, and the values should be instances ofBaseLookupTable
.
Methods
Compute the trained property for a molecule.
Compute the trained property for a molecule.
Load a model from a file.
Save this model to a file.
Attributes
training
- compute_properties(molecule: Molecule, as_numpy: bool = True, check_domains: bool = False, error_if_unsupported: bool = True, check_lookup_table: bool = True) Dict[str, Tensor] [source]
Compute the trained property for a molecule.
- Parameters:
molecule (
Molecule
) – The molecule to compute the property for.as_numpy (
bool
) – Whether to return the result as a numpy array. IfFalse
, the result will be atorch.Tensor
.check_domains (
bool
) – Whether to check if the molecule is similar to the training data.error_if_unsupported (
bool
) – Whether to raise an error if the molecule is not represented in the training data. This is only used ifcheck_domains
isTrue
. IfFalse
, a warning will be raised instead.check_lookup_table (
bool
) – Whether to check a lookup table for the property values. IfFalse
or if the molecule is not in the lookup table, the property will be computed using the model.
- Returns:
result (
Dict[str
,torch.Tensor]
orDict[str
,numpy.ndarray]
)
- compute_property(molecule: Molecule, readout_name: str | None = None, as_numpy: bool = True, check_domains: bool = False, error_if_unsupported: bool = True, check_lookup_table: bool = True)[source]
Compute the trained property for a molecule.
- Parameters:
molecule (
Molecule
) – The molecule to compute the property for.readout_name (
str
) – The name of the readout property to return. If this is not given and there is only one readout, the result of that readout will be returned.as_numpy (
bool
) – Whether to return the result as a numpy array. IfFalse
, the result will be atorch.Tensor
.check_domains (
bool
) – Whether to check if the molecule is similar to the training data.error_if_unsupported (
bool
) – Whether to raise an error if the molecule is not represented in the training data. This is only used ifcheck_domains
isTrue
. IfFalse
, a warning will be raised instead.check_lookup_table (
bool
) – Whether to check a lookup table for the property values. IfFalse
or if the molecule is not in the lookup table, the property will be computed using the model.
- Returns:
result (
torch.Tensor
ornumpy.ndarray
)
- classmethod from_yaml(filename)[source]
- classmethod load(model: str, eval_mode: bool = True, **kwargs)[source]
Load a model from a file.
- Parameters:
model (
str
) – The path to the model to load. This should be a file containing a dictionary of hyperparameters and a state dictionary, with the keys “hyperparameters” and “state_dict”. This can be created using the save method.eval_mode (
bool
) – Whether to set the model to evaluation mode.**kwargs – Additional keyword arguments to pass to torch.load.
- Returns:
model (
GNNModel
)
Examples
>>> model.save("model.pt") >>> new_model = GNNModel.load("model.pt")
Notes
This method is not compatible with normal Pytorch models saved with
torch.save
, as it expects a dictionary of hyperparameters and a state dictionary.
- save(path: str)[source]
Save this model to a file.
- Parameters:
path (
str
) – The path to save this file to.
Examples
>>> model.save("model.pt") >>> new_model = GNNModel.load("model.pt")
Notes
This method writes a dictionary of the hyperparameters and the state dictionary, with the keys “hyperparameters” and “state_dict”.