Prepare a NAGL dataset for training

Training a GCN requires a collection of examples that the GCN should reproduce and interpolate between. This notebook describes how to prepare such a dataset for predicting partial charges.

Imports

from pathlib import Path

from tqdm import tqdm

from openff.toolkit.topology import Molecule

from openff.nagl.label.dataset import LabelledDataset
from openff.nagl.label.labels import LabelCharges

Choosing our molecules

The simplest way to specify the molecules in our dataset is with SMILES, though anything you can load into an OpenFF Molecule is fair game. For instance, with the Molecule.from_file() method you could load partial charges from SDF files. But for this example, we’ll have NAGL generate our charges, so we can just provide the SMILES themselves:

alkanes_smiles = Path("alkanes.smi").read_text().splitlines()
alkanes_smiles
['C',
 'CC',
 'CCC',
 'CCCC',
 'CC(C)C',
 'CCCCC',
 'CC(C)CC',
 'CCCCCC',
 'CC(C)CCC',
 'CC(CC)CC']

Generating a LabelledDataset

A LabelledDataset is a wrapper over an Apache Arrow Dataset to make it easy to generate data. When we train GNN models, the data is read directly as an Arrow dataset, so there is no need to use a LabelledDataset to generate your data other than convenience. Here we demonstrate the conveniences of a LabelledDataset.

dataset = LabelledDataset.from_smiles(
    "labelled_alkanes",  # path to save to
    alkanes_smiles,
    mapped=False,
    overwrite_existing=True,
)
dataset.to_pandas()
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pandas/core/computation/expressions.py:21: UserWarning: Pandas requires version '2.8.4' or newer of 'numexpr' (version '2.7.3' currently installed).
  from pandas.core.computation.check import NUMEXPR_INSTALLED
