Source code for openff.bespokefit.executor.services.coordinator.storage

import pickle
from enum import Enum
from typing import List, Optional, Set, Union

from openff.bespokefit.executor.services.coordinator.models import CoordinatorTask
from openff.bespokefit.executor.services.coordinator.stages import (
    FragmentationStage,
    OptimizationStage,
    QCGenerationStage,
)
from openff.bespokefit.executor.utilities.redis import connect_to_default_redis
from openff.bespokefit.schema.fitting import BespokeOptimizationSchema


[docs]class TaskStatus(str, Enum): waiting = "waiting" running = "running" complete = "complete"
_QUEUE_NAMES = { TaskStatus.waiting: "coordinator:tasks:waiting", TaskStatus.running: "coordinator:tasks:running", TaskStatus.complete: "coordinator:tasks:complete", } def _task_id_to_key(task_id: Union[str, int]) -> str: return f"coordinator:task:{task_id}"
[docs]def get_task(task_id: Union[str, int]) -> CoordinatorTask: connection = connect_to_default_redis() task_pickle = connection.get(_task_id_to_key(task_id)) if task_pickle is None: raise IndexError(f"{task_id} was not found") return CoordinatorTask.parse_obj(pickle.loads(task_pickle))
[docs]def get_task_ids( skip: int = 0, limit: Optional[int] = None, status: Optional[Union[TaskStatus, Set[TaskStatus]]] = None, ) -> List[int]: connection = connect_to_default_redis() possible_status = [TaskStatus.waiting, TaskStatus.running, TaskStatus.complete] if status is not None and isinstance(status, TaskStatus): status = {status} elif status is None: status = {TaskStatus.waiting, TaskStatus.running, TaskStatus.complete} ordered_status = [value for value in possible_status if value in status] task_ids = [ int(task_id) for task_status in ordered_status for task_id in connection.lrange(_QUEUE_NAMES[task_status], 0, -1) ][skip : (skip + limit if limit is not None else None)] return task_ids
[docs]def create_task( input_schema: BespokeOptimizationSchema, stages: Optional[ List[Union[FragmentationStage, QCGenerationStage, OptimizationStage]] ] = None, ) -> int: connection = connect_to_default_redis() task_id = connection.incr("coordinator:id-counter") stages = ( stages if stages is not None else [FragmentationStage(), QCGenerationStage(), OptimizationStage()] ) task = CoordinatorTask( id=str(task_id), input_schema=input_schema, pending_stages=stages, ) task.input_schema.id = task_id task_key = _task_id_to_key(task_id) connection.set(task_key, pickle.dumps(task.dict())) connection.rpush(_QUEUE_NAMES[TaskStatus.waiting], task_id) return task_id
[docs]def get_n_tasks(status: Optional[TaskStatus] = None) -> int: connection = connect_to_default_redis() return sum( connection.llen(queue) for queue_name, queue in _QUEUE_NAMES.items() if status is None or queue_name == status )
[docs]def peek_task_status(status: TaskStatus) -> Optional[int]: connection = connect_to_default_redis() task_id = connection.lrange(_QUEUE_NAMES[status], 0, 0) return None if len(task_id) == 0 else int(task_id[0])
[docs]def pop_task_status(status: TaskStatus) -> Optional[int]: assert status != TaskStatus.complete, "complete tasks cannot be modified" connection = connect_to_default_redis() task_id = connection.lpop(_QUEUE_NAMES[status]) return None if task_id is None else int(task_id)
[docs]def push_task_status(task_id: int, status: TaskStatus): connection = connect_to_default_redis() return connection.rpush(_QUEUE_NAMES[status], task_id)
[docs]def save_task(task: CoordinatorTask): connection = connect_to_default_redis() connection.set(_task_id_to_key(int(task.id)), pickle.dumps(task.dict()))