Train a GNN directly to an electric field

To execute this example fully, the following packages are required.

  • openff-nagl

  • openff-recharge

  • openff-qcsubmit

  • psi4

However, if you wish to just follow along the training part without first creating the training datasets yourself, you can get away with just openff-nagl installed and simply load the training/validation data from the provided .parquet files. The commands are provided at the end of the “Generate and format training data” section, but commented out.

import tqdm

from qcportal import PortalClient
from openff.units import unit

from openff.toolkit import Molecule
from openff.qcsubmit.results import BasicResultCollection
from openff.recharge.esp.storage import MoleculeESPRecord
from openff.recharge.esp.qcresults import from_qcportal_results
from openff.recharge.grids import MSKGridSettings
from openff.recharge.utilities.geometry import compute_vector_field

import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np

import torch

Generate and format training data

Downloading from QCArchive

First, we will create training data. We’ll download a smaller training set for the purposes of this example.

qc_client = PortalClient("https://api.qcarchive.molssi.org:443", cache_dir=".")

# download dataset from QCPortal
br_esps_collection = BasicResultCollection.from_server(
    client=qc_client,
    datasets="OpenFF multi-Br ESP Fragment Conformers v1.1",
    spec_name="HF/6-31G*",
)

records_and_molecules = br_esps_collection.to_records()

Converting to MoleculeESPRecords

Now we convert to OpenFF Recharge records and compute the ESPs and electric fields of each molecule.

# Create OpenFF Recharge MoleculeESPRecords
grid_settings = MSKGridSettings()

# this can take a while; set records_and_molecules[:10]
# to only use the first 10
molecule_esp_records = [
    from_qcportal_results(
        qc_result=qcrecord,
        qc_molecule=qcrecord.molecule,
        qc_keyword_set=qcrecord.specification.keywords,
        grid_settings=grid_settings,
        compute_field=True
    )
    for qcrecord, _ in tqdm.tqdm(records_and_molecules[:50])
]
  0%|          | 0/50 [00:00<?, ?it/s]
  2%|▏         | 1/50 [00:13<10:57, 13.42s/it]
  4%|▍         | 2/50 [00:23<09:08, 11.42s/it]
  6%|▌         | 3/50 [00:35<09:02, 11.55s/it]
  8%|▊         | 4/50 [00:44<08:07, 10.60s/it]
 10%|█         | 5/50 [00:52<07:17,  9.73s/it]
 12%|█▏        | 6/50 [01:00<06:46,  9.24s/it]
 14%|█▍        | 7/50 [01:09<06:33,  9.15s/it]
 16%|█▌        | 8/50 [01:17<06:10,  8.82s/it]
 18%|█▊        | 9/50 [01:28<06:23,  9.36s/it]
 20%|██        | 10/50 [01:35<05:46,  8.66s/it]
 22%|██▏       | 11/50 [01:43<05:26,  8.37s/it]
 24%|██▍       | 12/50 [01:50<05:04,  8.01s/it]
 26%|██▌       | 13/50 [01:55<04:27,  7.22s/it]
 28%|██▊       | 14/50 [02:00<03:57,  6.59s/it]
 30%|███       | 15/50 [02:08<04:05,  7.02s/it]
 32%|███▏      | 16/50 [02:16<04:04,  7.20s/it]
 34%|███▍      | 17/50 [02:20<03:25,  6.23s/it]
 36%|███▌      | 18/50 [02:25<03:05,  5.80s/it]
 38%|███▊      | 19/50 [02:29<02:40,  5.18s/it]
 40%|████      | 20/50 [02:38<03:13,  6.46s/it]
 42%|████▏     | 21/50 [02:47<03:26,  7.14s/it]
 44%|████▍     | 22/50 [02:55<03:27,  7.42s/it]
 46%|████▌     | 23/50 [03:10<04:25,  9.83s/it]
 48%|████▊     | 24/50 [03:25<04:51, 11.22s/it]
 50%|█████     | 25/50 [03:32<04:10, 10.03s/it]
 52%|█████▏    | 26/50 [03:36<03:18,  8.27s/it]
 54%|█████▍    | 27/50 [03:42<02:51,  7.47s/it]
 56%|█████▌    | 28/50 [03:47<02:32,  6.92s/it]
 58%|█████▊    | 29/50 [03:54<02:21,  6.76s/it]
 60%|██████    | 30/50 [04:01<02:15,  6.76s/it]
 62%|██████▏   | 31/50 [04:05<01:55,  6.08s/it]
 64%|██████▍   | 32/50 [04:10<01:45,  5.85s/it]
 66%|██████▌   | 33/50 [04:16<01:37,  5.71s/it]
 68%|██████▊   | 34/50 [04:21<01:28,  5.56s/it]
 70%|███████   | 35/50 [04:27<01:25,  5.72s/it]
 72%|███████▏  | 36/50 [04:33<01:19,  5.67s/it]
 74%|███████▍  | 37/50 [04:39<01:17,  5.93s/it]
 76%|███████▌  | 38/50 [04:46<01:13,  6.13s/it]
 78%|███████▊  | 39/50 [04:52<01:09,  6.28s/it]
 80%|████████  | 40/50 [04:59<01:03,  6.36s/it]
 82%|████████▏ | 41/50 [05:05<00:57,  6.35s/it]
 84%|████████▍ | 42/50 [05:12<00:52,  6.53s/it]
 86%|████████▌ | 43/50 [05:19<00:46,  6.61s/it]
 88%|████████▊ | 44/50 [05:25<00:38,  6.43s/it]
 90%|█████████ | 45/50 [05:31<00:31,  6.37s/it]
 92%|█████████▏| 46/50 [05:38<00:25,  6.47s/it]
 94%|█████████▍| 47/50 [05:44<00:19,  6.44s/it]
 96%|█████████▌| 48/50 [05:50<00:12,  6.18s/it]
 98%|█████████▊| 49/50 [05:57<00:06,  6.48s/it]
