import copy
from typing import List, Optional, Union, Tuple, Callable
import torch
from openff.nagl.molecule._dgl import DGLMolecule, DGLMoleculeBatch
from openff.nagl.nn.activation import ActivationFunction
from openff.nagl.nn.gcn._base import _GCNStackMeta, BaseConvModule
from openff.nagl.nn._sequential import SequentialLayers
from openff.nagl.nn._pooling import PoolingLayer, get_pooling_layer
from openff.nagl.nn.postprocess import PostprocessLayer, _PostprocessLayerMeta
[docs]class ConvolutionModule(torch.nn.Module):
def __init__(
self,
n_input_features: int,
hidden_feature_sizes: List[int],
architecture: str = "SAGEConv",
layer_activation_functions: Optional[List[ActivationFunction]] = None,
layer_dropout: Optional[List[float]] = None,
layer_aggregator_types: Optional[List[str]] = None,
):
super().__init__()
self.n_input_features = n_input_features
self.hidden_feature_sizes = hidden_feature_sizes
self.architecture = architecture
self.layer_activation_functions = layer_activation_functions
self.layer_dropout = layer_dropout
self.layer_aggregator_types = layer_aggregator_types
gcn_cls = _GCNStackMeta._get_class(architecture)
self.gcn_layers = gcn_cls.with_layers(
n_input_features=n_input_features,
hidden_feature_sizes=hidden_feature_sizes,
layer_activation_functions=layer_activation_functions,
layer_dropout=layer_dropout,
layer_aggregator_types=layer_aggregator_types,
)
[docs] def copy(self, copy_weights: bool = False):
copied = type(self)(
self.n_input_features,
self.hidden_feature_sizes,
copy.deepcopy(self.architecture),
self.layer_activation_functions,
self.layer_dropout,
self.layer_aggregator_types,
)
if copy_weights:
copied.load_state_dict(self.state_dict())
return copied
[docs] def forward(self, molecule: Union[DGLMolecule, DGLMoleculeBatch]):
# The input graph will be heterogeneous - the edges are split into forward
# edge types and their symmetric reverse counterparts. The convolution layer
# doesn't need this information and hence we produce a homogeneous graph for
# it to operate on with only a single edge type.
homograph = molecule.to_homogenous()
feature_tensor = self.gcn_layers(homograph, molecule.atom_features)
molecule.graph.ndata[molecule._graph_feature_name] = feature_tensor
@property
def _is_dgl(self):
return self.gcn_layers._is_dgl
def _as_nagl(self, copy_weights: bool = False):
copied = self.copy()
if self._is_dgl:
copied.gcn_layers = copied.gcn_layers._as_nagl(copy_weights=copy_weights)
return copied
[docs] @classmethod
def from_config(
cls,
convolution_config,
n_input_features: int
):
hidden_feature_sizes = [
layer.hidden_feature_size for layer in convolution_config.layers
]
layer_activation_functions = [
layer.activation_function for layer in convolution_config.layers
]
layer_dropout = [
layer.dropout for layer in convolution_config.layers
]
layer_aggregator_types = [
layer.aggregator_type for layer in convolution_config.layers
]
return cls(
n_input_features,
hidden_feature_sizes,
architecture=convolution_config.architecture,
layer_activation_functions=layer_activation_functions,
layer_dropout=layer_dropout,
layer_aggregator_types=layer_aggregator_types,
)
[docs]class ReadoutModule(torch.nn.Module):
"""A module that transforms the node features generated by a series of graph
convolutions via propagation through a pooling, readout and optional postprocess
layer.
"""
def __init__(
self,
pooling_layer: PoolingLayer,
readout_layers: SequentialLayers,
postprocess_layer: Optional[PostprocessLayer] = None,
):
"""
Args:
pooling_layer: The pooling layer that will concatenate the node features
computed by a graph convolution into appropriate extended features (e.g.
bond or angle features). The concatenated features will be provided as
input to the dense readout layers.
readout_layers: The dense NN readout layers to apply to the output of the
pooling layers.
postprocess_layer: A (optional) postprocessing layer to apply to the output
of the readout layers
"""
super().__init__()
self.pooling_layer = get_pooling_layer(pooling_layer)
self.readout_layers = readout_layers
if postprocess_layer is not None:
if not isinstance(postprocess_layer, PostprocessLayer):
postprocess_layer = _PostprocessLayerMeta._get_object(postprocess_layer)
self.postprocess_layer = postprocess_layer
[docs] def forward(self, molecule: Union[DGLMolecule, DGLMoleculeBatch]) -> torch.Tensor:
x = self._forward_unpostprocessed(molecule)
if self.postprocess_layer is not None:
x = self.postprocess_layer.forward(molecule, x)
return x
def _forward_unpostprocessed(
self, molecule: Union[DGLMolecule, DGLMoleculeBatch]
) -> torch.Tensor:
"""
Forward pass without postprocessing the readout modules.
This is quality-of-life method for debugging and testing.
It is *not* intended for public use.
"""
x = self.pooling_layer.forward(molecule)
x = self.readout_layers.forward(x)
return x
[docs] def copy(self, copy_weights: bool = False):
pooling = type(self.pooling_layer)()
readout = self.readout_layers.copy(copy_weights=copy_weights)
postprocess = type(self.postprocess_layer)()
copied = type(self)(pooling, readout, postprocess)
if copy_weights:
copied.load_state_dict(self.state_dict())
return copied
[docs] @classmethod
def from_config(
cls,
readout_config,
n_input_features: int
):
pooling_layer = readout_config.pooling
hidden_feature_sizes = [
layer.hidden_feature_size for layer in readout_config.layers
]
layer_activation_functions = [
layer.activation_function for layer in readout_config.layers
]
layer_dropout = [
layer.dropout for layer in readout_config.layers
]
if readout_config.postprocess is not None:
postprocess_layer = _PostprocessLayerMeta._get_object(readout_config.postprocess)
hidden_feature_sizes.append(postprocess_layer.n_features)
layer_activation_functions.append(ActivationFunction.Identity)
layer_dropout.append(0.0)
readout_layers = SequentialLayers.with_layers(
n_input_features,
hidden_feature_sizes,
layer_activation_functions,
layer_dropout,
)
return cls(
pooling_layer,
readout_layers,
postprocess_layer
)