Source code for openff.interchange.types

"""Custom models for dealing with unit-bearing quantities in a Pydantic-compatible manner."""
import json
from typing import TYPE_CHECKING, Any, Dict

import numpy as np
from openff.units import unit
from openff.utilities.utilities import has_package, requires_package
from openmm import unit as openmm_unit

from openff.interchange.exceptions import (
    MissingUnitError,
    UnitValidationError,
    UnsupportedExportError,
)

if TYPE_CHECKING or has_package("unyt"):
    import unyt


class _FloatQuantityMeta(type):
    def __getitem__(self, t):
        return type("FloatQuantity", (FloatQuantity,), {"__unit__": t})


[docs]class FloatQuantity(float, metaclass=_FloatQuantityMeta): """A model for unit-bearing floats.""" @classmethod def __get_validators__(cls): yield cls.validate_type
[docs] @classmethod def validate_type(cls, val): """Process a value tagged with units into one tagged with "OpenFF" style units.""" unit_ = getattr(cls, "__unit__", Any) if unit_ is Any: if isinstance(val, (float, int)): # TODO: Can this exception be raised with knowledge of the field it's in? raise MissingUnitError(f"Value {val} needs to be tagged with a unit") elif isinstance(val, unit.Quantity): return unit.Quantity(val) elif isinstance(val, openmm_unit.Quantity): return _from_omm_quantity(val) else: raise UnitValidationError( f"Could not validate data of type {type(val)}" ) else: unit_ = unit(unit_) if isinstance(val, unit.Quantity): # some custom behavior could go here assert unit_.dimensionality == val.dimensionality # return through converting to some intended default units (taken from the class) return val.to(unit_) # could return here, without converting # (could be inconsistent with data model - heteregenous but compatible units) # return val if isinstance(val, openmm_unit.Quantity): return _from_omm_quantity(val).to(unit_) if has_package("unyt"): if isinstance(val, unyt.unyt_quantity): return _from_unyt_quantity(val).to(unit_) if isinstance(val, (float, int)) and not isinstance(val, bool): return val * unit_ if isinstance(val, str): # could do custom deserialization here? return unit.Quantity(val).to(unit_) raise UnitValidationError(f"Could not validate data of type {type(val)}")
def _from_omm_quantity(val: openmm_unit.Quantity): """ Convert float or array quantities tagged with SimTK/OpenMM units to a Pint-compatible quantity. """ unit_ = val.unit val_ = val.value_in_unit(unit_) if type(val_) in {float, int}: unit_ = val.unit return val_ * unit.Unit(str(unit_)) elif type(val_) in {tuple, list, np.ndarray}: array = np.asarray(val_) return array * unit.Unit(str(unit_)) elif isinstance(val_, (float, int)) and type(val_).__module__ == "numpy": return val_ * unit.Unit(str(unit_)) else: raise UnitValidationError( "Found a openmm.unit.Unit wrapped around something other than a float-like " f"or np.ndarray-like. Found a unit wrapped around type {type(val_)}." ) @requires_package("unyt") def _from_unyt_quantity(val: "unyt.unyt_array"): """Convert unyt arrays to Pint quantities.""" quantity = val.to_pint() # Ensure a float-like quantity is a float, not a scalar array if isinstance(val, unyt.unyt_quantity): quantity = float(quantity.magnitude) * quantity.units return quantity
[docs]class QuantityEncoder(json.JSONEncoder): """ JSON encoder for unit-wrapped floats and NumPy arrays. This is intended to operate on FloatQuantity and ArrayQuantity objects. """
[docs] def default(self, obj): # noqa if isinstance(obj, unit.Quantity): if isinstance(obj.magnitude, (float, int)): data = obj.magnitude elif isinstance(obj.magnitude, np.ndarray): data = obj.magnitude.tolist() else: # This shouldn't ever be hit if our object models # behave in ways we expect? raise UnsupportedExportError( f"trying to serialize unsupported type {type(obj.magnitude)}" ) return { "val": data, "unit": str(obj.units), }
[docs]def custom_quantity_encoder(v): """Wrap json.dump to use QuantityEncoder.""" return json.dumps(v, cls=QuantityEncoder)
[docs]def json_loader(data: str) -> dict: """Load JSON containing custom unit-tagged quantities.""" # TODO: recursively call this function for nested models out: Dict = json.loads(data) for key, val in out.items(): try: # Directly look for an encoded FloatQuantity/ArrayQuantity, # which is itself a dict v = json.loads(val) except json.JSONDecodeError: # Handles some cases of the val being a primitive type continue # TODO: More gracefully parse non-FloatQuantity/ArrayQuantity dicts unit_ = unit(v["unit"]) val = v["val"] out[key] = unit_ * val return out
class _ArrayQuantityMeta(type): def __getitem__(self, t): return type("ArrayQuantity", (ArrayQuantity,), {"__unit__": t}) if TYPE_CHECKING: ArrayQuantity = np.ndarray else:
[docs] class ArrayQuantity(float, metaclass=_ArrayQuantityMeta): """A model for unit-bearing arrays.""" @classmethod def __get_validators__(cls): yield cls.validate_type
[docs] @classmethod def validate_type(cls, val): """Process an array tagged with units into one tagged with "OpenFF" style units.""" unit_ = getattr(cls, "__unit__", Any) if unit_ is Any: if isinstance(val, (list, np.ndarray)): # TODO: Can this exception be raised with knowledge of the field it's in? raise MissingUnitError( f"Value {val} needs to be tagged with a unit" ) elif isinstance(val, unit.Quantity): # Redundant cast? Maybe this handles pint vs openff.interchange.unit? return unit.Quantity(val) elif isinstance(val, openmm_unit.Quantity): return _from_omm_quantity(val) else: raise UnitValidationError( f"Could not validate data of type {type(val)}" ) else: unit_ = unit(unit_) if isinstance(val, unit.Quantity): assert unit_.dimensionality == val.dimensionality return val.to(unit_) if isinstance(val, openmm_unit.Quantity): return _from_omm_quantity(val).to(unit_) if isinstance(val, (np.ndarray, list)): if has_package("unyt"): # Must check for unyt_array, not unyt_quantity, which is a subclass if isinstance(val, unyt.unyt_array): return _from_unyt_quantity(val).to(unit_) else: return val * unit_ else: return val * unit_ if isinstance(val, bytes): # Define outside loop dt = np.dtype(int) dt.newbyteorder("<") return np.frombuffer(val, dtype=dt) * unit_ if isinstance(val, str): # could do custom deserialization here? raise NotImplementedError # return unit.Quantity(val).to(unit_) raise UnitValidationError( f"Could not validate data of type {type(val)}" )