Source code for openff.nagl.nn.gcn._sage

import copy
from typing import Optional, Literal, Dict, TYPE_CHECKING, Union

import torch

from ._base import ActivationFunction, BaseGCNStack, BaseConvModule
from openff.nagl.nn.gcn import _function as _fn

# import dgl.function as fn
from openff.utilities import requires_package
from openff.utilities.exceptions import MissingOptionalDependencyError

if TYPE_CHECKING:
    import dgl


class SAGEConv(BaseConvModule):
    def __init__(
        self,
        in_feats: int,
        out_feats: int,
        aggregator_type: Literal["mean", "gcn", "pool", "lstm"],
        feat_drop: float,
        bias: bool = True,
        norm: Optional[callable] = None,
        activation: Optional[callable] = None,
    ):
        super().__init__()
        if aggregator_type not in SAGEConvStack.available_aggregator_types:
            raise ValueError(
                f"Aggregator type {aggregator_type} not supported by {SAGEConvStack.name}."
            )
        self._in_src_feats, self._in_dst_feats = in_feats, in_feats
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = torch.nn.Dropout(feat_drop)
        self.activation = activation

        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == "pool":
            self.fc_pool = torch.nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == "lstm":
            self.lstm = torch.nn.LSTM(
                self._in_src_feats, self._in_src_feats, batch_first=True
            )

        self.fc_neigh = torch.nn.Linear(self._in_src_feats, out_feats, bias=False)

        # TODO: replace lower code with upper code -- more up-to-date with DGL 1.x
        if aggregator_type != "gcn":
            self.fc_self = torch.nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        elif bias:
            self.bias = torch.nn.parameter.Parameter(torch.zeros(self._out_feats))
        else:
            self.register_buffer("bias", None)

        self.reset_parameters()

    def reset_parameters(self):
        r"""

        Description
        -----------
        Reinitialize learnable parameters.

        Note
        ----
        The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
        The LSTM module is using xavier initialization method for its weights.
        """
        gain = torch.nn.init.calculate_gain("relu")
        if self._aggre_type == "pool":
            torch.nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == "lstm":
            self.lstm.reset_parameters()
        if self._aggre_type != "gcn":
            torch.nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        torch.nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes) -> Dict[str, torch.Tensor]:
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox["m"]  # (B, L, D)
        batch_size = m.shape[0]
        h = (
            m.new_zeros((1, batch_size, self._in_src_feats)),
            m.new_zeros((1, batch_size, self._in_src_feats)),
        )
        _, (rst, _) = self.lstm(m, h)
        return {"neigh": rst.squeeze(0)}

    def forward(self, graph, feat, edge_weight=None):
        r"""

        Description
        -----------
        Compute GraphSAGE layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
        edge_weight : torch.Tensor, optional
            Optional tensor on the edge. If given, the convolution will weight
            with regard to the message.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N_{dst}, D_{out})`
            where :math:`N_{dst}` is the number of destination nodes in the input graph,
            :math:`D_{out}` is the size of the output feature.
        """
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
                if graph.is_block:
                    feat_dst = feat_src[: graph.number_of_dst_nodes()]

            msg_fn = _fn.copy_u("h", "m")
            # msg_fn = fn.copy_u('h', 'm')
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata["_edge_weight"] = edge_weight
                msg_fn = _fn.u_mul_e("h", "_edge_weight", "m")

            h_self = feat_dst

            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata["neigh"] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats
                ).to(feat_dst)

            # Determine whether to apply linear transformation before message passing A(XW)
            lin_before_mp = self._in_src_feats > self._out_feats

            # Message Passing
            if self._aggre_type == "mean":
                graph.srcdata["h"] = (
                    self.fc_neigh(feat_src) if lin_before_mp else feat_src
                )
                graph.update_all(msg_fn, _fn.mean("m", "neigh"))
                h_neigh = graph.dstdata["neigh"]
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)

            elif self._aggre_type == "gcn":
                if isinstance(feat, tuple):  # heterogeneous
                    assert feat[0].shape == feat[1].shape
                graph.srcdata["h"] = (
                    self.fc_neigh(feat_src) if lin_before_mp else feat_src
                )
                if isinstance(feat, tuple):  # heterogeneous
                    graph.dstdata["h"] = (
                        self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
                    )
                else:
                    if graph.is_block:
                        graph.dstdata["h"] = graph.srcdata["h"][: graph.num_dst_nodes()]
                    else:
                        graph.dstdata["h"] = graph.srcdata["h"]
                graph.update_all(msg_fn, _fn.sum("m", "neigh"))
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata["neigh"] + graph.dstdata["h"]) / (
                    degs.unsqueeze(-1) + 1
                )
                if not lin_before_mp:
                    h_neigh = self.fc_neigh(h_neigh)

            elif self._aggre_type == "pool":
                graph.srcdata["h"] = torch.relu(self.fc_pool(feat_src))
                graph.update_all(msg_fn, _fn.max("m", "neigh"))
                h_neigh = self.fc_neigh(graph.dstdata["neigh"])

            elif self._aggre_type == "lstm":
                graph.srcdata["h"] = feat_src
                graph.update_all(msg_fn, self._lstm_reducer)
                h_neigh = self.fc_neigh(graph.dstdata["neigh"])

            else:
                raise KeyError(
                    "Aggregator type {} not recognized.".format(self._aggre_type)
                )

            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == "gcn":
                rst = h_neigh
                # add bias manually for GCN
                if self.bias is not None:
                    rst = rst + self.bias
            else:
                rst = self.fc_self(h_self) + h_neigh

            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)

            return rst


