Source code for openff.nagl.training.metrics

import abc
import typing

import torch

from openff.nagl._base.metaregistry import create_registry_metaclass
from openff.nagl._base.base import ImmutableModel

try:
    from pydantic.v1.main import ModelMetaclass
except ImportError:
    from pydantic.main import ModelMetaclass

if typing.TYPE_CHECKING:
    import torch
    from openff.nagl.molecule._dgl.batch import DGLMoleculeBatch
    from openff.nagl.molecule._dgl.molecule import DGLMolecule


# class MetricMeta(ModelMetaclass, abc.ABCMeta, create_registry_metaclass("name")):
#     pass

[docs]class BaseMetric(ImmutableModel, abc.ABC): name: typing.Literal[""] def __call__( self, predicted_values: "torch.Tensor", expected_values: "torch.Tensor" ) -> "torch.Tensor": return self.compute(predicted_values, expected_values)
[docs] @abc.abstractmethod def compute( self, predicted_values: "torch.Tensor", expected_values: "torch.Tensor" ) -> "torch.Tensor": raise NotImplementedError
[docs]class RMSEMetric(BaseMetric): name: typing.Literal["rmse"] = "rmse"
[docs] def compute(self, predicted_values, expected_values): loss = torch.nn.MSELoss() return torch.sqrt(loss(predicted_values, expected_values))
# return torch.sqrt(torch.mean((predicted_values - expected_values) ** 2))
[docs]class MSEMetric(BaseMetric): name: typing.Literal["mse"] = "mse"
[docs] def compute(self, predicted_values, expected_values): loss = torch.nn.MSELoss() return loss(predicted_values, expected_values)
# return torch.mean((predicted_values - expected_values) ** 2)
[docs]class MAEMetric(BaseMetric): name: typing.Literal["mae"] = "mae"
[docs] def compute(self, predicted_values, expected_values): loss = torch.nn.L1Loss() return loss(predicted_values, expected_values)
# return torch.mean(torch.abs(predicted_values - expected_values)) MetricType = typing.Union[RMSEMetric, MSEMetric, MAEMetric] METRICS = { "rmse": RMSEMetric, "mse": MSEMetric, "mae": MAEMetric } """ Mapping from metric names to the corresponding classes. """
[docs]def get_metric_type(metric): if isinstance(metric, BaseMetric): return metric elif isinstance(metric, str): metric = metric.lower() return METRICS[metric]()