"""Utilities for storing data in a SQLite database"""
import abc
import math
from typing import TypeVar
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
Integer,
PickleType,
String,
UniqueConstraint,
)
from sqlalchemy.orm import Query, Session, relationship, declarative_base
from openff.recharge.esp import ESPSettings, PCMSettings
from openff.recharge.grids import GridSettingsType, LatticeGridSettings, MSKGridSettings
DBBase = declarative_base()
_InstanceType = TypeVar("_InstanceType")
_DBInstanceType = TypeVar("_DBInstanceType")
DB_VERSION = 2
_DB_FLOAT_PRECISION = 100000.0
def _float_to_db_int(value: float) -> int:
return int(math.floor(value * _DB_FLOAT_PRECISION))
def _db_int_to_float(value: int) -> float:
return value / _DB_FLOAT_PRECISION
class _UniqueMixin:
"""A base class for records which should be unique in the
database."""
@classmethod
@abc.abstractmethod
def _hash(cls, instance: _InstanceType) -> int:
"""Returns the hash of the instance that this record represents."""
raise NotImplementedError()
@classmethod
@abc.abstractmethod
def _query(cls, db: Session, instance: _InstanceType) -> Query:
"""Returns a query which should find existing copies of an instance."""
raise NotImplementedError()
@classmethod
@abc.abstractmethod
def _instance_to_db(cls, instance: _InstanceType) -> _DBInstanceType:
"""Map an instance into a database version of itself."""
raise NotImplementedError()
@classmethod
def unique(cls, db: Session, instance: _InstanceType) -> _DBInstanceType:
"""Creates a new database object from the specified instance if it
does not already exist on the database, otherwise the existing
instance is returned.
"""
cache = getattr(db, "_unique_cache", None)
if cache is None:
db._unique_cache = cache = {}
key = (cls, cls._hash(instance))
if key in cache:
return cache[key]
with db.no_autoflush:
existing_instance = cls._query(db, instance).first()
if not existing_instance:
existing_instance = cls._instance_to_db(instance)
db.add(existing_instance)
cache[key] = existing_instance
return existing_instance
[docs]class DBGridSettings(_UniqueMixin, DBBase):
__tablename__ = "grid_settings"
id = Column(Integer, primary_key=True, index=True)
type = Column(String, nullable=False)
lattice_spacing = Column(Integer, nullable=True)
lattice_inner_vdw_scale = Column(Integer, nullable=True)
lattice_outer_vdw_scale = Column(Integer, nullable=True)
msk_density = Column(Integer, nullable=True)
@classmethod
def _hash(cls, instance: GridSettingsType) -> int:
if isinstance(instance, LatticeGridSettings):
return hash(
(
instance.type,
_float_to_db_int(instance.spacing),
_float_to_db_int(instance.inner_vdw_scale),
_float_to_db_int(instance.outer_vdw_scale),
)
)
elif isinstance(instance, MSKGridSettings):
return hash((instance.type, _float_to_db_int(instance.density)))
else:
raise NotImplementedError()
@classmethod
def _query(cls, db: Session, instance: GridSettingsType) -> Query:
if isinstance(instance, LatticeGridSettings):
spacing = _float_to_db_int(instance.spacing)
inner_vdw_scale = _float_to_db_int(instance.inner_vdw_scale)
outer_vdw_scale = _float_to_db_int(instance.outer_vdw_scale)
return (
db.query(DBGridSettings)
.filter(DBGridSettings.type == instance.type)
.filter(DBGridSettings.lattice_spacing == spacing)
.filter(DBGridSettings.lattice_inner_vdw_scale == inner_vdw_scale)
.filter(DBGridSettings.lattice_outer_vdw_scale == outer_vdw_scale)
)
elif isinstance(instance, MSKGridSettings):
density = _float_to_db_int(instance.density)
return (
db.query(DBGridSettings)
.filter(DBGridSettings.type == instance.type)
.filter(DBGridSettings.msk_density == density)
)
else:
raise NotImplementedError()
@classmethod
def _instance_to_db(cls, instance: GridSettingsType) -> "DBGridSettings":
if isinstance(instance, LatticeGridSettings):
return DBGridSettings(
type=instance.type,
lattice_spacing=_float_to_db_int(instance.spacing),
lattice_inner_vdw_scale=_float_to_db_int(instance.inner_vdw_scale),
lattice_outer_vdw_scale=_float_to_db_int(instance.outer_vdw_scale),
)
elif isinstance(instance, MSKGridSettings):
return DBGridSettings(
type=instance.type,
msk_density=_float_to_db_int(instance.density),
)
else:
raise NotImplementedError()
[docs] @classmethod
def db_to_instance(cls, db_instance: "DBGridSettings") -> GridSettingsType:
if db_instance.type in ["fcc"]:
# noinspection PyTypeChecker
return LatticeGridSettings(
type=db_instance.type,
spacing=_db_int_to_float(db_instance.lattice_spacing),
inner_vdw_scale=_db_int_to_float(db_instance.lattice_inner_vdw_scale),
outer_vdw_scale=_db_int_to_float(db_instance.lattice_outer_vdw_scale),
)
elif db_instance.type == "msk":
# noinspection PyTypeChecker
return MSKGridSettings(
type=db_instance.type, density=_db_int_to_float(db_instance.msk_density)
)
else:
raise NotImplementedError()
[docs]class DBPCMSettings(_UniqueMixin, DBBase):
__tablename__ = "pcm_settings"
id = Column(Integer, primary_key=True, index=True)
solver = Column(String(6), nullable=False)
solvent = Column(String(20), nullable=False)
radii_model = Column(String(8), nullable=False)
radii_scaling = Column(Boolean, nullable=False)
cavity_area = Column(Integer)
@classmethod
def _hash(cls, instance: PCMSettings) -> int:
return hash(
(
instance.solver,
instance.solvent,
instance.radii_model,
instance.radii_scaling,
_float_to_db_int(instance.cavity_area),
)
)
@classmethod
def _query(cls, db: Session, instance: PCMSettings) -> Query:
cavity_area = _float_to_db_int(instance.cavity_area)
return (
db.query(DBPCMSettings)
.filter(DBPCMSettings.solver == instance.solver)
.filter(DBPCMSettings.solvent == instance.solvent)
.filter(DBPCMSettings.radii_model == instance.radii_model)
.filter(DBPCMSettings.radii_scaling == instance.radii_scaling)
.filter(DBPCMSettings.cavity_area == cavity_area)
)
@classmethod
def _instance_to_db(cls, instance: PCMSettings) -> "DBPCMSettings":
return DBPCMSettings(
solver=instance.solver,
solvent=instance.solvent,
radii_model=instance.radii_model,
radii_scaling=instance.radii_scaling,
cavity_area=_float_to_db_int(instance.cavity_area),
)
[docs] @classmethod
def db_to_instance(cls, db_instance: "DBPCMSettings") -> PCMSettings:
# noinspection PyTypeChecker
return PCMSettings(
solver=db_instance.solver,
solvent=db_instance.solvent,
radii_model=db_instance.radii_model,
radii_scaling=db_instance.radii_scaling,
cavity_area=_db_int_to_float(db_instance.cavity_area),
)
[docs]class DBESPSettings(_UniqueMixin, DBBase):
__tablename__ = "esp_settings"
__table_args__ = (UniqueConstraint("basis", "method"),)
id = Column(Integer, primary_key=True, index=True)
basis = Column(String, index=True, nullable=False)
method = Column(String, index=True, nullable=False)
psi4_dft_grid_settings = Column(String, nullable=False)
@classmethod
def _hash(cls, instance: ESPSettings) -> int:
return hash(
(instance.basis, instance.method, instance.psi4_dft_grid_settings.value)
)
@classmethod
def _query(cls, db: Session, instance: ESPSettings) -> Query:
return (
db.query(DBESPSettings)
.filter(DBESPSettings.basis == instance.basis)
.filter(DBESPSettings.method == instance.method)
.filter(
DBESPSettings.psi4_dft_grid_settings
== instance.psi4_dft_grid_settings.value
)
)
@classmethod
def _instance_to_db(cls, instance: ESPSettings) -> "DBESPSettings":
return DBESPSettings(
**instance.dict(
exclude={"grid_settings", "pcm_settings", "psi4_dft_grid_settings"}
),
psi4_dft_grid_settings=instance.psi4_dft_grid_settings.value
)
[docs]class DBMoleculeRecord(DBBase):
__tablename__ = "molecules"
smiles = Column(String, primary_key=True, index=True)
conformers = relationship("DBConformerRecord")
[docs]class DBGeneralProvenance(DBBase):
__tablename__ = "general_provenance"
key = Column(String, primary_key=True, index=True, unique=True)
value = Column(String, nullable=False)
parent_id = Column(Integer, ForeignKey("db_info.version"))
[docs]class DBSoftwareProvenance(DBBase):
__tablename__ = "software_provenance"
key = Column(String, primary_key=True, index=True, unique=True)
value = Column(String, nullable=False)
parent_id = Column(Integer, ForeignKey("db_info.version"))