Source code for openff.bespokefit.executor.services.qcgenerator.cache
import hashlib
from typing import TypeVar, Union
import redis
from openff.toolkit.topology import Molecule
from openff.bespokefit.executor.services.qcgenerator import worker
from openff.bespokefit.schema.tasks import HessianTask, OptimizationTask, Torsion1DTask
from openff.bespokefit.utilities.molecule import canonical_order_atoms
_T = TypeVar("_T", HessianTask, OptimizationTask, Torsion1DTask)
def _canonicalize_task(task: _T) -> _T:
task = task.copy(deep=True)
# Ensure the SMILES has a canonical ordering to help ensure cache hits.
canonical_molecule = canonical_order_atoms(
Molecule.from_smiles(task.smiles, allow_undefined_stereo=True)
)
if isinstance(task, Torsion1DTask):
map_to_atom_index = {
j: i for i, j in canonical_molecule.properties["atom_map"].items()
}
central_atom_indices = sorted(
map_to_atom_index[task.central_bond[i]] for i in (0, 1)
)
canonical_molecule.properties["atom_map"] = {
atom_index: (i + 1) for i, atom_index in enumerate(central_atom_indices)
}
canonical_smiles = canonical_molecule.to_smiles(
isomeric=True, explicit_hydrogens=True, mapped=True
)
task.central_bond = (1, 2)
else:
canonical_smiles = canonical_molecule.to_smiles(
isomeric=True, explicit_hydrogens=True, mapped=False
)
task.smiles = canonical_smiles
return task
[docs]def cached_compute_task(
task: Union[HessianTask, OptimizationTask, Torsion1DTask],
redis_connection: redis.Redis,
) -> str:
"""Checks to see if a QC task has already been executed and if not send it to a
worker.
"""
if isinstance(task, Torsion1DTask):
compute = worker.compute_torsion_drive
elif isinstance(task, OptimizationTask):
compute = worker.compute_optimization
elif isinstance(task, HessianTask):
compute = worker.compute_hessian
else:
raise NotImplementedError()
# Canonicalize the task to improve the cache hit rate.
task = _canonicalize_task(task)
task_hash = hashlib.sha512(task.json().encode()).hexdigest()
task_id = redis_connection.hget("qcgenerator:task-ids", task_hash)
if task_id is not None:
return task_id.decode()
task_id = compute.delay(task_json=task.json()).id
redis_connection.hset("qcgenerator:types", task_id, task.type)
# Make sure to only set the hash after the type is set in case the connection
# goes down before this information is entered and subsequently discarded.
redis_connection.hset("qcgenerator:task-ids", task_hash, task_id)
return task_id