Source code for openff.qcsubmit.utils.utils

import logging
import os
from contextlib import contextmanager
from typing import (
    Callable,
    Dict,
    Generator,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
)

from openff.toolkit import topology as off
from openff.toolkit.utils.toolkits import (
    RDKitToolkitWrapper,
    UndefinedStereochemistryError,
)
from qcportal import PortalClient
from qcportal.cache import RecordCache, get_records_with_cache
from qcportal.optimization.record_models import OptimizationRecord
from qcportal.singlepoint.record_models import SinglepointRecord
from qcportal.torsiondrive.record_models import TorsiondriveRecord

logger = logging.getLogger(__name__)


[docs]class _CachedPortalClient(PortalClient): """A cached version of a `qcportal.PortalClient <https://molssi.github.io/QCFractal/user_guide/qcportal_reference.html#qcportal.client.PortalClient>`_. """
[docs] def __init__( self, address: str, cache_dir: str, username: Optional[str] = None, password: Optional[str] = None, verify: bool = True, show_motd: bool = True, *, cache_max_size: int = 0, memory_cache_key: Optional[str] = None, ): """Parameters ---------- address The host or IP address of the FractalServer instance, including protocol and port if necessary ("https://ml.qcarchive.molssi.org", "http://192.168.1.10:8888") cache_dir Directory to store an internal cache of records and other data. Unlike a normal ``PortalClient``, this argument is required. username The username to authenticate with. password The password to authenticate with. verify Verifies the SSL connection with a third party server. This may be False if a FractalServer was not provided an SSL certificate and defaults back to self-signed SSL keys. show_motd If a Message-of-the-Day is available, display it cache_max_size Maximum size of the cache directory """ super().__init__( address, username=username, password=password, verify=verify, show_motd=show_motd, cache_dir=cache_dir, cache_max_size=cache_max_size, memory_cache_key=memory_cache_key, ) self.record_cache = RecordCache( os.path.join(self.cache.cache_dir, "cache.sqlite"), read_only=False )
def __repr__(self) -> str: """A short representation of the current PortalClient. Returns ------- str The desired representation. """ ret = "CachedPortalClient(server_name='{}', address='{}', username='{}', cache_dir='{}')".format( self.server_name, self.address, self.username, self.cache.cache_dir ) return ret
[docs] def get_optimizations( self, record_ids: Union[int, Sequence[int]], missing_ok: bool = False, *, include: Optional[Iterable[str]] = None, ) -> Union[Optional[OptimizationRecord], List[Optional[OptimizationRecord]]]: """Obtain optimization records with the specified IDs. Records will be returned in the same order as the record ids. Parameters ---------- record_ids Single ID or sequence/list of records to obtain missing_ok Unlike a ``PortalClient``, this argument is ignored. If set to True, a warning will be printed. Any missing records will cause a ``RuntimeError`` to be raised. include Additional fields to include in the returned record Returns ------- : If a single ID was specified, returns just that record. Otherwise, returns a list of records. """ if missing_ok: logger.warning( "missing_ok was set to True, but CachedPortalClient" " doesn't actually support this so it's being set to False" ) if unpack := not isinstance(record_ids, Sequence): record_ids = [record_ids] res = get_records_with_cache( client=self, record_cache=self.record_cache, record_type=OptimizationRecord, record_ids=record_ids, include=include, force_fetch=False, ) if unpack: return res[0] else: return res
[docs] def get_singlepoints( self, record_ids: Union[int, Sequence[int]], missing_ok: bool = False, *, include: Optional[Iterable[str]] = None, ) -> Union[Optional[SinglepointRecord], List[Optional[SinglepointRecord]]]: """ Obtain singlepoint records with the specified IDs. Records will be returned in the same order as the record ids. Parameters ---------- record_ids Single ID or sequence/list of records to obtain missing_ok Unlike a ``PortalClient``, this argument is ignored. If set to True, a warning will be printed. Any missing records will cause a ``RuntimeError`` to be raised. include Additional fields to include in the returned record Returns ------- : If a single ID was specified, returns just that record. Otherwise, returns a list of records. """ if missing_ok: logger.warning( "missing_ok was set to True, but CachedPortalClient" " doesn't actually support this so it's being set to False" ) if unpack := not isinstance(record_ids, Sequence): record_ids = [record_ids] res = get_records_with_cache( client=self, record_cache=self.record_cache, record_type=SinglepointRecord, record_ids=record_ids, include=include, force_fetch=False, ) if unpack: return res[0] else: return res
[docs] def get_torsiondrives( self, record_ids: Union[int, Sequence[int]], missing_ok: bool = False, *, include: Optional[Iterable[str]] = None, ) -> Union[Optional[TorsiondriveRecord], List[Optional[TorsiondriveRecord]]]: """ Obtain torsiondrive records with the specified IDs. Records will be returned in the same order as the record ids. Parameters ---------- record_ids Single ID or sequence/list of records to obtain missing_ok Unlike a ``PortalClient``, this argument is ignored. If set to True, a warning will be printed. Any missing records will cause a ``RuntimeError`` to be raised. include Additional fields to include in the returned record Returns ------- : If a single ID was specified, returns just that record. Otherwise, returns a list of records. """ if missing_ok: logger.warning( "missing_ok was set to True, but CachedPortalClient" " doesn't actually support this so it's being set to False" ) if unpack := not isinstance(record_ids, Sequence): record_ids = [record_ids] res = get_records_with_cache( client=self, record_cache=self.record_cache, record_type=TorsiondriveRecord, record_ids=record_ids, include=include, force_fetch=False, ) if unpack: return res[0] else: return res
@contextmanager def _no_session(self): """This is a context manager to prevent the ``_CachedPortalClient`` from accessing the internet. ``PortalClient`` creates a ``requests.Session`` on initialization that can be reused without accessing ``socket.socket`` again, so this is a more reliable way to ensure that the local cache is used than overriding ``socket.socket``. Attempting to access the session will raise an ``AttributeError`` with the message ``'NoneType' object has no attribute 'prepare_request'``. """ tmp = self._req_session self._req_session = None try: yield finally: self._req_session = tmp
def _default_portal_client(client_address) -> PortalClient: return PortalClient(client_address)
[docs]@contextmanager def portal_client_manager(portal_client_fn: Callable[[str], PortalClient]): """A context manager that temporarily changes the default ``qcportal.PortalClient`` constructor used internally in functions like ``BasicResultCollection.to_records`` and many of the ``ResultFilter`` classes. This can be especially useful if you need to provide additional keyword arguments to the ``PortalClient``, such as ``verify=False`` or a ``cache_dir``. .. warning:: It is not safe to share the same client across threads or to construct multiple clients accessing the same cache database. Parameters ---------- portal_client_fn: A function returning a PortalClient Examples -------- Assuming you already have a dataset defined as ``ds``, call ``to_records`` and use an existing cache in the current working directory if present or create a new one automatically: >>> from openff.qcsubmit.utils import portal_client_manager >>> from qcportal import PortalClient >>> def my_portal_client(client_address): >>> return PortalClient(client_address, cache_dir=".") >>> with portal_client_manager(my_portal_client): >>> records_and_molecules = ds.to_records() """ global _default_portal_client original_client_fn = _default_portal_client _default_portal_client = portal_client_fn try: yield finally: _default_portal_client = original_client_fn
def get_data(relative_path): """ Get the file path to some data in the qcsubmit package. Parameters: relative_path: The relative path to the data """ import os from pkg_resources import resource_filename fn = resource_filename("openff.qcsubmit", os.path.join("data", relative_path)) if not os.path.exists(fn): raise ValueError( f"Sorry! {fn} does not exist. If you just added it, you'll have to re-install" ) return fn def check_missing_stereo(molecule: off.Molecule) -> bool: """ Get if the given molecule has missing stereo by round trip and catching stereo errors. Here we use the RDKit backend explicitly for this check as this avoids nitrogen stereochemistry issues with the toolkit. Parameters ---------- molecule: off.Molecule The molecule which should be checked for stereo issues. Returns ------- bool `True` if some stereochemistry is missing else `False`. """ try: _ = off.Molecule.from_smiles( smiles=molecule.to_smiles(isomeric=True, explicit_hydrogens=True), hydrogens_are_explicit=True, allow_undefined_stereo=False, toolkit_registry=RDKitToolkitWrapper(), ) return False except UndefinedStereochemistryError: return True def clean_strings(string_list: List[str]) -> List[str]: """ Clean up a list of strings ready to be cast to numbers. """ clean_string = [] for string in string_list: new_string = string.strip() clean_string.append(new_string.strip(",")) return clean_string def remap_list(target_list: List[int], mapping: Dict[int, int]) -> List[int]: """ Take a list of atom indices and remap them using the given mapping. """ return [mapping[x] for x in target_list] def condense_molecules(molecules: List[off.Molecule]) -> off.Molecule: """ Take a list of identical molecules in different conformers and collapse them making sure that they are in the same order. """ molecule = molecules.pop() for conformer in molecules: _, atom_map = off.Molecule.are_isomorphic( conformer, molecule, return_atom_map=True ) mapped_mol = conformer.remap(atom_map) for geometry in mapped_mol.conformers: molecule.add_conformer(geometry) return molecule def chunk_generator(iterable: List, chunk_size: int) -> Generator[List, None, None]: """ Take an iterable and return a list of lists of the specified size. Parameters: iterable: An iterable object like a list chunk_size: The size of each chunk """ for i in range(0, len(iterable), chunk_size): yield iterable[i : i + chunk_size] def get_torsion(bond: off.Bond) -> Tuple[int, int, int, int]: """ Create a torsion tuple which will be restrained in the torsiondrive. Parameters: bond: The tuple of the atom indexes for the central bond. Returns: The tuple of the four atom indices which should be restrained. Note: If there is more than one possible combination of atoms the heaviest set are selected to be restrained. """ atoms: List[off.Atom] = [bond.atom1, bond.atom2] terminal_atoms: Dict[off.Atom, off.atom] = dict() for atom in atoms: for neighbour in atom.bonded_atoms: if neighbour not in atoms: # If we have not seen any possible terminal atoms for this atom, add the neighbour if atom not in terminal_atoms: terminal_atoms[atom] = neighbour # If the neighbour is heavier than the current terminal atom, replace it elif neighbour.atomic_number > terminal_atoms.get(atom).atomic_number: terminal_atoms[atom] = neighbour # build out the torsion return tuple( [ terminal_atoms[atoms[0]].molecule_atom_index, atoms[0].molecule_atom_index, atoms[1].molecule_atom_index, terminal_atoms[atoms[1]].molecule_atom_index, ] ) def get_symmetry_classes(molecule: off.Molecule) -> List[int]: """Calculate the symmetry classes of each atom in the molecule using the backend toolkits.""" try: from rdkit import Chem rd_mol = molecule.to_rdkit() symmetry_classes = list(Chem.CanonicalRankAtoms(rd_mol, breakTies=False)) except (ImportError, ModuleNotFoundError): from openeye import oechem oe_mol = molecule.to_openeye() oechem.OEPerceiveSymmetry(oe_mol) symmetry_classes_by_index = { a.GetIdx(): a.GetSymmetryClass() for a in oe_mol.GetAtoms() } symmetry_classes = [ symmetry_classes_by_index[i] for i in range(molecule.n_atoms) ] return symmetry_classes def get_symmetry_group( atom_group: Tuple[int, ...], symmetry_classes: List[int] ) -> Tuple[int, ...]: """ For the list of atom groups calculate their symmetry class for the given molecule. """ return tuple([symmetry_classes[atom] for atom in atom_group])