import abc
import json
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import httpx
from openff.fragmenter.fragment import Fragment, FragmentationResult
from openff.toolkit.typing.engines.smirnoff import (
AngleHandler,
BondHandler,
ImproperTorsionHandler,
ParameterType,
ProperTorsionHandler,
vdWHandler,
)
from qcelemental.models import AtomicResult, OptimizationResult
from qcelemental.util import serialize
from qcengine.procedures.torsiondrive import TorsionDriveResult
from typing_extensions import Literal
from openff.bespokefit._pydantic import BaseModel, Field
from openff.bespokefit.executor.services import current_settings
from openff.bespokefit.executor.services.coordinator.utils import get_cached_parameters
from openff.bespokefit.executor.services.fragmenter.models import (
FragmenterGETResponse,
FragmenterPOSTBody,
FragmenterPOSTResponse,
)
from openff.bespokefit.executor.services.optimizer.models import (
OptimizerGETResponse,
OptimizerPOSTBody,
OptimizerPOSTResponse,
)
from openff.bespokefit.executor.services.qcgenerator.models import (
QCGeneratorGETPageResponse,
QCGeneratorPOSTBody,
QCGeneratorPOSTResponse,
)
from openff.bespokefit.executor.utilities.redis import (
connect_to_default_redis,
is_redis_available,
)
from openff.bespokefit.executor.utilities.typing import Status
from openff.bespokefit.schema.data import BespokeQCData, LocalQCData
from openff.bespokefit.schema.fitting import BespokeOptimizationSchema
from openff.bespokefit.schema.results import BespokeOptimizationResults
from openff.bespokefit.schema.smirnoff import (
AngleSMIRKS,
BondSMIRKS,
ImproperTorsionSMIRKS,
ProperTorsionSMIRKS,
VdWSMIRKS,
)
from openff.bespokefit.schema.targets import TargetSchema
from openff.bespokefit.schema.tasks import Torsion1DTask
from openff.bespokefit.utilities.smirks import (
ForceFieldEditor,
SMIRKSGenerator,
SMIRKSType,
get_cached_torsion_parameters,
)
if TYPE_CHECKING:
from openff.bespokefit.executor.services.coordinator.models import CoordinatorTask
class _Stage(BaseModel, abc.ABC):
type: Literal["base-stage"] = "base-stage"
status: Status = Field("waiting", description="The status of this stage.")
error: Optional[str] = Field(
None, description="The error raised, if any, while running this stage."
)
async def enter(self, task: "CoordinatorTask"):
try:
return await self._enter(task)
except BaseException as e: # lgtm [py/catch-base-exception]
self.status = "errored"
self.error = json.dumps(f"{e.__class__.__name__}: {str(e)}")
async def update(self):
try:
return await self._update()
except BaseException as e: # lgtm [py/catch-base-exception]
self.status = "errored"
self.error = json.dumps(f"{e.__class__.__name__}: {str(e)}")
@abc.abstractmethod
async def _enter(self, task: "CoordinatorTask"):
pass
@abc.abstractmethod
async def _update(self):
pass
[docs]class FragmentationStage(_Stage):
type: Literal["fragmentation"] = "fragmentation"
id: Optional[str] = Field(None, description="")
result: Optional[FragmentationResult] = Field(None, description="")
async def _enter(self, task: "CoordinatorTask"):
settings = current_settings()
async with httpx.AsyncClient() as client:
raw_response = await client.post(
f"http://127.0.0.1:"
f"{settings.BEFLOW_GATEWAY_PORT}"
f"{settings.BEFLOW_API_V1_STR}/"
f"{settings.BEFLOW_FRAGMENTER_PREFIX}",
data=FragmenterPOSTBody(
cmiles=task.input_schema.smiles,
fragmenter=task.input_schema.fragmentation_engine,
target_bond_smarts=task.input_schema.target_torsion_smirks,
).json(),
)
if raw_response.status_code != 200:
self.error = json.dumps(raw_response.text)
self.status = "errored"
return
contents = raw_response.text
post_response = FragmenterPOSTResponse.parse_raw(contents)
self.id = post_response.id
async def _update(self):
if self.status == "errored":
return
settings = current_settings()
async with httpx.AsyncClient() as client:
raw_response = await client.get(
f"http://127.0.0.1:"
f"{settings.BEFLOW_GATEWAY_PORT}"
f"{settings.BEFLOW_API_V1_STR}/"
f"{settings.BEFLOW_FRAGMENTER_PREFIX}/{self.id}"
)
if raw_response.status_code != 200:
self.error = json.dumps(raw_response.text)
self.status = "errored"
return
contents = raw_response.text
get_response = FragmenterGETResponse.parse_raw(contents)
self.result = get_response.result
if (
isinstance(self.result, FragmentationResult)
and len(self.result.fragments) == 0
):
self.error = json.dumps(
"No fragments could be generated for the parent molecule. This likely "
"means that the bespoke parameters that you have generated and are "
"trying to fit are invalid. Please raise an issue on the GitHub issue "
"tracker for further assistance."
)
self.status = "errored"
else:
self.error = get_response.error
self.status = get_response.status
[docs]class QCGenerationStage(_Stage):
type: Literal["qc-generation"] = "qc-generation"
ids: Optional[Dict[int, List[str]]] = Field(None, description="")
results: Optional[
Dict[str, Union[AtomicResult, OptimizationResult, TorsionDriveResult]]
] = Field(None, description="")
@staticmethod
def _generate_torsion_parameters(
fragmentation_result: FragmentationResult,
input_schema: BespokeOptimizationSchema,
) -> Tuple[List[ParameterType], List[Fragment]]:
"""
Generate torsion parameters for the fragments using any possible cached parameters.
Args:
fragmentation_result: The result of the fragmentation
input_schema: The input schema detailing the optimisation
Returns:
The list of generated smirks patterns including any cached values, and a list of fragments which require torsiondrives.
"""
settings = current_settings()
cached_torsions = None
if is_redis_available(
host=settings.BEFLOW_REDIS_ADDRESS, port=settings.BEFLOW_REDIS_PORT
):
redis_connection = connect_to_default_redis()
cached_force_field = get_cached_parameters(
fitting_schema=input_schema, redis_connection=redis_connection
)
if cached_force_field is not None:
cached_torsions = cached_force_field["ProperTorsions"].parameters
parent = fragmentation_result.parent_molecule
smirks_gen = SMIRKSGenerator(
initial_force_field=input_schema.initial_force_field,
generate_bespoke_terms=input_schema.smirk_settings.generate_bespoke_terms,
expand_torsion_terms=input_schema.smirk_settings.expand_torsion_terms,
target_smirks=[SMIRKSType.ProperTorsions],
)
new_smirks = []
fragments = []
for fragment_data in fragmentation_result.fragments:
central_bond = fragment_data.bond_indices
fragment_molecule = fragment_data.molecule
bespoke_smirks = smirks_gen.generate_smirks_from_fragment(
parent=parent,
fragment=fragment_molecule,
fragment_map_indices=central_bond,
)
if cached_torsions is not None:
smirks_to_add = []
for smirk in bespoke_smirks:
cached_smirk = get_cached_torsion_parameters(
molecule=fragment_molecule,
bespoke_parameter=smirk,
cached_parameters=cached_torsions,
)
if cached_smirk is not None:
smirks_to_add.append(cached_smirk)
if len(smirks_to_add) == len(bespoke_smirks):
# if we have the same number of parameters we are safe to transfer as they were fit together
new_smirks.extend(smirks_to_add)
else:
# the cached parameter was not fit with the correct number of other parameters
# so use new terms not cached
new_smirks.extend(bespoke_smirks)
fragments.append(fragment_data)
else:
new_smirks.extend(bespoke_smirks)
# only keep track of fragments which require QM calculations as they have no cached values
fragments.append(fragment_data)
return new_smirks, fragments
@staticmethod
async def _generate_parameters(
input_schema: BespokeOptimizationSchema,
fragmentation_result: Optional[FragmentationResult],
) -> List[Fragment]:
"""
Generate a list of parameters which are to be optimised, these are added to the input force field.
The parameters are also added to the parameter list in each stage corresponding to the stage where they will be fit.
"""
initial_force_field = ForceFieldEditor(input_schema.initial_force_field)
new_parameters = []
target_smirks = {*input_schema.target_smirks}
if SMIRKSType.ProperTorsions in target_smirks:
target_smirks.remove(SMIRKSType.ProperTorsions)
(
torsion_parameters,
fragment_jobs,
) = QCGenerationStage._generate_torsion_parameters(
fragmentation_result=fragmentation_result,
input_schema=input_schema,
)
new_parameters.extend(torsion_parameters)
else:
fragment_jobs = []
if len(target_smirks) > 0:
smirks_gen = SMIRKSGenerator(
initial_force_field=input_schema.initial_force_field,
generate_bespoke_terms=input_schema.smirk_settings.generate_bespoke_terms,
target_smirks=[*target_smirks],
smirks_layers=1,
)
parameters = smirks_gen.generate_smirks_from_molecule(
molecule=input_schema.molecule
)
new_parameters.extend(parameters)
# add all new terms to the input force field
initial_force_field.add_parameters(parameters=new_parameters)
parameter_to_type = {
vdWHandler.vdWType: VdWSMIRKS,
BondHandler.BondType: BondSMIRKS,
AngleHandler.AngleType: AngleSMIRKS,
ProperTorsionHandler.ProperTorsionType: ProperTorsionSMIRKS,
ImproperTorsionHandler.ImproperTorsionType: ImproperTorsionSMIRKS,
}
# convert all parameters to bespokefit types
parameters_to_fit = defaultdict(list)
for parameter in new_parameters:
bespoke_parameter = parameter_to_type[parameter.__class__].from_smirnoff(
parameter
)
# We only want to fit if it was not cached
if not bespoke_parameter.cached:
parameters_to_fit[bespoke_parameter.type].append(bespoke_parameter)
# set which parameters should be optimised in each stage
for stage in input_schema.stages:
for hyper_param in stage.parameter_hyperparameters:
stage.parameters.extend(parameters_to_fit[hyper_param.type])
input_schema.initial_force_field = initial_force_field.force_field.to_string()
return fragment_jobs
async def _enter(self, task: "CoordinatorTask"):
settings = current_settings()
fragment_stage = next(
iter(
stage
for stage in task.completed_stages
if stage.type == "fragmentation"
),
None,
)
input_schema = task.input_schema
# TODO: Move these methods onto the celery worker.
try:
fragments = await self._generate_parameters(
fragmentation_result=fragment_stage.result,
input_schema=input_schema,
)
except BaseException as e: # lgtm [py/catch-base-exception]
self.status = "errored"
self.error = json.dumps(
f"Failed to generate SMIRKS patterns that match both the parent and "
f"torsion fragments: {str(e)}"
)
return
target_qc_tasks = defaultdict(list)
targets = [
target for stage in task.input_schema.stages for target in stage.targets
]
for i, target in enumerate(targets):
if not isinstance(target.reference_data, BespokeQCData):
continue
if target.bespoke_task_type() == "torsion1d":
target_qc_tasks[i].extend(
Torsion1DTask(
smiles=fragment.smiles,
central_bond=fragment.bond_indices,
**target.calculation_specification.dict(),
)
for fragment in fragments
)
else:
raise NotImplementedError()
qc_calc_ids = defaultdict(set)
async with httpx.AsyncClient() as client:
for i, qc_tasks in target_qc_tasks.items():
for qc_task in qc_tasks:
raw_response = await client.post(
f"http://127.0.0.1:"
f"{settings.BEFLOW_GATEWAY_PORT}"
f"{settings.BEFLOW_API_V1_STR}/"
f"{settings.BEFLOW_QC_COMPUTE_PREFIX}",
data=QCGeneratorPOSTBody(input_schema=qc_task).json(),
)
if raw_response.status_code != 200:
self.error = json.dumps(raw_response.text)
self.status = "errored"
return
response = QCGeneratorPOSTResponse.parse_raw(raw_response.text)
qc_calc_ids[i].add(response.id)
self.ids = {i: sorted(ids) for i, ids in qc_calc_ids.items()}
async def _update(self):
settings = current_settings()
if self.status == "errored":
return
if (
len([qc_id for target_ids in self.ids.values() for qc_id in target_ids])
== 0
):
# Handle the case were there was no bespoke QC data to generate.
self.status = "success"
self.results = {}
return
async with httpx.AsyncClient() as client:
id_query = "&ids=".join(qc_id for i in self.ids for qc_id in self.ids[i])
raw_response = await client.get(
f"http://127.0.0.1:"
f"{settings.BEFLOW_GATEWAY_PORT}"
f"{settings.BEFLOW_API_V1_STR}/"
f"{settings.BEFLOW_QC_COMPUTE_PREFIX}?ids={id_query}"
)
contents = raw_response.text
if raw_response.status_code != 200:
self.error = json.dumps(raw_response.text)
self.status = "errored"
return
get_responses = QCGeneratorGETPageResponse.parse_raw(contents).contents
statuses = {get_response.status for get_response in get_responses}
errors = [
json.loads(get_response.error)
for get_response in get_responses
if get_response.error is not None
]
self.error = json.dumps(errors)
self.status = "running"
if "errored" in statuses:
self.status = "errored"
elif statuses == {"waiting"}:
self.status = "waiting"
elif statuses == {"success"}:
self.status = "success"
self.results = {
get_response.id: get_response.result for get_response in get_responses
}
[docs]class OptimizationStage(_Stage):
type: Literal["optimization"] = "optimization"
id: Optional[str] = Field(
None, description="The id of the optimization associated with this stage."
)
result: Optional[BespokeOptimizationResults] = Field(
None, description="The result of the optimization."
)
@staticmethod
async def _inject_bespoke_qc_data(
qc_generation_stage: QCGenerationStage,
input_schema: BespokeOptimizationSchema,
):
targets: List[TargetSchema] = [
target for stage in input_schema.stages for target in stage.targets
]
for i, target in enumerate(targets):
if not isinstance(target.reference_data, BespokeQCData):
continue
if i not in qc_generation_stage.ids:
continue
local_qc_data = LocalQCData(
qc_records=[
qc_generation_stage.results[result_id]
for result_id in qc_generation_stage.ids[i]
]
)
target.reference_data = local_qc_data
targets_missing_qc_data = [
target
for target in targets
if isinstance(target.reference_data, BespokeQCData)
]
n_targets_missing_qc_data = len(targets_missing_qc_data)
if n_targets_missing_qc_data > 0 and qc_generation_stage.results:
raise RuntimeError(
f"{n_targets_missing_qc_data} targets were missing QC data - this "
f"should likely never happen. Please raise an issue on the GitHub "
f"issue tracker."
)
async def _enter(self, task: "CoordinatorTask"):
settings = current_settings()
completed_stages = {stage.type: stage for stage in task.completed_stages}
input_schema = task.input_schema.copy(deep=True)
# Map the generated QC results into a local QC data class and update the schema
# to target these.
qc_generation_stage: QCGenerationStage = completed_stages["qc-generation"]
try:
await self._inject_bespoke_qc_data(qc_generation_stage, input_schema)
except BaseException as e: # lgtm [py/catch-base-exception]
self.status = "errored"
self.error = json.dumps(
f"Failed to inject the bespoke QC data into the optimization "
f"schema: {str(e)}"
)
return
async with httpx.AsyncClient() as client:
raw_response = await client.post(
f"http://127.0.0.1:"
f"{settings.BEFLOW_GATEWAY_PORT}"
f"{settings.BEFLOW_API_V1_STR}/"
f"{settings.BEFLOW_OPTIMIZER_PREFIX}",
data=serialize(
OptimizerPOSTBody(input_schema=input_schema), encoding="json"
),
)
if raw_response.status_code != 200:
self.error = json.dumps(raw_response.text)
self.status = "errored"
return
response = OptimizerPOSTResponse.parse_raw(raw_response.text)
self.id = response.id
async def _update(self):
settings = current_settings()
if self.status == "errored":
return
async with httpx.AsyncClient() as client:
raw_response = await client.get(
f"http://127.0.0.1:"
f"{settings.BEFLOW_GATEWAY_PORT}"
f"{settings.BEFLOW_API_V1_STR}/"
f"{settings.BEFLOW_OPTIMIZER_PREFIX}/{self.id}"
)
contents = raw_response.text
if raw_response.status_code != 200:
self.error = json.dumps(raw_response.text)
self.status = "errored"
return
get_response = OptimizerGETResponse.parse_raw(contents)
self.result = get_response.result
self.error = get_response.error
self.status = get_response.status
StageType = Union[FragmentationStage, QCGenerationStage, OptimizationStage]