Source code for openff.nagl.nn._pooling
import abc
import functools
from typing import ClassVar, Dict, Union, TYPE_CHECKING, Iterable
import torch.nn
from openff.nagl.molecule._dgl import DGLMolecule, DGLMoleculeBatch, DGLMoleculeOrBatch
from openff.nagl.nn._sequential import SequentialLayers
if TYPE_CHECKING:
import dgl
class PoolingLayer(torch.nn.Module, abc.ABC):
"""A convenience class for pooling together node feature vectors produced by
a graph convolutional layer.
"""
n_feature_columns: ClassVar[int] = 0
@abc.abstractmethod
def forward(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor:
"""Returns the pooled feature vector."""
@abc.abstractmethod
def get_nvalues_per_molecule(self, molecule: DGLMoleculeOrBatch) -> Iterable[int]:
"""Returns the number of values per molecule."""
[docs]class PoolAtomFeatures(PoolingLayer):
"""A convenience class for pooling the node feature vectors produced by
a graph convolutional layer.
This class simply returns the features "h" from the graphs node data.
"""
n_feature_columns: ClassVar[int] = 1
[docs] def forward(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor:
return molecule.graph.ndata[molecule._graph_feature_name]
[docs] def get_nvalues_per_molecule(self, molecule: DGLMoleculeOrBatch) -> Iterable[int]:
return molecule.n_atoms_per_molecule
[docs]class PoolBondFeatures(PoolingLayer):
"""A convenience class for pooling the node feature vectors produced by
a graph convolutional layer into a set of symmetric bond (edge) features.
"""
n_feature_columns: ClassVar[int] = 2
def __init__(self, layers: SequentialLayers):
super().__init__()
self.layers = layers
@staticmethod
def _apply_edges(
edges: "dgl.udf.EdgeBatch", feature_name: str = "h"
) -> Dict[str, torch.Tensor]:
h_u = edges.src[feature_name]
h_v = edges.dst[feature_name]
return {feature_name: torch.cat([h_u, h_v], 1)}
# def _directionwise_forward(
# self,
# molecule: DGLMoleculeOrBatch,
# edge_type: str = "forward",
# ):
# graph = molecule.graph
# apply_edges = functools.partial(
# self._apply_edges,
# feature_name=molecule._graph_feature_name,
# )
# with graph.local_scope():
# graph.apply_edges(apply_edges, etype=edge_type)
# edges = graph.edges[edge_type].data[molecule._graph_feature_name]
# return self.layers(edges)
[docs] def forward(self, molecule: DGLMoleculeOrBatch) -> torch.Tensor:
graph = molecule.graph
node = molecule._graph_feature_name
apply_edges = functools.partial(
self._apply_edges,
feature_name=node,
)
with graph.local_scope():
graph.apply_edges(apply_edges, etype=molecule._graph_forward_edge_type)
h_forward = graph.edges[molecule._graph_forward_edge_type].data[node]
with graph.local_scope():
graph.apply_edges(apply_edges, etype=molecule._graph_backward_edge_type)
h_reverse = graph.edges[molecule._graph_backward_edge_type].data[node]
# h_forward = self._directionwise_forward(
# molecule,
# molecule._graph_forward_edge_type,
# )
# h_reverse = self._directionwise_forward(
# molecule,
# molecule._graph_backward_edge_type,
# )
return self.layers(h_forward) + self.layers(h_reverse)
[docs] def get_nvalues_per_molecule(self, molecule: DGLMoleculeOrBatch) -> Iterable[int]:
return molecule.n_bonds_per_molecule
def get_pooling_layer(layer: Union[str, PoolingLayer]) -> PoolingLayer:
if isinstance(layer, PoolingLayer):
return layer
if isinstance(layer, str):
if layer.lower() in {"atom", "atoms"}:
return PoolAtomFeatures()
if layer.lower() in {"bond", "bonds"}:
return PoolBondFeatures()
raise NotImplementedError(f"Unsupported pooling layer '{layer}'.")