Source code for openff.evaluator.protocols.utils

"""
A set of utilities for setting up property estimation workflows.
"""
from dataclasses import astuple, dataclass
from typing import Generic, Optional, Tuple, TypeVar

from openff.evaluator import unit
from openff.evaluator.attributes import PlaceholderValue
from openff.evaluator.datasets import PropertyPhase
from openff.evaluator.protocols import (
    analysis,
    coordinates,
    forcefield,
    gradients,
    groups,
    miscellaneous,
    openmm,
    reweighting,
    storage,
)
from openff.evaluator.protocols.groups import ConditionalGroup
from openff.evaluator.storage.data import StoredSimulationData
from openff.evaluator.thermodynamics import Ensemble
from openff.evaluator.utils.observables import ObservableType
from openff.evaluator.workflow import ProtocolGroup
from openff.evaluator.workflow.schemas import ProtocolReplicator
from openff.evaluator.workflow.utils import ProtocolPath, ReplicatorValue

S = TypeVar("S", bound=analysis.BaseAverageObservable)
T = TypeVar("T", bound=reweighting.BaseMBARProtocol)


[docs]@dataclass class SimulationProtocols(Generic[S]): """The common set of protocols which would be required to estimate an observable by running a new molecule simulation.""" build_coordinates: coordinates.BuildCoordinatesPackmol assign_parameters: forcefield.BaseBuildSystem energy_minimisation: openmm.OpenMMEnergyMinimisation equilibration_simulation: openmm.OpenMMSimulation production_simulation: openmm.OpenMMSimulation analysis_protocol: S converge_uncertainty: ProtocolGroup decorrelate_trajectory: analysis.DecorrelateTrajectory decorrelate_observables: analysis.DecorrelateObservables def __iter__(self): yield from astuple(self)
[docs]@dataclass class ReweightingProtocols(Generic[S, T]): """The common set of protocols which would be required to re-weight an observable from cached simulation data.""" unpack_stored_data: storage.UnpackStoredSimulationData join_trajectories: reweighting.ConcatenateTrajectories join_observables: reweighting.ConcatenateObservables build_reference_system: forcefield.BaseBuildSystem evaluate_reference_potential: reweighting.BaseEvaluateEnergies build_target_system: forcefield.BaseBuildSystem evaluate_target_potential: reweighting.BaseEvaluateEnergies statistical_inefficiency: S replicate_statistics: miscellaneous.DummyProtocol decorrelate_reference_potential: analysis.DecorrelateObservables decorrelate_target_potential: analysis.DecorrelateObservables decorrelate_observable: analysis.DecorrelateObservables zero_gradients: Optional[gradients.ZeroGradients] reweight_observable: T def __iter__(self): yield from astuple(self)
[docs]def generate_base_reweighting_protocols( statistical_inefficiency: S, reweight_observable: T, replicator_id: str = "data_replicator", id_suffix: str = "", ) -> Tuple[ReweightingProtocols[S, T], ProtocolReplicator]: """Constructs a set of protocols which, when combined in a workflow schema, may be executed to reweight a set of cached simulation data to estimate the average value of an observable. Parameters ---------- statistical_inefficiency The protocol which will be used to compute the statistical inefficiency and equilibration time of the observable of interest. This information will be used to decorrelate the cached data prior to reweighting. reweight_observable The MBAR reweighting protocol to use to reweight the observable to the target state. This method will automatically set the reduced potentials on the object. replicator_id: str The id to use for the cached data replicator. id_suffix: str A string suffix to append to each of the protocol ids. Returns ------- The protocols to add to the workflow, a reference to the average value of the estimated observable (an ``Observable`` object), and the replicator which will clone the workflow for each piece of cached simulation data. """ # Create the replicator which will apply these protocol once for each piece of # cached simulation data. data_replicator = ProtocolReplicator(replicator_id=replicator_id) data_replicator.template_values = ProtocolPath("full_system_data", "global") # Validate the inputs. assert isinstance(statistical_inefficiency, analysis.BaseAverageObservable) assert data_replicator.placeholder_id in statistical_inefficiency.id assert data_replicator.placeholder_id not in reweight_observable.id replicator_suffix = f"_{data_replicator.placeholder_id}{id_suffix}" # Unpack all the of the stored data. unpack_stored_data = storage.UnpackStoredSimulationData( "unpack_data{}".format(replicator_suffix) ) unpack_stored_data.simulation_data_path = ReplicatorValue(replicator_id) # Join the individual trajectories together. join_trajectories = reweighting.ConcatenateTrajectories( f"join_trajectories{id_suffix}" ) join_trajectories.input_coordinate_paths = ProtocolPath( "coordinate_file_path", unpack_stored_data.id ) join_trajectories.input_trajectory_paths = ProtocolPath( "trajectory_file_path", unpack_stored_data.id ) join_observables = reweighting.ConcatenateObservables( f"join_observables{id_suffix}" ) join_observables.input_observables = ProtocolPath( "observables", unpack_stored_data.id ) # Calculate the reduced potentials for each of the reference states. build_reference_system = forcefield.BaseBuildSystem( f"build_system{replicator_suffix}" ) build_reference_system.force_field_path = ProtocolPath( "force_field_path", unpack_stored_data.id ) build_reference_system.coordinate_file_path = ProtocolPath( "coordinate_file_path", unpack_stored_data.id ) build_reference_system.substance = ProtocolPath("substance", unpack_stored_data.id) reduced_reference_potential = openmm.OpenMMEvaluateEnergies( f"reduced_potential{replicator_suffix}" ) reduced_reference_potential.parameterized_system = ProtocolPath( "parameterized_system", build_reference_system.id ) reduced_reference_potential.thermodynamic_state = ProtocolPath( "thermodynamic_state", unpack_stored_data.id ) reduced_reference_potential.coordinate_file_path = ProtocolPath( "coordinate_file_path", unpack_stored_data.id ) reduced_reference_potential.trajectory_file_path = ProtocolPath( "output_trajectory_path", join_trajectories.id ) # Calculate the reduced potential of the target state. build_target_system = forcefield.BaseBuildSystem(f"build_system_target{id_suffix}") build_target_system.force_field_path = ProtocolPath("force_field_path", "global") build_target_system.substance = ProtocolPath("substance", "global") build_target_system.coordinate_file_path = ProtocolPath( "output_coordinate_path", join_trajectories.id ) reduced_target_potential = openmm.OpenMMEvaluateEnergies( f"reduced_potential_target{id_suffix}" ) reduced_target_potential.thermodynamic_state = ProtocolPath( "thermodynamic_state", "global" ) reduced_target_potential.parameterized_system = ProtocolPath( "parameterized_system", build_target_system.id ) reduced_target_potential.coordinate_file_path = ProtocolPath( "output_coordinate_path", join_trajectories.id ) reduced_target_potential.trajectory_file_path = ProtocolPath( "output_trajectory_path", join_trajectories.id ) reduced_target_potential.gradient_parameters = ProtocolPath( "parameter_gradient_keys", "global" ) # Compute the observable gradients. zero_gradients = gradients.ZeroGradients(f"zero_gradients{id_suffix}") zero_gradients.force_field_path = ProtocolPath("force_field_path", "global") zero_gradients.gradient_parameters = ProtocolPath( "parameter_gradient_keys", "global" ) # Decorrelate the target potentials and observables. if not isinstance(statistical_inefficiency, analysis.BaseAverageObservable): raise NotImplementedError() decorrelate_target_potential = analysis.DecorrelateObservables( f"decorrelate_target_potential{id_suffix}" ) decorrelate_target_potential.time_series_statistics = ProtocolPath( "time_series_statistics", statistical_inefficiency.id ) decorrelate_target_potential.input_observables = ProtocolPath( "output_observables", reduced_target_potential.id ) decorrelate_observable = analysis.DecorrelateObservables( f"decorrelate_observable{id_suffix}" ) decorrelate_observable.time_series_statistics = ProtocolPath( "time_series_statistics", statistical_inefficiency.id ) decorrelate_observable.input_observables = ProtocolPath( "output_observables", zero_gradients.id ) # Decorrelate the reference potentials. Due to a quirk of how workflow replicators # work the time series statistics need to be passed via a dummy protocol first. # # Because the `statistical_inefficiency` and `decorrelate_reference_potential` # protocols are replicated by the same replicator the `time_series_statistics` # input of `decorrelate_reference_potential_X` will take its value from # the `time_series_statistics` output of `statistical_inefficiency_X` rather than # as a list of of [statistical_inefficiency_0.time_series_statistics... # statistical_inefficiency_N.time_series_statistics]. Passing the statistics via # an un-replicated intermediate resolves this. replicate_statistics = miscellaneous.DummyProtocol( f"replicated_statistics{id_suffix}" ) replicate_statistics.input_value = ProtocolPath( "time_series_statistics", statistical_inefficiency.id ) decorrelate_reference_potential = analysis.DecorrelateObservables( f"decorrelate_reference_potential{replicator_suffix}" ) decorrelate_reference_potential.time_series_statistics = ProtocolPath( "output_value", replicate_statistics.id ) decorrelate_reference_potential.input_observables = ProtocolPath( "output_observables", reduced_reference_potential.id ) # Finally, apply MBAR to get the reweighted value. reweight_observable.reference_reduced_potentials = ProtocolPath( "output_observables[ReducedPotential]", decorrelate_reference_potential.id ) reweight_observable.target_reduced_potentials = ProtocolPath( "output_observables[ReducedPotential]", decorrelate_target_potential.id ) reweight_observable.observable = ProtocolPath( "output_observables", decorrelate_observable.id ) reweight_observable.frame_counts = ProtocolPath( "time_series_statistics.n_uncorrelated_points", statistical_inefficiency.id ) protocols = ReweightingProtocols( unpack_stored_data, # join_trajectories, join_observables, # build_reference_system, reduced_reference_potential, # build_target_system, reduced_target_potential, # statistical_inefficiency, replicate_statistics, # decorrelate_reference_potential, decorrelate_target_potential, # decorrelate_observable, zero_gradients, # reweight_observable, ) return protocols, data_replicator
[docs]def generate_reweighting_protocols( observable_type: ObservableType, replicator_id: str = "data_replicator", id_suffix: str = "", ) -> Tuple[ ReweightingProtocols[analysis.AverageObservable, reweighting.ReweightObservable], ProtocolReplicator, ]: assert observable_type not in [ ObservableType.KineticEnergy, ObservableType.TotalEnergy, ObservableType.Enthalpy, ] statistical_inefficiency = analysis.AverageObservable( f"observable_inefficiency_$({replicator_id}){id_suffix}" ) statistical_inefficiency.bootstrap_iterations = 1 reweight_observable = reweighting.ReweightObservable( f"reweight_observable{id_suffix}" ) protocols, data_replicator = generate_base_reweighting_protocols( statistical_inefficiency, reweight_observable, replicator_id, id_suffix ) protocols.statistical_inefficiency.observable = ProtocolPath( f"observables[{observable_type.value}]", protocols.unpack_stored_data.id ) if ( observable_type != ObservableType.PotentialEnergy and observable_type != ObservableType.TotalEnergy and observable_type != ObservableType.Enthalpy and observable_type != ObservableType.ReducedPotential ): protocols.zero_gradients.input_observables = ProtocolPath( f"output_observables[{observable_type.value}]", protocols.join_observables.id, ) else: protocols.zero_gradients = None protocols.decorrelate_observable = protocols.decorrelate_target_potential protocols.reweight_observable.observable = ProtocolPath( f"output_observables[{observable_type.value}]", protocols.decorrelate_observable.id, ) return protocols, data_replicator
[docs]def generate_simulation_protocols( analysis_protocol: S, use_target_uncertainty: bool, id_suffix: str = "", conditional_group: Optional[ConditionalGroup] = None, n_molecules: int = 1000, ) -> Tuple[SimulationProtocols[S], ProtocolPath, StoredSimulationData]: """Constructs a set of protocols which, when combined in a workflow schema, may be executed to run a single simulation to estimate the average value of an observable. The protocols returned will: 1) Build a set of liquid coordinates for the property substance using packmol. 2) Assign a set of smirnoff force field parameters to the system. 3) Perform an energy minimisation on the system. 4) Run a short NPT equilibration simulation for 100000 steps using a timestep of 2fs. 5) Within a conditional group (up to a maximum of 100 times): 5a) Run a longer NPT production simulation for 1000000 steps using a timestep of 2fs 5b) Extract the average value of an observable and it's uncertainty. 5c) If a convergence mode is set by the options, check if the target uncertainty has been met. If not, repeat steps 5a), 5b) and 5c). 6) Extract uncorrelated configurations from a generated production simulation. 7) Extract uncorrelated statistics from a generated production simulation. Parameters ---------- analysis_protocol The protocol which will extract the observable of interest from the generated simulation data. use_target_uncertainty Whether to run the simulation until the observable is estimated to within the target uncertainty. id_suffix: str A string suffix to append to each of the protocol ids. conditional_group: ProtocolGroup, optional A custom group to wrap the main simulation / extraction protocols within. It is up to the caller of this method to manually add the convergence conditions to this group. If `None`, a default group with uncertainty convergence conditions is automatically constructed. n_molecules: int The number of molecules to use in the workflow. Returns ------- The protocols to add to the workflow, a reference to the average value of the estimated observable (an ``Observable`` object), and an object which describes the default data from a simulation to store, such as the uncorrelated statistics and configurations. """ build_coordinates = coordinates.BuildCoordinatesPackmol( f"build_coordinates{id_suffix}" ) build_coordinates.substance = ProtocolPath("substance", "global") build_coordinates.max_molecules = n_molecules assign_parameters = forcefield.BaseBuildSystem(f"assign_parameters{id_suffix}") assign_parameters.force_field_path = ProtocolPath("force_field_path", "global") assign_parameters.coordinate_file_path = ProtocolPath( "coordinate_file_path", build_coordinates.id ) assign_parameters.substance = ProtocolPath("output_substance", build_coordinates.id) # Equilibration energy_minimisation = openmm.OpenMMEnergyMinimisation( f"energy_minimisation{id_suffix}" ) energy_minimisation.input_coordinate_file = ProtocolPath( "coordinate_file_path", build_coordinates.id ) energy_minimisation.parameterized_system = ProtocolPath( "parameterized_system", assign_parameters.id ) equilibration_simulation = openmm.OpenMMSimulation( f"equilibration_simulation{id_suffix}" ) equilibration_simulation.ensemble = Ensemble.NPT equilibration_simulation.steps_per_iteration = 100000 equilibration_simulation.output_frequency = 5000 equilibration_simulation.timestep = 2.0 * unit.femtosecond equilibration_simulation.thermodynamic_state = ProtocolPath( "thermodynamic_state", "global" ) equilibration_simulation.input_coordinate_file = ProtocolPath( "output_coordinate_file", energy_minimisation.id ) equilibration_simulation.parameterized_system = ProtocolPath( "parameterized_system", assign_parameters.id ) # Production production_simulation = openmm.OpenMMSimulation(f"production_simulation{id_suffix}") production_simulation.ensemble = Ensemble.NPT production_simulation.steps_per_iteration = 1000000 production_simulation.output_frequency = 2000 production_simulation.timestep = 2.0 * unit.femtosecond production_simulation.thermodynamic_state = ProtocolPath( "thermodynamic_state", "global" ) production_simulation.input_coordinate_file = ProtocolPath( "output_coordinate_file", equilibration_simulation.id ) production_simulation.parameterized_system = ProtocolPath( "parameterized_system", assign_parameters.id ) production_simulation.gradient_parameters = ProtocolPath( "parameter_gradient_keys", "global" ) # Set up a conditional group to ensure convergence of uncertainty if conditional_group is None: conditional_group = groups.ConditionalGroup(f"conditional_group{id_suffix}") conditional_group.max_iterations = 100 if use_target_uncertainty: condition = groups.ConditionalGroup.Condition() condition.right_hand_value = ProtocolPath("target_uncertainty", "global") condition.type = groups.ConditionalGroup.Condition.Type.LessThan condition.left_hand_value = ProtocolPath( "value.error", conditional_group.id, analysis_protocol.id ) conditional_group.add_condition(condition) # Make sure the simulation gets extended after each iteration. production_simulation.total_number_of_iterations = ProtocolPath( "current_iteration", conditional_group.id ) conditional_group.add_protocols(production_simulation, analysis_protocol) # Point the analyse protocol to the correct data sources if not isinstance(analysis_protocol, analysis.BaseAverageObservable): raise ValueError( "The analysis protocol must inherit from either the " "AverageTrajectoryObservable or BaseAverageObservable " "protocols." ) analysis_protocol.thermodynamic_state = ProtocolPath( "thermodynamic_state", "global" ) analysis_protocol.potential_energies = ProtocolPath( f"observables[{ObservableType.PotentialEnergy.value}]", production_simulation.id, ) # Finally, extract uncorrelated data time_series_statistics = ProtocolPath( "time_series_statistics", conditional_group.id, analysis_protocol.id ) coordinate_file = ProtocolPath( "output_coordinate_file", conditional_group.id, production_simulation.id ) trajectory_path = ProtocolPath( "trajectory_file_path", conditional_group.id, production_simulation.id ) observables = ProtocolPath( "observables", conditional_group.id, production_simulation.id ) decorrelate_trajectory = analysis.DecorrelateTrajectory( f"decorrelate_trajectory{id_suffix}" ) decorrelate_trajectory.time_series_statistics = time_series_statistics decorrelate_trajectory.input_coordinate_file = coordinate_file decorrelate_trajectory.input_trajectory_path = trajectory_path decorrelate_observables = analysis.DecorrelateObservables( f"decorrelate_observables{id_suffix}" ) decorrelate_observables.time_series_statistics = time_series_statistics decorrelate_observables.input_observables = observables # Build the object which defines which pieces of simulation data to store. output_to_store = StoredSimulationData() output_to_store.thermodynamic_state = ProtocolPath("thermodynamic_state", "global") output_to_store.property_phase = PropertyPhase.Liquid output_to_store.force_field_id = PlaceholderValue() output_to_store.number_of_molecules = ProtocolPath( "output_number_of_molecules", build_coordinates.id ) output_to_store.substance = ProtocolPath("output_substance", build_coordinates.id) output_to_store.statistical_inefficiency = ProtocolPath( "time_series_statistics.statistical_inefficiency", conditional_group.id, analysis_protocol.id, ) output_to_store.observables = ProtocolPath( "output_observables", decorrelate_observables.id ) output_to_store.trajectory_file_name = ProtocolPath( "output_trajectory_path", decorrelate_trajectory.id ) output_to_store.coordinate_file_name = coordinate_file output_to_store.source_calculation_id = PlaceholderValue() # Define where the final values come from. final_value_source = ProtocolPath( "value", conditional_group.id, analysis_protocol.id ) base_protocols = SimulationProtocols( build_coordinates, assign_parameters, energy_minimisation, equilibration_simulation, production_simulation, analysis_protocol, conditional_group, decorrelate_trajectory, decorrelate_observables, ) return base_protocols, final_value_source, output_to_store