Source code for openff.nagl.nn.activation
"Activation functions"
import enum
from typing import Callable
import torch
__all__ = ["ActivationFunction"]
[docs]class ActivationFunction(enum.Enum):
"""Activation function options"""
Identity = torch.nn.Identity
Tanh = torch.nn.Tanh
ReLU = torch.nn.ReLU
LeakyReLU = torch.nn.LeakyReLU
ELU = torch.nn.ELU
Sigmoid = torch.nn.Sigmoid
@classmethod
def _lowercase(cls):
return {name.lower(): value for name, value in cls.__members__.items()}
[docs] @classmethod
def get(cls, name: str) -> "ActivationFunction":
if isinstance(name, cls):
return name
if isinstance(name, str):
try:
return cls[name]
except KeyError:
return cls._lowercase()[name.lower()]
return cls(name)
[docs] @classmethod
def get_value(cls, name: str) -> Callable[[torch.tensor], torch.Tensor]:
try:
return cls.get(name).value()
except ValueError:
return name
[docs] @classmethod
def get_function(cls, name: str) -> Callable[[torch.tensor], torch.Tensor]:
import torch.nn.functional as F
x = cls.get(name)
FUNCTIONS = {
"Identity": lambda x: x,
"Tanh": F.tanh,
"ReLU": F.relu,
"LeakyReLU": F.leaky_relu,
"ELU": F.elu,
"Sigmoid": F.sigmoid,
}
return FUNCTIONS[x.name]
_get_object = get_value
_get_class = get