mapped_smiles
0 [C:1]([H:2])([H:3])([H:4])[H:5]
1 [C:1]([C:2]([H:6])([H:7])[H:8])([H:3])([H:4])[...
2 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([H:7])...
3 [C:1]([C:2]([C:3]([C:4]([H:12])([H:13])[H:14])...
4 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](...
5 [C:1]([C:2]([C:3]([C:4]([C:5]([H:15])([H:16])[...
6 [C:1]([C:2]([C:3]([H:10])([H:11])[H:12])([C:4]...
7 [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([H:18])([H...
8 [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]...
9 [C:1]([C:2]([C:3]([C:4]([H:13])([H:14])[H:15])...
# path to directory containing parquet files of dataset
dataset.source
'labelled_alkanes'
# actual files of the dataset
dataset.dataset.files
['labelled_alkanes/part-0.parquet']

Generating charges

NAGL can generate AM1-BCC and AM1-Mulliken charges automatically with the OpenFF Toolkit. exist_ok defines whether to error if the charge_column is already present in the dataset. Normally we want this to be False, but it’s set to True here in case we run the cell multiple times.

am1bcc_labeller = LabelCharges(
    charge_method="am1bcc",
    charge_column="am1bcc_charges",
    exist_ok=True,
)
am1_labeller = LabelCharges(
    charge_method="am1-mulliken",
    charge_column="am1_charges",
    exist_ok=True,
)
    

dataset.apply_labellers(
    [am1_labeller, am1bcc_labeller],
    verbose=True,
)
dataset.to_pandas()
Applying labellers to batches: 0it [00:00, ?it/s]

Assigning charges:   0%|          | 0/10 [00:00<?, ?it/s]


Assigning charges:  10%|█         | 1/10 [00:00<00:06,  1.49it/s]


Assigning charges:  20%|██        | 2/10 [00:01<00:06,  1.32it/s]


Assigning charges:  30%|███       | 3/10 [00:02<00:06,  1.10it/s]


Assigning charges:  40%|████      | 4/10 [00:03<00:05,  1.01it/s]


Assigning charges:  50%|█████     | 5/10 [00:05<00:05,  1.12s/it]


Assigning charges:  60%|██████    | 6/10 [00:07<00:06,  1.52s/it]


Assigning charges:  70%|███████   | 7/10 [00:10<00:06,  2.08s/it]


Assigning charges:  80%|████████  | 8/10 [00:12<00:04,  2.07s/it]


Assigning charges:  90%|█████████ | 9/10 [00:16<00:02,  2.59s/it]


Assigning charges: 100%|██████████| 10/10 [00:19<00:00,  2.80s/it]

Assigning charges: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]


Assigning charges:   0%|          | 0/10 [00:00<?, ?it/s]


Assigning charges:  10%|█         | 1/10 [00:00<00:06,  1.33it/s]


Assigning charges:  20%|██        | 2/10 [00:01<00:07,  1.13it/s]


Assigning charges:  30%|███       | 3/10 [00:02<00:07,  1.01s/it]


Assigning charges:  40%|████      | 4/10 [00:04<00:06,  1.12s/it]


Assigning charges:  50%|█████     | 5/10 [00:05<00:06,  1.27s/it]


Assigning charges:  60%|██████    | 6/10 [00:08<00:06,  1.63s/it]


Assigning charges:  70%|███████   | 7/10 [00:11<00:06,  2.21s/it]


Assigning charges:  80%|████████  | 8/10 [00:14<00:05,  2.60s/it]


Assigning charges:  90%|█████████ | 9/10 [00:19<00:03,  3.10s/it]


Assigning charges: 100%|██████████| 10/10 [00:22<00:00,  3.33s/it]

Assigning charges: 100%|██████████| 10/10 [00:22<00:00,  2.29s/it]
Applying labellers to batches: 1it [00:42, 42.56s/it]
Applying labellers to batches: 1it [00:42, 42.56s/it]

mapped_smiles am1_charges am1bcc_charges
0 [C:1]([H:2])([H:3])([H:4])[H:5] [-0.2656, 0.0664, 0.0664, 0.0664, 0.0664] [-0.1084, 0.0271, 0.0271, 0.0271, 0.0271]
1 [C:1]([C:2]([H:6])([H:7])[H:8])([H:3])([H:4])[... [-0.21225, -0.21225, 0.07075, 0.07075, 0.07075... [-0.09435, -0.09435, 0.03145, 0.03145, 0.03145...
2 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([H:7])... [-0.211, -0.16, -0.211, 0.072, 0.072, 0.071, 0... [-0.09310018181818182, -0.08140018181818182, -...
3 [C:1]([C:2]([C:3]([C:4]([H:12])([H:13])[H:14])... [-0.21028571428571427, -0.15928571428571428, -... [-0.09238585714285714, -0.08068585714285714, -...
4 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... [-0.20735714285714285, -0.10935714285714285, -... [-0.08945692857142856, -0.07005692857142856, -...
5 [C:1]([C:2]([C:3]([C:4]([C:5]([H:15])([H:16])[... [-0.21, -0.159, -0.158, -0.159, -0.21, 0.07200... [-0.09210011764705882, -0.08040011764705882, -...
6 [C:1]([C:2]([C:3]([H:10])([H:11])[H:12])([C:4]... [-0.207, -0.10599999999999998, -0.207, -0.153,... [-0.08909988235294117, -0.06669988235294116, -...
7 [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([H:18])([H... [-0.21009999999999998, -0.1591, -0.1581, -0.15... [-0.0922, -0.0805, -0.0795, -0.0795, -0.0805, ...
8 [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... [-0.20694999999999997, -0.10694999999999998, -... [-0.08955004999999999, -0.06765004999999999, -...
9 [C:1]([C:2]([C:3]([C:4]([H:13])([H:14])[H:15])... [-0.20794999999999997, -0.10594999999999999, -... [-0.09005, -0.06665, -0.07785, -0.09205, -0.07...

If you have your own charges to add, use the LabelledDataset.append_columns. Warning: this does not run any checks as to the validity of the charges, such as the length or type!

dataset.append_columns(
    columns={
        "custom_charges": [
            [i]
            for i in range(len(alkanes_smiles))
        ]
    }
)
dataset.to_pandas()
mapped_smiles am1_charges am1bcc_charges custom_charges
0 [C:1]([H:2])([H:3])([H:4])[H:5] [-0.2656, 0.0664, 0.0664, 0.0664, 0.0664] [-0.1084, 0.0271, 0.0271, 0.0271, 0.0271] [0]
1 [C:1]([C:2]([H:6])([H:7])[H:8])([H:3])([H:4])[... [-0.21225, -0.21225, 0.07075, 0.07075, 0.07075... [-0.09435, -0.09435, 0.03145, 0.03145, 0.03145... [1]
2 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([H:7])... [-0.211, -0.16, -0.211, 0.072, 0.072, 0.071, 0... [-0.09310018181818182, -0.08140018181818182, -... [2]
3 [C:1]([C:2]([C:3]([C:4]([H:12])([H:13])[H:14])... [-0.21028571428571427, -0.15928571428571428, -... [-0.09238585714285714, -0.08068585714285714, -... [3]
4 [C:1]([C:2]([C:3]([H:9])([H:10])[H:11])([C:4](... [-0.20735714285714285, -0.10935714285714285, -... [-0.08945692857142856, -0.07005692857142856, -... [4]
5 [C:1]([C:2]([C:3]([C:4]([C:5]([H:15])([H:16])[... [-0.21, -0.159, -0.158, -0.159, -0.21, 0.07200... [-0.09210011764705882, -0.08040011764705882, -... [5]
6 [C:1]([C:2]([C:3]([H:10])([H:11])[H:12])([C:4]... [-0.207, -0.10599999999999998, -0.207, -0.153,... [-0.08909988235294117, -0.06669988235294116, -... [6]
7 [C:1]([C:2]([C:3]([C:4]([C:5]([C:6]([H:18])([H... [-0.21009999999999998, -0.1591, -0.1581, -0.15... [-0.0922, -0.0805, -0.0795, -0.0795, -0.0805, ... [7]
8 [C:1]([C:2]([C:3]([H:11])([H:12])[H:13])([C:4]... [-0.20694999999999997, -0.10694999999999998, -... [-0.08955004999999999, -0.06765004999999999, -... [8]
9 [C:1]([C:2]([C:3]([C:4]([H:13])([H:14])[H:15])... [-0.20794999999999997, -0.10594999999999999, -... [-0.09005, -0.06665, -0.07785, -0.09205, -0.07... [9]