100%|██████████| 50/50 [06:04<00:00,  6.56s/it]
100%|██████████| 50/50 [06:04<00:00,  7.29s/it]

Convert to PyArrow dataset

NAGL reads in and trains to data from PyArrow tables. Below we do some conversion of each electric field to fit a basic GeneralLinearFit target, which fits the equation Ax = b. To avoid carrying too many data points around, we can do some postprocessing by first flattening the electric field matrix from 3 dimensions to 2, then multiplying by the transpose of A.

\[\mathbf{A}\vec{x} = \vec{b}\]
\[\mathbf{A^{T}}\mathbf{A}\vec{x} = \mathbf{A^{T}}\vec{b}\]
pyarrow_entries = []
for molecule_esp_record in tqdm.tqdm(molecule_esp_records):
    electric_field = molecule_esp_record.electric_field # in atomic units
    grid = molecule_esp_record.grid_coordinates * unit.angstrom
    conformer = molecule_esp_record.conformer * unit.angstrom

    vector_field = compute_vector_field(
        conformer.m_as(unit.bohr),  # shape: M x 3
        grid.m_as(unit.bohr),  # shape: N x 3
    ) # N x 3 x M

    # postprocess so we're not carrying around millions of floats
    # firstly flatten out N x 3 -> 3N x M
    vector_field_2d = np.concatenate(vector_field, axis=0)
    electric_field_1d = np.concatenate(electric_field, axis=0)

    # now multiply by vector_field_2d's transpose
    new_precursor_matrix = vector_field_2d.T @ vector_field_2d
    new_field_vector = vector_field_2d.T @ electric_field_1d

    n_atoms = conformer.shape[0]
    assert new_precursor_matrix.shape == (n_atoms, n_atoms)
    assert new_field_vector.shape == (n_atoms,)

    # create entry. These columns are essential
    # mapped_smiles is essential for every target
    entry = {
        "mapped_smiles": molecule_esp_record.tagged_smiles,
        "precursor_matrix": new_precursor_matrix.flatten().tolist(),
        "prediction_vector": new_field_vector.tolist()
    }
    pyarrow_entries.append(entry)


# arbitrarily split into training and validation datasets
training_pyarrow_entries = pyarrow_entries[:-10]
validation_pyarrow_entries = pyarrow_entries[-10:]

training_table = pa.Table.from_pylist(training_pyarrow_entries)
validation_table = pa.Table.from_pylist(validation_pyarrow_entries)
training_table
  0%|          | 0/50 [00:00<?, ?it/s]
 66%|██████▌   | 33/50 [00:00<00:00, 329.63it/s]
100%|██████████| 50/50 [00:00<00:00, 315.72it/s]

pyarrow.Table
mapped_smiles: string
precursor_matrix: list<item: double>
  child 0, item: double
