Source code for openff.nagl.nn._containers

import copy
from typing import List, Optional, Union, Tuple, Callable

import torch

from openff.nagl.molecule._dgl import DGLMolecule, DGLMoleculeBatch

from openff.nagl.nn.activation import ActivationFunction
from openff.nagl.nn.gcn._base import _GCNStackMeta, BaseConvModule
from openff.nagl.nn._sequential import SequentialLayers
from openff.nagl.nn._pooling import PoolingLayer, get_pooling_layer
from openff.nagl.nn.postprocess import PostprocessLayer, _PostprocessLayerMeta


[docs]class ConvolutionModule(torch.nn.Module): def __init__( self, n_input_features: int, hidden_feature_sizes: List[int], architecture: str = "SAGEConv", layer_activation_functions: Optional[List[ActivationFunction]] = None, layer_dropout: Optional[List[float]] = None, layer_aggregator_types: Optional[List[str]] = None, ): super().__init__() self.n_input_features = n_input_features self.hidden_feature_sizes = hidden_feature_sizes self.architecture = architecture self.layer_activation_functions = layer_activation_functions self.layer_dropout = layer_dropout self.layer_aggregator_types = layer_aggregator_types gcn_cls = _GCNStackMeta._get_class(architecture) self.gcn_layers = gcn_cls.with_layers( n_input_features=n_input_features, hidden_feature_sizes=hidden_feature_sizes, layer_activation_functions=layer_activation_functions, layer_dropout=layer_dropout, layer_aggregator_types=layer_aggregator_types, )
[docs] def copy(self, copy_weights: bool = False): copied = type(self)( self.n_input_features, self.hidden_feature_sizes, copy.deepcopy(self.architecture), self.layer_activation_functions, self.layer_dropout, self.layer_aggregator_types, ) if copy_weights: copied.load_state_dict(self.state_dict()) return copied
[docs] def forward(self, molecule: Union[DGLMolecule, DGLMoleculeBatch]): # The input graph will be heterogeneous - the edges are split into forward # edge types and their symmetric reverse counterparts. The convolution layer # doesn't need this information and hence we produce a homogeneous graph for # it to operate on with only a single edge type. homograph = molecule.to_homogenous() feature_tensor = self.gcn_layers(homograph, molecule.atom_features) molecule.graph.ndata[molecule._graph_feature_name] = feature_tensor
@property def _is_dgl(self): return self.gcn_layers._is_dgl def _as_nagl(self, copy_weights: bool = False): copied = self.copy() if self._is_dgl: copied.gcn_layers = copied.gcn_layers._as_nagl(copy_weights=copy_weights) return copied
[docs] @classmethod def from_config( cls, convolution_config, n_input_features: int ): hidden_feature_sizes = [ layer.hidden_feature_size for layer in convolution_config.layers ] layer_activation_functions = [ layer.activation_function for layer in convolution_config.layers ] layer_dropout = [ layer.dropout for layer in convolution_config.layers ] layer_aggregator_types = [ layer.aggregator_type for layer in convolution_config.layers ] return cls( n_input_features, hidden_feature_sizes, architecture=convolution_config.architecture, layer_activation_functions=layer_activation_functions, layer_dropout=layer_dropout, layer_aggregator_types=layer_aggregator_types, )
[docs]class ReadoutModule(torch.nn.Module): """A module that transforms the node features generated by a series of graph convolutions via propagation through a pooling, readout and optional postprocess layer. """ def __init__( self, pooling_layer: PoolingLayer, readout_layers: SequentialLayers, postprocess_layer: Optional[PostprocessLayer] = None, ): """ Args: pooling_layer: The pooling layer that will concatenate the node features computed by a graph convolution into appropriate extended features (e.g. bond or angle features). The concatenated features will be provided as input to the dense readout layers. readout_layers: The dense NN readout layers to apply to the output of the pooling layers. postprocess_layer: A (optional) postprocessing layer to apply to the output of the readout layers """ super().__init__() self.pooling_layer = get_pooling_layer(pooling_layer) self.readout_layers = readout_layers if postprocess_layer is not None: if not isinstance(postprocess_layer, PostprocessLayer): postprocess_layer = _PostprocessLayerMeta._get_object(postprocess_layer) self.postprocess_layer = postprocess_layer
[docs] def forward(self, molecule: Union[DGLMolecule, DGLMoleculeBatch]) -> torch.Tensor: x = self._forward_unpostprocessed(molecule) if self.postprocess_layer is not None: x = self.postprocess_layer.forward(molecule, x) return x
def _forward_unpostprocessed( self, molecule: Union[DGLMolecule, DGLMoleculeBatch] ) -> torch.Tensor: """ Forward pass without postprocessing the readout modules. This is quality-of-life method for debugging and testing. It is *not* intended for public use. """ x = self.pooling_layer.forward(molecule) x = self.readout_layers.forward(x) return x
[docs] def copy(self, copy_weights: bool = False): pooling = type(self.pooling_layer)() readout = self.readout_layers.copy(copy_weights=copy_weights) postprocess = type(self.postprocess_layer)() copied = type(self)(pooling, readout, postprocess) if copy_weights: copied.load_state_dict(self.state_dict()) return copied
[docs] @classmethod def from_config( cls, readout_config, n_input_features: int ): pooling_layer = readout_config.pooling hidden_feature_sizes = [ layer.hidden_feature_size for layer in readout_config.layers ] layer_activation_functions = [ layer.activation_function for layer in readout_config.layers ] layer_dropout = [ layer.dropout for layer in readout_config.layers ] if readout_config.postprocess is not None: postprocess_layer = _PostprocessLayerMeta._get_object(readout_config.postprocess) hidden_feature_sizes.append(postprocess_layer.n_features) layer_activation_functions.append(ActivationFunction.Identity) layer_dropout.append(0.0) readout_layers = SequentialLayers.with_layers( n_input_features, hidden_feature_sizes, layer_activation_functions, layer_dropout, ) return cls( pooling_layer, readout_layers, postprocess_layer )