"Classes for handling featurized molecule data to train GNN models"
from collections import defaultdict
import functools
import glob
import hashlib
import io
import logging
import pickle
import tempfile
import typing
import tqdm
import torch
from openff.utilities import requires_package
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from openff.nagl._base.base import ImmutableModel
from openff.nagl.config.training import TrainingConfig
from openff.nagl.features.atoms import AtomFeature
from openff.nagl.features.bonds import BondFeature
from openff.nagl.molecule._dgl import DGLMolecule, DGLMoleculeBatch, DGLMoleculeOrBatch
from openff.nagl.utils._parallelization import get_mapper_to_processes
from openff.nagl.utils._hash import digest_file
import pathlib
import numpy as np
if typing.TYPE_CHECKING:
from openff.toolkit import Molecule
__all__ = [
"DataHash",
"DGLMoleculeDataset",
"DGLMoleculeDatasetEntry",
]
logger = logging.getLogger(__name__)
class DataHash(ImmutableModel):
"""A class for computing the hash of a dataset."""
path_hash: str
columns: typing.List[str]
atom_features: typing.List[AtomFeature]
bond_features: typing.List[BondFeature]
@classmethod
def from_file(
cls,
*paths: typing.Union[str, pathlib.Path],
columns: typing.Optional[typing.List[str]] = None,
atom_features: typing.Optional[typing.List[AtomFeature]] = None,
bond_features: typing.Optional[typing.List[BondFeature]] = None,
):
path_hash = ""
for path in paths:
path = pathlib.Path(path)
if path.is_dir():
for file in path.glob("**/*"):
if file.is_file():
path_hash += digest_file(file)
elif path.is_file():
path_hash += digest_file(path)
else:
path_hash += str(path.resolve())
if columns is None:
columns = []
columns = sorted(columns)
if atom_features is None:
atom_features = []
if bond_features is None:
bond_features = []
return cls(
path_hash=path_hash,
columns=columns,
atom_features=atom_features,
bond_features=bond_features,
)
def to_hash(self):
json_str = self.json().encode("utf-8")
hashed = hashlib.sha256(json_str).hexdigest()
return hashed
def _get_hashed_arrow_dataset_path(
path: pathlib.Path,
atom_features: typing.Optional[typing.List[AtomFeature]] = None,
bond_features: typing.Optional[typing.List[BondFeature]] = None,
columns: typing.Optional[typing.List[str]] = None,
directory: typing.Optional[pathlib.Path] = None
) -> pathlib.Path:
hash_value = DataHash.from_file(
path,
columns=columns,
atom_features=atom_features,
bond_features=bond_features,
).to_hash()
file_path = f"{hash_value}"
if directory is not None:
directory = pathlib.Path(directory)
return directory / file_path
return pathlib.Path(file_path)
[docs]class DGLMoleculeDatasetEntry(typing.NamedTuple):
"""A named tuple containing a featurized molecule graph, a tensor of the atom
features, and a tensor of the molecule label.
"""
molecule: DGLMolecule
labels: typing.Dict[str, torch.Tensor]
[docs] @classmethod
def from_openff(
cls,
openff_molecule: "Molecule",
labels: typing.Dict[str, typing.Any],
atom_features: typing.List[AtomFeature],
bond_features: typing.List[BondFeature],
atom_feature_tensor: typing.Optional[torch.Tensor] = None,
bond_feature_tensor: typing.Optional[torch.Tensor] = None,
):
dglmol = DGLMolecule.from_openff(
openff_molecule,
atom_features=atom_features,
bond_features=bond_features,
atom_feature_tensor=atom_feature_tensor,
bond_feature_tensor=bond_feature_tensor,
)
labels_ = {}
for key, value in labels.items():
if value is not None:
value = np.asarray(value)
tensor = torch.from_numpy(value)
if tensor.dtype == torch.float64:
tensor = tensor.float()
labels_[key] = tensor
return cls(dglmol, labels_)
[docs] def to(self, device: str):
return type(self)(
self.molecule.to(device),
{k: v.to(device) for k, v in self.labels.items()},
)
[docs] @classmethod
def from_mapped_smiles(
cls,
mapped_smiles: str,
labels: typing.Dict[str, typing.Any],
atom_features: typing.List[AtomFeature],
bond_features: typing.List[BondFeature],
atom_feature_tensor: typing.Optional[torch.Tensor] = None,
bond_feature_tensor: typing.Optional[torch.Tensor] = None,
):
"""
Create a dataset entry from a mapped SMILES string.
Parameters
----------
mapped_smiles
The mapped SMILES string.
labels
The labels for the dataset entry.
These will be converted to Pytorch tensors.
atom_features
The atom features to use.
If this is provided, an atom_feature_tensor
should not be provided as it will be generated
during featurization.
bond_features
The bond features to use.
If this is provided, a bond_feature_tensor
should not be provided as it will be generated
during featurization.
atom_feature_tensor
The atom feature tensor to use.
If this is provided, atom_features should not
be provided as it will be ignored.
bond_feature_tensor
The bond feature tensor to use.
If this is provided, bond_features should not
be provided as it will be ignored.
"""
from openff.toolkit import Molecule
molecule = Molecule.from_mapped_smiles(
mapped_smiles,
allow_undefined_stereo=True,
)
return cls.from_openff(
molecule,
labels,
atom_features,
bond_features,
atom_feature_tensor,
bond_feature_tensor,
)
@classmethod
def _from_unfeaturized_pyarrow_row(
cls,
row: typing.Dict[str, typing.Any],
atom_features: typing.List[AtomFeature],
bond_features: typing.List[BondFeature],
smiles_column: str = "mapped_smiles",
):
labels = dict(row)
mapped_smiles = labels.pop(smiles_column)
return cls.from_mapped_smiles(
mapped_smiles,
labels,
atom_features,
bond_features,
)
@classmethod
def _from_featurized_pyarrow_row(
cls,
row: typing.Dict[str, typing.Any],
atom_feature_column: str,
bond_feature_column: str,
smiles_column: str = "mapped_smiles",
):
from openff.toolkit import Molecule
labels = dict(row)
mapped_smiles = labels.pop(smiles_column)
atom_features = labels.pop(atom_feature_column)
bond_features = labels.pop(bond_feature_column)
molecule = Molecule.from_mapped_smiles(
mapped_smiles,
allow_undefined_stereo=True,
)
if atom_features is not None:
atom_features = torch.tensor(atom_features).float()
atom_features = atom_features.reshape(len(molecule.atoms), -1)
if bond_features is not None:
bond_features = torch.tensor(bond_features).float()
bond_features = bond_features.reshape(len(molecule.bonds), -1)
return cls.from_mapped_smiles(
mapped_smiles,
labels,
atom_features=[],
bond_features=[],
atom_feature_tensor=atom_features,
bond_feature_tensor=bond_features,
)
class _LazyDGLMoleculeDataset(Dataset):
version = 0.1
@property
def schema(self):
import pyarrow as pa
return pa.schema([pa.field("pickled", pa.binary())])
def __len__(self):
return self.n_entries
def __getitem__(self, index):
row = self.table.slice(index, length=1).to_pydict()["pickled"][0]
entry = pickle.loads(row)
return entry
@requires_package("pyarrow")
def __init__(
self,
source: str,
):
import pyarrow as pa
self.source = str(source)
with pa.memory_map(self.source, "rb") as src:
reader = pa.ipc.open_file(src)
self.table = reader.read_all()
self.n_entries = self.table.num_rows
self.n_atom_features = (
self[0].molecule.atom_features.shape[1]
if len(self)
else 0
)
@classmethod
@requires_package("pyarrow")
def from_arrow_dataset(
cls,
path: pathlib.Path,
format: str = "parquet",
atom_features: typing.Optional[typing.List[AtomFeature]] = None,
bond_features: typing.Optional[typing.List[BondFeature]] = None,
atom_feature_column: typing.Optional[str] = None,
bond_feature_column: typing.Optional[str] = None,
smiles_column: str = "mapped_smiles",
columns: typing.Optional[typing.List[str]] = None,
cache_directory: typing.Optional[pathlib.Path] = None,
use_cached_data: bool = True,
n_processes: int = 0,
):
import pyarrow as pa
import pyarrow.dataset as ds
if columns is not None:
columns = list(columns)
if smiles_column not in columns:
columns.append(smiles_column)
file_path = _get_hashed_arrow_dataset_path(
path,
atom_features,
bond_features,
columns,
).with_suffix(".arrow")
if cache_directory is not None:
cache_directory = pathlib.Path(cache_directory)
output_path = cache_directory / file_path
else:
output_path = file_path
if use_cached_data:
if output_path.exists():
return cls(output_path)
else:
tempdir = tempfile.TemporaryDirectory()
output_path = pathlib.Path(tempdir.name) / file_path
logger.info(f"Featurizing dataset to {output_path}")
if atom_feature_column is None and bond_feature_column is None:
# set featurizer function
converter = functools.partial(
cls._pickle_entry_from_unfeaturized_row,
atom_features=atom_features,
bond_features=bond_features,
smiles_column=smiles_column,
)
else:
converter = functools.partial(
cls._pickle_entry_from_featurized_row,
atom_feature_column=atom_feature_column,
bond_feature_column=bond_feature_column,
smiles_column=smiles_column,
)
if columns is not None and atom_feature_column not in columns:
columns.append(atom_feature_column)
if columns is not None and bond_feature_column not in columns:
columns.append(bond_feature_column)
input_dataset = ds.dataset(path, format=format)
with pa.OSFile(str(output_path), "wb") as sink:
with pa.ipc.new_file(sink, cls.schema) as writer:
input_batches = input_dataset.to_batches(columns=columns)
for input_batch in input_batches:
with get_mapper_to_processes(n_processes=n_processes) as mapper:
pickled = list(mapper(converter, input_batch.to_pylist()))
output_batch = pa.RecordBatch.from_arrays(
[pa.array(pickled)],
schema=cls.schema
)
writer.write_batch(output_batch)
return cls(output_path)
@staticmethod
def _pickle_entry_from_unfeaturized_row(
row,
atom_features=None,
bond_features=None,
smiles_column="mapped_smiles",
):
entry = DGLMoleculeDatasetEntry._from_unfeaturized_pyarrow_row(
row,
atom_features=atom_features,
bond_features=bond_features,
smiles_column=smiles_column,
)
f = io.BytesIO()
pickle.dump(entry, f)
return f.getvalue()
@staticmethod
def _pickle_entry_from_featurized_row(
row,
atom_feature_column: str = "atom_features",
bond_feature_column: str = "bond_features",
smiles_column: str = "mapped_smiles",
):
entry = DGLMoleculeDatasetEntry._from_featurized_pyarrow_row(
row,
atom_feature_column=atom_feature_column,
bond_feature_column=bond_feature_column,
smiles_column=smiles_column,
)
f = io.BytesIO()
pickle.dump(entry, f)
return f.getvalue()
[docs]class DGLMoleculeDataset(Dataset):
def __len__(self):
return len(self.entries)
def __getitem__(self, index_or_slice):
return self.entries[index_or_slice]
def __init__(self, entries: typing.Tuple[DGLMoleculeDatasetEntry, ...] = tuple()):
self.entries = entries
@property
def n_atom_features(self) -> int:
if not len(self):
return 0
return self[0].molecule.atom_features.shape[1]
[docs] @classmethod
@requires_package("pyarrow")
def from_arrow_dataset(
cls,
path: pathlib.Path,
format: str = "parquet",
atom_features: typing.Optional[typing.List[AtomFeature]] = None,
bond_features: typing.Optional[typing.List[BondFeature]] = None,
atom_feature_column: typing.Optional[str] = None,
bond_feature_column: typing.Optional[str] = None,
smiles_column: str = "mapped_smiles",
columns: typing.Optional[typing.List[str]] = None,
n_processes: int = 0,
):
import pyarrow.dataset as ds
if columns is not None:
columns = list(columns)
if smiles_column not in columns:
columns.append(smiles_column)
if atom_feature_column is None and bond_feature_column is None:
converter = functools.partial(
DGLMoleculeDatasetEntry._from_unfeaturized_pyarrow_row,
atom_features=atom_features,
bond_features=bond_features,
smiles_column=smiles_column,
)
else:
converter = functools.partial(
DGLMoleculeDatasetEntry._from_featurized_pyarrow_row,
atom_feature_column=atom_feature_column,
bond_feature_column=bond_feature_column,
smiles_column=smiles_column,
)
if columns is not None and atom_feature_column not in columns:
columns.append(atom_feature_column)
if columns is not None and bond_feature_column not in columns:
columns.append(bond_feature_column)
input_dataset = ds.dataset(path, format=format)
entries = []
for input_batch in tqdm.tqdm(
input_dataset.to_batches(columns=columns),
desc="Featurizing dataset",
):
for row in tqdm.tqdm(input_batch.to_pylist(), desc="Featurizing batch"):
entries.append(converter(row))
# with get_mapper_to_processes(n_processes=n_processes) as mapper:
# row_entries = list(mapper(converter, input_batch.to_pylist()))
# entries.extend(row_entries)
return cls(entries)
[docs] @classmethod
def from_openff(
cls,
molecules: typing.Iterable["Molecule"],
atom_features: typing.Optional[typing.List[AtomFeature]] = None,
bond_features: typing.Optional[typing.List[BondFeature]] = None,
atom_feature_tensors: typing.Optional[typing.List[torch.Tensor]] = None,
bond_feature_tensors: typing.Optional[typing.List[torch.Tensor]] = None,
labels: typing.Optional[typing.List[typing.Dict[str, typing.Any]]] = None,
label_function: typing.Optional[
typing.Callable[["Molecule"], typing.Dict[str, typing.Any]]
] = None,
):
if labels is None:
labels = [{} for _ in molecules]
else:
labels = [dict(label) for label in labels]
if len(labels) != len(molecules):
raise ValueError(
f"The number of labels ({len(labels)}) must match the number of "
f"molecules ({len(molecules)})."
)
if atom_feature_tensors is not None:
if len(atom_feature_tensors) != len(molecules):
raise ValueError(
f"The number of atom feature tensors ({len(atom_feature_tensors)}) "
f"must match the number of molecules ({len(molecules)})."
)
else:
atom_feature_tensors = [None] * len(molecules)
if bond_feature_tensors is not None:
if len(bond_feature_tensors) != len(molecules):
raise ValueError(
f"The number of bond feature tensors ({len(bond_feature_tensors)}) "
f"must match the number of molecules ({len(molecules)})."
)
else:
bond_feature_tensors = [None] * len(molecules)
if label_function is not None:
for molecule, label in zip(molecules, labels):
label.update(label_function(molecule))
entries = [
DGLMoleculeDatasetEntry.from_openff(
molecule,
label,
atom_features=atom_features,
bond_features=bond_features,
atom_feature_tensor=atom_tensor,
bond_feature_tensor=bond_tensor,
)
for molecule, atom_tensor, bond_tensor, label in zip(
molecules, atom_feature_tensors, bond_feature_tensors, labels
)
]
return cls(entries)
[docs] @requires_package("pyarrow")
def to_pyarrow(
self,
atom_feature_column: str = "atom_features",
bond_feature_column: str = "bond_features",
smiles_column: str = "mapped_smiles",
):
"""
Convert the dataset to a Pyarrow table.
This will contain at minimum the smiles, atom features,
and bond features, using the column names specified as
arguments. It will also contain any labels that in the entry.
Parameters
----------
atom_feature_column
The name of the column to use for the atom features.
bond_feature_column
The name of the column to use for the bond features.
smiles_column
The name of the column to use for the SMILES strings.
Returns
-------
table
"""
import pyarrow as pa
required_columns = [smiles_column, atom_feature_column, bond_feature_column]
label_columns = []
if len(self):
first_labels = self.entries[0].labels
label_columns = list(first_labels.keys())
label_set = set(label_columns)
rows = []
for dglmol, labels in self.entries:
atom_features = None
bond_features = None
if dglmol.atom_features is not None:
atom_features = dglmol.atom_features.detach().numpy()
atom_features = atom_features.astype(float).flatten()
if dglmol.bond_features is not None:
bond_features = dglmol.bond_features.detach().numpy()
bond_features = bond_features.astype(float).flatten()
mol_label_set = set(labels.keys())
if label_set != mol_label_set:
raise ValueError(
f"The label sets are not consistent. "
f"Expected {label_set}, got {mol_label_set}."
)
row = [dglmol.mapped_smiles, atom_features, bond_features]
for label in label_columns:
row.append(labels[label].detach().numpy().tolist())
rows.append(row)
table = pa.table(
[*zip(*rows)],
names=required_columns + label_columns,
)
return table
[docs]class DGLMoleculeDataLoader(DataLoader):
def __init__(
self,
dataset: typing.Union[DGLMoleculeDataset, _LazyDGLMoleculeDataset, ConcatDataset],
batch_size: typing.Optional[int] = 1,
**kwargs,
):
super().__init__(
dataset=dataset,
batch_size=batch_size,
collate_fn=self._collate,
**kwargs,
)
@staticmethod
def _collate(graph_entries: typing.List[DGLMoleculeDatasetEntry]):
if isinstance(graph_entries[0], DGLMolecule):
graph_entries = [graph_entries]
molecules, labels = zip(*graph_entries)
batched_molecules = DGLMoleculeBatch.from_dgl_molecules(molecules)
batched_labels = defaultdict(list)
for molecule_labels in labels:
for label_name, label_value in molecule_labels.items():
batched_labels[label_name].append(label_value.reshape(-1, 1))
batched_labels = {k: torch.vstack(v) for k, v in batched_labels.items()}
return batched_molecules, batched_labels