prediction_vector: list<item: double>
  child 0, item: double
----
mapped_smiles: [["[H:1][N:2]1[C:3]([Br:4])=[N:5][c:6]2[n:7][c:8]([Br:9])[c:10]([Br:11])[n:12][c:13]21","[H:1][N:2]([H:3])[C:4]1([C:5](=[O:6])[O-:7])[C:8]([H:9])([H:10])[C:11]2([H:12])[C:13]([Br:14])([Br:15])[C:16]2([H:17])[C:18]1([H:19])[H:20]","[H:1][C:2]1([H:3])[C:4]2([H:5])[C:6]([Br:7])([Br:8])[C:9]2([H:10])[C:11]([H:12])([H:13])[C:14]1([C:15](=[O:16])[O-:17])[N+:18]([H:19])([H:20])[H:21]","[H:1][N:2]1[C:3]([H:4])([H:5])[C:6]2([H:7])[C:8]([Br:9])([Br:10])[C:11]2([H:12])[C:13]1([H:14])[C:15](=[O:16])[O-:17]","[H:1][C:2]12[C:3]([Br:4])([Br:5])[C:6]1([H:7])[C:8]([H:9])([C:10](=[O:11])[O-:12])[N+:13]([H:14])([H:15])[C:16]2([H:17])[H:18]",...,"[H:1][c:2]1[c:3]([Br:4])[c:5]([H:6])[c:7]2[c:8]([c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[H:19]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[N:19]([H:20])[H:21]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[N+:19]([H:20])([H:21])[H:22]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[Br:10])[N:11]([H:12])[C:13]([H:14])([H:15])[C:16]([H:17])([H:18])[C:19]2([H:20])[H:21]","[H:1][c:2]1[c:3]([Br:4])[c:5]([H:6])[c:7]2[c:8]([c:9]1[Br:10])[C:11]([H:12])([H:13])[N:14]([H:15])[C:16]([H:17])([H:18])[C:19]2([H:20])[H:21]"]]
precursor_matrix: [[[0.5692106565580053,0.4085897652889512,0.284249916238166,0.20940517539253248,0.20686844494784193,...,0.11433558179284603,0.23263226801337633,0.1371485227635511,0.3023917581405904,0.32730801486131705],[0.5260442629899051,0.38089630545801045,0.35215718214227837,0.24910585545578626,0.17644186782494933,...,0.22243116268067484,0.19814033340715625,0.300593182954176,0.3088962938720413,0.4158481032091346],...,[0.5289925354233009,0.37414187483676303,0.28313431180338455,0.27734985432011455,0.17944395434503416,...,0.2880391596291975,0.2321472722241459,0.3857071059607491,0.3665630767875717,0.5677900784441218],[0.4983240011637843,0.3447984584130681,0.25468324548837756,0.20316511244549332,0.17595041719921767,...,0.27107542957261366,0.21722450631808865,0.3626687533961094,0.3447698917574723,0.5347596590639927]]]
prediction_vector: [[[0.11831342820663876,0.07052253574331227,0.03443373566615229,0.031423455167753835,-0.006535801800470039,...,-0.015353416908781552,0.0058843508732542345,-0.011419043977358915,0.015153163796349949,0.039054645977152046],[-0.1168859699903578,-0.16262446070843317,-0.17209952898209305,-0.16772575000042161,-0.23935707380099153,...,-0.11399722466179721,-0.09336325633851358,-0.12839560498881492,-0.12752534734377857,-0.10652701620291485],...,[0.01076983677321753,-0.005588997214517595,0.007231581443923954,0.03031856180479904,0.008379193273740145,...,0.04214775639020978,0.03202866265802594,0.0301970861903761,0.03768884431038124,0.03714789835823986],[0.019492995260861956,0.005150669698426461,0.000872755565492754,-0.01489613518860242,0.015886122871734236,...,0.03648331002928913,0.03218903273689575,0.04016640969171912,0.046849895898038244,0.04789045546505114]]]
pq.write_table(training_table, "training_dataset_table.parquet")
pq.write_table(validation_table, "validation_dataset_table.parquet")

# # to read back in -- note, the files saved here give the full dataset, not the 50 record subset
# training_table = pq.read_table("training_dataset_table.parquet")
# validation_table = pq.read_table("validation_dataset_table.parquet")

Set up for training a GNN

