Source code for openff.nagl.training.reporting

import base64
import pathlib
import typing

import jinja2
import torch
import numpy as np

from openff.utilities import requires_package
import openff.nagl.training

if typing.TYPE_CHECKING:
    from openff.nagl.molecule._dgl import DGLMolecule
    from openff.nagl.training.metrics import MetricType
    from openff.toolkit import Molecule


def _encode_image(image):
    image_encoded = base64.b64encode(image.encode()).decode()
    image_src = f"data:image/svg+xml;base64,{image_encoded}"
    return image_src

@requires_package("rdkit")
def _draw_molecule_with_atom_labels(
    molecule: "Molecule",
    predicted_labels: torch.Tensor,
    reference_labels: torch.Tensor,
    highlight_outliers: bool = False,
    outlier_threshold: float = 1.0,
) -> str:
    """
    Draw a molecule with predicted and reference atom labels.

    Parameters
    ----------
    molecule : Molecule
        The OpenFF molecule to draw.
    predicted_labels : torch.Tensor
        The predicted atom labels.
    reference_labels : torch.Tensor
        The reference atom labels.
    highlight_outliers : bool, optional
        Whether to highlight atoms with predicted labels that are more than
        ``outlier_threshold`` away from the reference labels.
    outlier_threshold : float, optional
        The threshold for highlighting outliers.
    
    Returns
    -------
    str
        The SVG image of the molecule, as text
    """
    from openff.nagl.molecule._dgl import DGLMolecule
    from rdkit.Chem import Draw

    if isinstance(molecule, DGLMolecule):
        molecule = molecule.to_openff()

    predicted_labels = predicted_labels.detach().numpy().flatten()
    reference_labels = reference_labels.detach().numpy().flatten()
    highlight_atoms = None
    if highlight_outliers:
        diff = np.abs(predicted_labels - reference_labels)
        highlight_atoms = list(np.where(diff > outlier_threshold)[0])
    
    predicted_molecule = molecule.to_rdkit()
    for atom, label in zip(predicted_molecule.GetAtoms(), predicted_labels):
        atom.SetProp("atomNote", f"{float(label):.3f}")
    
    reference_molecule = molecule.to_rdkit()
    for atom, label in zip(reference_molecule.GetAtoms(), reference_labels):
        atom.SetProp("atomNote", f"{float(label):.3f}")
    
    Draw.PrepareMolForDrawing(predicted_molecule)
    Draw.PrepareMolForDrawing(reference_molecule)

    draw_options = Draw.MolDrawOptions()
    draw_options.legendFontSize = 25

    image = Draw.MolsToGridImage(
        [predicted_molecule, reference_molecule],
        legends=["prediction", "reference"],
        molsPerRow=2,
        subImgSize=(400, 400),
        useSVG=True,
        drawOptions=draw_options,
        highlightAtomLists=[highlight_atoms, highlight_atoms],
    )
    return image


@requires_package("rdkit")
def _draw_molecule(
    molecule: typing.Union["Molecule", "DGLMolecule"],
) -> str:
    """
    Draw a molecule without labels.

    Parameters
    ----------
    molecule : typing.Union[Molecule, "DGLMolecule"]
        The molecule to draw.

    Returns
    -------
    str
        The SVG image of the molecule, as text
    """
    from rdkit.Chem import Draw

    from openff.nagl.molecule._dgl import DGLMolecule

    if isinstance(molecule, DGLMolecule):
        molecule = molecule.to_openff()
    
    rdmol = molecule.to_rdkit()
    drawer = Draw.rdMolDraw2D.MolDraw2DSVG(400, 400)
    Draw.PrepareAndDrawMolecule(drawer, rdmol)
    drawer.FinishDrawing()
    return drawer.GetDrawingText()


def _generate_jinja_dicts_per_atom(
    molecules: typing.List["Molecule"],
    predicted_labels: typing.List[torch.Tensor],
    reference_labels: typing.List[torch.Tensor],
    metrics: typing.List["MetricType"],
    highlight_outliers: bool = False,
    outlier_threshold: float = 1.0,
) -> typing.List[typing.Dict[str, str]]:
    from openff.nagl.training.metrics import get_metric_type

    metrics = [get_metric_type(metric) for metric in metrics]
    jinja_dicts = []

    n_molecules = len(molecules)
    if n_molecules != len(predicted_labels):
        raise ValueError(
            "The number of molecules and predicted labels must match."
        )
    if n_molecules != len(reference_labels):
        raise ValueError(
            "The number of molecules and reference labels must match."
        )
    
    for molecule, predicted, reference in zip(
        molecules, predicted_labels, reference_labels
    ):
        entry_metrics = {
            metric.name.upper(): f"{metric.compute(predicted, reference):.4f}"
            for metric in metrics
        }
        image = _draw_molecule_with_atom_labels(
            molecule,
            predicted,
            reference,
            highlight_outliers=highlight_outliers,
            outlier_threshold=outlier_threshold,
        )
        jinja_dicts.append(
            {
                "img": _encode_image(image),
                "metrics": entry_metrics,
            }
        )
    return jinja_dicts


