GNNModel

class openff.nagl.GNNModel(convolution_architecture: Union[str, BaseGCNStack], n_convolution_hidden_features: int, n_convolution_layers: int, n_readout_hidden_features: int, n_readout_layers: int, activation_function: Union[str, ActivationFunction], postprocess_layer: Union[str, PostprocessLayer], readout_name: str, learning_rate: float, atom_features: Tuple[AtomFeature, ...], bond_features: Tuple[BondFeature, ...], loss_function: Callable = rmse_loss, convolution_dropout: float = 0, readout_dropout: float = 0)[source]

Bases: BaseGNNModel

A model that applies a graph convolutional step followed by pooling and readout steps.

Parameters
  • convolution_architecture (Union[str, BaseGCNStack]) – The graph convolution architecture. This can be given either as a class, e.g. SAGEConvStack or as a string, e.g. "SAGEConv".

  • n_convolution_hidden_features (int) – The number of features in each of the hidden convolutional layers.

  • n_convolution_layers (int) – The number of hidden convolutional layers to generate. These are the layers in the convolutional module between the input layer and the pooling layer.

  • n_readout_hidden_features (int) – The number of features in each of the hidden readout layers.

  • n_readout_layers (int) – The number of hidden readout layers to generate. These are the layers between the convolution module’s pooling layer and the readout module’s output layer. The pooling layer may be considered to be both the convolution module’s output layer and the readout module’s input layer.

  • activation_function (Union[str, ActivationFunction]) – The activation function to use for the readout module. This can be given either as a class, e.g. ReLU, or as a string, e.g. "ReLU".

  • postprocess_layer (Union[str, PostprocessLayer]) – The postprocess layer to use. This can be given either as a class, e.g. ComputePartialCharges, or as a string, e.g. "compute_partial_charges".

  • readout_name (str) – A human-readable name for the readout module.

  • learning_rate (float) – The learning rate for optimization.

  • atom_features (Tuple[AtomFeature, ]) – The atom features to use.

  • bond_features (Tuple[BondFeature, ]) – The bond features to use.

  • loss_function (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) – The loss function. This is RMSE by default, but can be any function that takes a predicted and target tensor and returns a scalar loss tensor.

  • convolution_dropout (float) – The dropout probability to use in the convolutional layers.

  • readout_dropout (float) – The dropout probability to use in the readout layers.

Methods

compute_property

Compute the trained property for a molecule.

from_yaml_file

Construct a GNNModel from a YAML file

load

Load a model from a file.

save

Save this model to a file.

Attributes

n_atom_features

The number of features used to represent an atom

compute_property(molecule: Molecule, as_numpy: bool = False) Tensor[source]

Compute the trained property for a molecule.

Parameters
  • molecule (Molecule) – The molecule to compute the property for.

  • as_numpy (bool) – Whether to return the result as a numpy array. If False, the result will be a torch.Tensor.

Returns

result (torch.Tensor or numpy.ndarray)

classmethod from_yaml_file(*paths, **kwargs) GNNModel[source]

Construct a GNNModel from a YAML file

classmethod load(model: str, eval_mode: bool = True)[source]

Load a model from a file.

Parameters
  • model (str) – The path to the model to load. This should be a file containing a dictionary of hyperparameters and a state dictionary, with the keys “hyperparameters” and “state_dict”. This can be created using the save method.

  • eval_mode (bool) – Whether to set the model to evaluation mode.

Returns

model (GNNModel)

Examples

>>> model.save("model.pt")
>>> new_model = GNNModel.load("model.pt")

Notes

This method is not compatible with normal Pytorch models saved with torch.save, as it expects a dictionary of hyperparameters and a state dictionary.

property n_atom_features: int

The number of features used to represent an atom

save(path: str)[source]

Save this model to a file.

Parameters

path (str) – The path to save this file to.

Examples

>>> model.save("model.pt")
>>> new_model = GNNModel.load("model.pt")

Notes

This method writes a dictionary of the hyperparameters and the state dictionary, with the keys “hyperparameters” and “state_dict”.