from openff.nagl.config import (
    TrainingConfig,
    OptimizerConfig,
    ModelConfig,
    DataConfig
)
from openff.nagl.config.model import (
    ConvolutionModule, ReadoutModule,
    ConvolutionLayer, ForwardLayer,
)
from openff.nagl.config.data import DatasetConfig
from openff.nagl.training.training import TrainingGNNModel
from openff.nagl.features.atoms import (
    AtomicElement,
    AtomConnectivity,
    AtomInRingOfSize,
    AtomAverageFormalCharge,
)

from openff.nagl.training.loss import GeneralLinearFitTarget

Defining the training config

Defining a ModelConfig

First we define a ModelConfig. This can be done in Python, but in practice it is probably easier to define the model in a YAML file and load it with ModelConfig.from_yaml.

atom_features = [
    AtomicElement(categories=["H", "C", "N", "O", "F", "Br", "S", "P", "I"]),
    AtomConnectivity(categories=[1, 2, 3, 4, 5, 6]),
    AtomInRingOfSize(ring_size=3),
    AtomInRingOfSize(ring_size=4),
    AtomInRingOfSize(ring_size=5),
    AtomInRingOfSize(ring_size=6),
    AtomAverageFormalCharge(),
]

# define our convolution module
convolution_module = ConvolutionModule(
    architecture="SAGEConv",
    # construct 6 layers with dropout 0 (default),
    # hidden feature size 512, and ReLU activation function
    # these layers can also be individually specified,
    # but we just duplicate the layer 6 times for identical layers
    layers=[
        ConvolutionLayer(
            hidden_feature_size=512,
            activation_function="ReLU",
            aggregator_type="mean"
        )
    ] * 6,
)

# define our readout module/s
# multiple are allowed but let's focus on charges
readout_modules = {
    # key is the name of output property, any naming is allowed
    "charges": ReadoutModule(
        pooling="atoms",
        postprocess="compute_partial_charges",
        # 2 layers
        layers=[
            ForwardLayer(
                hidden_feature_size=512,
                activation_function="ReLU",
            )
        ] * 2,
    )
}

# bring it all together
model_config = ModelConfig(
    version="0.1",
    atom_features=atom_features,
    convolution=convolution_module,
    readouts=readout_modules,
)

Defining a DataConfig

We can then define our dataset configs. Here we also have to specify our training targets.

target = GeneralLinearFitTarget(
    # what we're using to evaluate loss
    target_label="prediction_vector",
    # the output of the GNN we use to evaluate loss
    prediction_label="charges",
    # the column in the table that contains the precursor matrix
    design_matrix_column="precursor_matrix",
    # how we want to evaluate loss, e.g. RMSE, MSE, ...
    metric="rmse",
    # how much to weight this target
    # helps with scaling in multi-target optimizations
    weight=1,
    denominator=1,
)

training_to_electric_field = DatasetConfig(
    sources=["training_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)
validating_to_electric_field = DatasetConfig(
    sources=["validation_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)

# bringing it together
data_config = DataConfig(
    training=training_to_electric_field,
    validation=validating_to_electric_field
)

Defining an OptimizerConfig

The optimizer config is relatively simple; the only moving part here currently is the learning rate.

optimizer_config = OptimizerConfig(optimizer="Adam", learning_rate=0.001)

Creating a TrainingConfig

training_config = TrainingConfig(
    model=model_config,
    data=data_config,
    optimizer=optimizer_config
)

Creating a TrainingGNNModel

Now we can create a TrainingGNNModel, which allows easy training of a GNNModel. The GNNModel can be accessed through TrainingGNNModel.model.

training_model = TrainingGNNModel(training_config)
training_model
TrainingGNNModel(
  (model): GNNModel(
    (convolution_module): ConvolutionModule(
      (gcn_layers): SAGEConvStack(
        (0): SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=20, out_features=512, bias=False)
          (fc_self): Linear(in_features=20, out_features=512, bias=True)
        )
        (1-5): 5 x SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=512, out_features=512, bias=False)
          (fc_self): Linear(in_features=512, out_features=512, bias=True)
        )
      )
    )
    (readout_modules): ModuleDict(
      (charges): ReadoutModule(
        (pooling_layer): PoolAtomFeatures()
        (readout_layers): SequentialLayers(
          (0): Linear(in_features=512, out_features=512, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=512, out_features=512, bias=True)
          (4): ReLU()
          (5): Dropout(p=0.0, inplace=False)
          (6): Linear(in_features=512, out_features=2, bias=True)
          (7): Identity()
          (8): Dropout(p=0.0, inplace=False)
        )
        (postprocess_layer): ComputePartialCharges()
      )
    )
  )
)