def _generate_jinja_dicts_per_molecule(
    molecules: typing.List["Molecule"],
    metrics: typing.List[torch.Tensor],
    metric_name: str
) -> typing.List[typing.Dict[str, str]]:
    assert len(metrics) == len(molecules)

    jinja_dicts = []
    for molecule, metric in zip(molecules, metrics):
        image = _draw_molecule(molecule)
        data = {
            "img": _encode_image(image),
            "metrics": {
                metric_name.upper(): f"{float(metric):.4f}"
            }
        }
        jinja_dicts.append(data)
    return jinja_dicts
    

def _write_jinja_report(
    output_path: pathlib.Path,
    top_n_structures: typing.List[typing.Dict[str, str]],
    bottom_n_structures: typing.List[typing.Dict[str, str]],
):
    output_path = pathlib.Path(output_path)
    env = jinja2.Environment(
        loader=jinja2.PackageLoader("openff.nagl.training"),
    )
    template = env.get_template("jinja_report.html")
    rendered = template.render(
        top_n_structures=top_n_structures,
        bottom_n_structures=bottom_n_structures,
    )
    output_path = pathlib.Path(output_path)
    output_path.write_text(rendered)



[docs]def create_atom_label_report( molecules: typing.List["Molecule"], predicted_labels: typing.List[torch.Tensor], reference_labels: typing.List[torch.Tensor], metrics: typing.List["MetricType"], rank_by: "MetricType", output_path: pathlib.Path, top_n_entries: int = 100, bottom_n_entries: int = 100, highlight_outliers: bool = False, outlier_threshold: float = 1.0, ): from openff.nagl.training.metrics import get_metric_type ranker = get_metric_type(rank_by) metrics = [get_metric_type(metric) for metric in metrics] n_molecules = len(molecules) if n_molecules != len(predicted_labels): raise ValueError( "The number of molecules and predicted labels must match." ) if n_molecules != len(reference_labels): raise ValueError( "The number of molecules and reference labels must match." ) entries_and_ranks = [] for molecule, predicted, reference in zip( molecules, predicted_labels, reference_labels ): diff = ranker.compute(predicted, reference) entries_and_ranks.append((molecule, predicted, reference, diff)) entries_and_ranks.sort(key=lambda x: x[-1], reverse=True) top_molecules, top_predicted, top_reference, _ = zip( *entries_and_ranks ) top_jinja_dicts = _generate_jinja_dicts_per_atom( top_molecules[:top_n_entries], top_predicted[:top_n_entries], top_reference[:top_n_entries], metrics, highlight_outliers=highlight_outliers, outlier_threshold=outlier_threshold, ) bottom_jinja_dicts = _generate_jinja_dicts_per_atom( top_molecules[-bottom_n_entries:], top_predicted[-bottom_n_entries:], top_reference[-bottom_n_entries:], metrics, highlight_outliers=highlight_outliers, outlier_threshold=outlier_threshold, ) _write_jinja_report( output_path, top_jinja_dicts, bottom_jinja_dicts, )
[docs]def create_molecule_label_report( molecules: typing.List["Molecule"], losses: typing.List[torch.Tensor], metric_name: str, output_path: pathlib.Path, top_n_entries: int = 100, bottom_n_entries: int = 100, ): assert len(molecules) == len(losses) entries = sorted(zip(molecules, losses), key=lambda x: x[-1]) molecules_, losses_ = zip(*entries) top_n_entries = _generate_jinja_dicts_per_molecule( molecules_[:top_n_entries], losses_[:top_n_entries], metric_name, ) bottom_n_entries = _generate_jinja_dicts_per_molecule( molecules_[-bottom_n_entries:], losses_[-bottom_n_entries:], metric_name, ) _write_jinja_report( output_path, top_n_entries, bottom_n_entries, )