[docs]class SAGEConvStack(BaseGCNStack[Union[SAGEConv, "dgl.nn.pytorch.SAGEConv"]]): """ GraphSAGE graph convolutional neural network for atom embeddings. `GraphSAGE <https://snap.stanford.edu/graphsage/>`_ GCNs learn a function that iteratively improves a node embedding by mixing in aggregated feature vectors of progressively more distant neighborhoods. GraphSAGE is inductive, scales to large graphs, and makes good use of feature-rich node embeddings. Layers in this network use the DGL :py:class:`SAGEConv <dgl.nn.pytorch.conv.SAGEConv>` class. See Also -------- dgl.nn.pytorch.conv.SAGEConv """ name = "SAGEConv" available_aggregator_types = ["mean", "gcn", "pool", "lstm"] default_aggregator_type = "mean" default_dropout = 0.0 default_activation_function = ActivationFunction.ReLU @classmethod def _create_gcn_layer( cls, n_input_features: int, n_output_features: int, aggregator_type: str, dropout: float, activation_function: ActivationFunction, **kwargs, ) -> Union[SAGEConv, "dgl.nn.pytorch.SAGEConv"]: try: return cls._create_gcn_layer_dgl( n_input_features=n_input_features, n_output_features=n_output_features, aggregator_type=aggregator_type, dropout=dropout, activation_function=activation_function, **kwargs, ) except MissingOptionalDependencyError: return cls._create_gcn_layer_nagl( n_input_features=n_input_features, n_output_features=n_output_features, aggregator_type=aggregator_type, dropout=dropout, activation_function=activation_function, **kwargs, ) @classmethod def _create_gcn_layer_nagl( cls, n_input_features: int, n_output_features: int, aggregator_type: str, dropout: float, activation_function: ActivationFunction, **kwargs, ) -> "SAGEConv": return SAGEConv( in_feats=n_input_features, out_feats=n_output_features, activation=activation_function, feat_drop=dropout, aggregator_type=aggregator_type, ) @classmethod @requires_package("dgl") def _create_gcn_layer_dgl( cls, n_input_features: int, n_output_features: int, aggregator_type: str, dropout: float, activation_function: ActivationFunction, **kwargs, ) -> "dgl.nn.pytorch.SAGEConv": import dgl return dgl.nn.pytorch.SAGEConv( in_feats=n_input_features, out_feats=n_output_features, activation=activation_function, feat_drop=dropout, aggregator_type=aggregator_type, ) def _as_nagl(self, copy_weights: bool = False): if self._is_dgl: new_obj = type(self)() new_obj.hidden_feature_sizes = self.hidden_feature_sizes for layer in self: new_layer = self._create_gcn_layer_nagl( n_input_features=layer._in_src_feats, n_output_features=layer._out_feats, aggregator_type=layer._aggre_type, dropout=layer.feat_drop.p, activation_function=layer.activation, ) if copy_weights: new_layer.load_state_dict(layer.state_dict()) new_obj.append(new_layer) return copy.deepcopy(new_obj) return copy.deepcopy(self)