We can look at the initial capabilities of the model by comparing its charges to AM1-BCC charges. They’re pretty bad.

test_molecule = Molecule.from_smiles("CCCBr")
test_molecule.assign_partial_charges("am1bcc")
reference_charges = test_molecule.partial_charges.m

# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    nagl_charges_1 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["charges"]

# switch back to training mode
training_model.model.train()

# compare charges
differences = reference_charges - nagl_charges_1
differences
array([-0.42479179, -0.40645674, -0.54374945,  0.20729126, -0.06090091,
       -0.06090091, -0.06090091,  0.24849565,  0.24849565,  0.42671014,
        0.42671014])

Training the model

We use Pytorch Lightning to train.

import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[TQDMProgressBar()], # add progress bar
    accelerator="cpu"
)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
datamodule = training_model.create_data_module(verbose=False)
trainer.fit(
    training_model,
    datamodule=datamodule,
)
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/40 [00:00<?, ?it/s]


Featurizing batch:  10%|█         | 4/40 [00:00<00:00, 38.50it/s]


Featurizing batch:  22%|██▎       | 9/40 [00:00<00:00, 44.94it/s]


Featurizing batch:  35%|███▌      | 14/40 [00:00<00:00, 43.14it/s]


Featurizing batch:  50%|█████     | 20/40 [00:00<00:00, 44.97it/s]


Featurizing batch:  62%|██████▎   | 25/40 [00:00<00:00, 41.97it/s]


Featurizing batch:  75%|███████▌  | 30/40 [00:00<00:00, 42.19it/s]


Featurizing batch:  90%|█████████ | 36/40 [00:00<00:00, 45.97it/s]

Featurizing batch: 100%|██████████| 40/40 [00:00<00:00, 44.31it/s]
Featurizing dataset: 1it [00:00,  1.09it/s]
Featurizing dataset: 1it [00:00,  1.09it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]


Featurizing batch:  50%|█████     | 5/10 [00:00<00:00, 47.52it/s]


Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 45.48it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 45.35it/s]
Featurizing dataset: 1it [00:00,  4.45it/s]
Featurizing dataset: 1it [00:00,  4.43it/s]

2024-12-10 00:46:02.785540: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/40 [00:00<?, ?it/s]


Featurizing batch:  10%|█         | 4/40 [00:00<00:00, 38.43it/s]


Featurizing batch:  25%|██▌       | 10/40 [00:00<00:00, 45.89it/s]


Featurizing batch:  38%|███▊      | 15/40 [00:00<00:00, 44.07it/s]


Featurizing batch:  52%|█████▎    | 21/40 [00:00<00:00, 46.08it/s]


Featurizing batch:  65%|██████▌   | 26/40 [00:00<00:00, 45.10it/s]


Featurizing batch:  78%|███████▊  | 31/40 [00:00<00:00, 44.22it/s]


Featurizing batch:  90%|█████████ | 36/40 [00:00<00:00, 44.84it/s]

Featurizing batch: 100%|██████████| 40/40 [00:00<00:00, 44.84it/s]
Featurizing dataset: 1it [00:00,  1.11it/s]
Featurizing dataset: 1it [00:00,  1.11it/s]

Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]


Featurizing batch:  50%|█████     | 5/10 [00:00<00:00, 44.92it/s]


Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 43.47it/s]

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 43.29it/s]
Featurizing dataset: 1it [00:00,  4.25it/s]
Featurizing dataset: 1it [00:00,  4.23it/s]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type     | Params | Mode 
-------------------------------------------
0 | model | GNNModel | 3.2 M  | train
-------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.685    Total estimated model params size (MB)
47        Modules in train mode
0         Modules in eval mode
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.

We can now check the charges again. They should have improved, especially if you used a larger dataset (note: results may vary with the small 50-molecule dataset this notebook has chosen to use, for reasons of speed).

# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    nagl_charges_2 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["charges"]

differences_after_training = reference_charges - nagl_charges_2
differences_after_training
array([ 0.00945532, -0.02115238,  0.00968652, -0.07503167, -0.0402597 ,
       -0.0402597 , -0.0402597 ,  0.01262529,  0.01262529,  0.08628529,
        0.08628529])