Training a NAGL model
Preparing a dataset
Preparing a dataset for training is just as important as designing the model itself. The dataset must, at a minimum:
Cover the entire space of molecules that you want to predict the properties of
Contain enough examples that the model can learn the desired function
Treat the property you want to model consistently
Once you have your overall dataset, you should divide it into training, validation and test datasets to evaluate the quality of fit of your model as it trains. It’s important that all three buckets share coverage of the entire space:
training: About 80% of your data. Model parameters are fit directly to these data.
validation: About 10% of your data. Used to validate the model as it is iteratively fitted.
test: About 10% of your data. Used to test the final model against data it has never “seen” before.
Datasets are constructed by producing OpenFF Toolkit Molecule
objects with the appropriate charges, loading them into MoleculeRecord
objects, and storing them in a SQLite database:
from openff.nagl.storage.record import MoleculeRecord
from openff.nagl.storage import MoleculeStore
from openff.toolkit import Molecule
records = []
for smiles in [
"C",
"CC",
"CCC",
...
]:
# Create the Molecule
molecule = Molecule.from_smiles(smiles)
# Compute example partial charges
molecule.generate_conformers(n_conformers=1)
molecule.assign_partial_charges("am1bcc")
# Record results
records.append(
MoleculeRecord.from_precomputed_openff(
molecule,
partial_charge_method="am1bcc"
)
)
# Save the dataset
MoleculeStore("dataset.sqlite").store(records)
Loading a dataset
Datasets are loaded for training with the DGLMoleculeLightningDataModule
class. Objects of this class require the featurization schema used in NAGL models, as well as paths to the datasets and batch sizes for loading the data.
from openff.nagl.nn.dataset import DGLMoleculeLightningDataModule
data_module = DGLMoleculeLightningDataModule(
atom_features=atom_features,
bond_features=bond_features,
partial_charge_method="am1bcc",
training_set_paths=["training_data.sqlite"],
validation_set_paths=["validation_data.sqlite"],
test_set_paths=["test_data.sqlite"],
training_batch_size=1000,
validation_batch_size=1000,
test_batch_size=1000,
)
Training
The PyTorch Lightning Trainer
class greatly simplifies training a model. It’s a simple procedure to construct the Trainer
, fit the model to the data, test it, and save the fitted model:
from pytorch_lightning import Trainer
trainer = Trainer(max_epochs=200)
trainer.checkpoint_callback.monitor = "val_loss"
trainer.fit(
model,
datamodule=data_module,
)
trainer.test(model, data_module)
model.save("model.pt")