import abc
from typing import ClassVar, Dict, Generic, List, Optional, Type, TypeVar
import torch.nn
import torch.nn.functional
from openff.nagl._base.metaregistry import create_registry_metaclass
from openff.nagl.nn.activation import ActivationFunction
from openff.nagl.nn._base import ContainsLayersMixin
GCNLayerType = TypeVar("GCNLayerType", bound=torch.nn.Module)
# class GCNStackMeta(abc.ABCMeta):
# """A metaclass for GCN stacks.
# This metaclass is used to register GCN layers by name.
# """
# registry: ClassVar[Dict[str, Type]] = {}
# def __init__(cls, name, bases, namespace, **kwargs):
# super().__init__(name, bases, namespace, **kwargs)
# if hasattr(cls, "name") and cls.name:
# cls.registry[cls.name] = cls
# @classmethod
# def get_gcn_class(cls, class_name: str):
# if isinstance(class_name, cls):
# return class_name
# if isinstance(type(class_name), cls):
# return type(class_name)
# try:
# return cls.registry[class_name]
# except KeyError:
# raise ValueError(
# f"Unknown GCN layer type: {class_name}. "
# f"Supported types: {list(cls.registry.keys())}"
# )
class BaseConvModule(torch.nn.Module):
pass
class _GCNStackMeta(abc.ABCMeta, create_registry_metaclass("name")):
pass
[docs]class BaseGCNStack(
torch.nn.ModuleList,
Generic[GCNLayerType],
ContainsLayersMixin,
abc.ABC,
metaclass=_GCNStackMeta,
):
"""A wrapper around a stack of GCN graph convolutional layers.
Note:
This class is based on the ``dgllife.model.SAGEConv`` module.
"""
# hidden_feature_sizes: List[GCNLayerType]
@property
@classmethod
@abc.abstractmethod
def name(cls) -> str:
pass
@property
@classmethod
@abc.abstractmethod
def available_aggregator_types(cls) -> str:
"""The aggregator options to use for the GCN layers."""
@property
@classmethod
@abc.abstractmethod
def default_aggregator_type(cls) -> str:
"""The aggregator options to use for the GCN layers."""
@property
@classmethod
@abc.abstractmethod
def default_dropout(cls) -> str:
"""The aggregator options to use for the GCN layers."""
@property
@classmethod
@abc.abstractmethod
def default_activation_function(cls) -> str:
"""The aggregator options to use for the GCN layers."""
@classmethod
def _check_input_lengths(
cls,
n_layers: int,
layer_activation_functions: Optional[List[ActivationFunction]] = None,
layer_dropout: Optional[List[float]] = None,
layer_aggregator_types: Optional[List[str]] = None,
):
layer_activation_functions, layer_dropout = super()._check_input_lengths(
n_layers,
layer_activation_functions,
layer_dropout,
)
if layer_aggregator_types is None:
layer_aggregator_types = cls.default_aggregator_type
layer_aggregator_types = cls._check_argument_input_length(
n_layers,
layer_aggregator_types,
"layer_aggregator_types",
)
return layer_activation_functions, layer_dropout, layer_aggregator_types
[docs] @classmethod
def with_layers(
cls,
n_input_features: int,
hidden_feature_sizes: List[int],
layer_activation_functions: Optional[List[ActivationFunction]] = None,
layer_dropout: Optional[List[float]] = None,
layer_aggregator_types: Optional[List[str]] = None,
):
"""Create this model with layers with the specified parameters."""
obj = cls()
n_layers = len(hidden_feature_sizes)
(
layer_activation_functions,
layer_dropout,
layer_aggregator_types,
) = cls._check_input_lengths(
n_layers,
layer_activation_functions,
layer_dropout,
layer_aggregator_types,
)
for i in range(n_layers):
n_output_features = hidden_feature_sizes[i]
activation_function = layer_activation_functions[i]
dropout = layer_dropout[i]
aggregator_type = layer_aggregator_types[i]
obj.append_gcn_layer(
n_input_features=n_input_features,
n_output_features=n_output_features,
activation_function=activation_function,
aggregator_type=aggregator_type,
dropout=dropout,
)
n_input_features = n_output_features
return obj
def __init__(self, *args, _is_dgl: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_feature_sizes = []
[docs] def append_gcn_layer(
self,
n_output_features: int,
n_input_features: Optional[int] = None,
aggregator_type: Optional[str] = None,
dropout: Optional[float] = None,
activation_function: Optional[ActivationFunction] = None,
):
"""Add a new layer to the stack."""
if n_input_features is None:
try:
n_input_features = self.hidden_feature_sizes[-1]
except IndexError:
raise ValueError(
"Must specify n_input_features if no layers have been created yet."
)
self.hidden_feature_sizes.append(n_output_features)
self.append(
self.create_gcn_layer(
n_input_features,
n_output_features,
aggregator_type,
dropout,
activation_function,
)
)
[docs] @classmethod
def create_gcn_layer(
cls,
n_input_features: int,
n_output_features: int,
aggregator_type: Optional[str] = None,
dropout: Optional[float] = None,
activation_function: Optional[ActivationFunction] = None,
**kwargs,
) -> GCNLayerType:
"""Create a new GCN layer."""
if aggregator_type is None:
aggregator_type = cls.default_aggregator_type
if dropout is None:
dropout = cls.default_dropout
if activation_function is None:
activation_function = cls.default_activation_function
activation = ActivationFunction.get_value(activation_function)
# activation = ActivationFunction.get_function(activation_function)
return cls._create_gcn_layer(
n_input_features=n_input_features,
n_output_features=n_output_features,
aggregator_type=aggregator_type,
dropout=dropout,
activation_function=activation,
**kwargs,
)
@classmethod
@abc.abstractmethod
def _create_gcn_layer(
cls,
n_input_features: int,
n_output_features: int,
aggregator_type: str,
dropout: float,
activation_function: ActivationFunction,
**kwargs,
) -> GCNLayerType:
"""A function which returns an instantiated GCN layer.
Args:
in_feats: Number of input node features.
out_feats: Number of output node features.
activation_function: The activation_function function to.
dropout: `The dropout probability.
aggregator_type: The aggregator type, which can be one of ``"sum"``,
``"max"``, ``"mean"``.
init_eps: The initial value of epsilon.
learn_eps: If True epsilon will be a learnable parameter.
Returns:
The instantiated GCN layer.
"""
[docs] def reset_parameters(self):
"""Reinitialize model parameters."""
for gnn in self:
gnn.reset_parameters()
[docs] def forward(self, graph, inputs: torch.Tensor) -> torch.Tensor:
"""Update node representations.
Args:
graph: The batch of graphs to operate on.
inputs: The inputs to the layers with shape=(n_nodes, in_feats).
Returns
The output hidden features with shape=(n_nodes, hidden_feats[-1]).
"""
for gnn in self:
inputs: torch.Tensor = gnn(graph, inputs)
return inputs
@property
def _is_dgl(self):
return not isinstance(self[0], BaseConvModule)