Source code for pynenc.state_backend.mem_state_backend

import itertools
import threading
from collections import defaultdict
from collections.abc import Iterator
from datetime import datetime
from typing import TYPE_CHECKING, Any

from pynenc.state_backend.base_state_backend import BaseStateBackend
from pynenc.identifiers.invocation_id import InvocationId as InvId
from pynenc.identifiers.task_id import TaskId
from pynenc.types import Params, Result

if TYPE_CHECKING:
    from pynenc.app import AppInfo, Pynenc
    from pynenc.identifiers.invocation_id import InvocationId
    from pynenc.invocation.dist_invocation import InvocationDTO
    from pynenc.models.call_dto import CallDTO
    from pynenc.runner.runner_context import RunnerContext
    from pynenc.state_backend.base_state_backend import InvocationHistory
    from pynenc.workflow import WorkflowIdentity


[docs] class MemStateBackend(BaseStateBackend[Params, Result]): """ A memory-based implementation of the state backend. Stores invocation data, history, results, and exceptions in in-memory dictionaries. Useful for environments where persistence is not required or for testing purposes. ```{warning} The `MemStateBackend` class stores all data in the process's memory and is not suitable for production systems. Its use should be limited to testing or demonstration purposes only. ``` """ # Class-level registry shared across instances, protected by a lock. # Process-wide so that discover_app_infos (a @staticmethod) can find all apps. _app_info_registry: dict[str, "AppInfo"] = {} _registry_lock: threading.Lock = threading.Lock() def __init__(self, app: "Pynenc") -> None: self._cache: dict[str, tuple[InvocationDTO, CallDTO]] = {} self._parent_to_children: dict[str, list[str]] = defaultdict(list) self._runner_contexts: dict[str, RunnerContext] = {} self._history: dict[InvocationId, list] = defaultdict(list) self._results: dict[InvocationId, str] = {} self._exceptions: dict[InvocationId, str] = {} self._workflow_data: dict[InvocationId, dict[str, Any]] = defaultdict(dict) self._workflow_types: set[TaskId] = set() # Stores workflow_task_ids self._workflow_runs: dict[TaskId, set[WorkflowIdentity]] = defaultdict( set ) # workflow_task_id -> runs self._workflow_sub_invocations: dict[InvocationId, set[InvocationId]] = ( defaultdict(set) ) # workflow_id -> sub_invocation_ids super().__init__(app)
[docs] def purge(self) -> None: """Clears all stored data""" self._cache.clear() self._parent_to_children.clear() self._history.clear() self._results.clear() self._exceptions.clear() self._workflow_types.clear() self._workflow_runs.clear() self._workflow_sub_invocations.clear()
[docs] def _upsert_invocations( self, entries: list[tuple["InvocationDTO", "CallDTO"]] ) -> None: """Store invocation and call DTO pairs in the memory cache. Also maintains the parent-to-children index for efficient family tree traversal without scanning the entire cache. """ for inv_dto, call_dto in entries: self._cache[inv_dto.invocation_id] = (inv_dto, call_dto) if inv_dto.parent_invocation_id is not None: parent_key = str(inv_dto.parent_invocation_id) child_key = str(inv_dto.invocation_id) children = self._parent_to_children[parent_key] if child_key not in children: children.append(child_key)
[docs] def _get_invocation( self, invocation_id: "InvocationId" ) -> tuple["InvocationDTO", "CallDTO"] | None: """Retrieve an invocation DTO pair from the memory cache. :param "InvocationId" invocation_id: The ID of the invocation to retrieve. :return: Paired DTOs if found, else None. """ return self._cache.get(invocation_id)
[docs] def get_child_invocations( self, parent_invocation_id: "InvocationId" ) -> Iterator["InvocationId"]: """Return IDs of invocations that name the given ID as their parent. Uses the pre-built parent-to-children index instead of scanning the entire cache, providing O(1) lookup per parent. :param parent_invocation_id: The parent invocation ID to search for. :return: Iterator of child invocation IDs. """ parent_key = str(parent_invocation_id) return ( InvId(child_key) for child_key in self._parent_to_children.get(parent_key, []) )
[docs] def _add_histories( self, invocation_ids: list["InvocationId"], invocation_history: "InvocationHistory", ) -> None: """Adds the same history record for a list of invocations.""" for invocation_id in invocation_ids: self._history[invocation_id].append(invocation_history)
[docs] def _get_history(self, invocation_id: "InvocationId") -> list["InvocationHistory"]: """ Retrieves the history of an invocation ordered by timestamp. :param "InvocationId" invocation_id: The ID of the invocation to get the history from :return: List of InvocationHistory records """ return sorted( self._history.get(invocation_id, []), key=lambda record: record.timestamp, )
[docs] def _get_result(self, invocation_id: "InvocationId") -> str: """ Retrieves the result of an invocation. :param "InvocationId" invocation_id: The ID of the invocation to get the result from :return: The serialized result string """ return self._results[invocation_id]
[docs] def _set_result( self, invocation_id: "InvocationId", serialized_result: str ) -> None: """ Sets the result of an invocation. :param "InvocationId" invocation_id: The ID of the invocation to set :param str serialized_result: The serialized result string to set """ self._results[invocation_id] = serialized_result
[docs] def _get_exception(self, invocation_id: "InvocationId") -> str: """ Retrieves the exception of an invocation. :param "InvocationId" invocation_id: The ID of the invocation to get the exception from :return: The serialized exception string """ return self._exceptions[invocation_id]
[docs] def _set_exception( self, invocation_id: "InvocationId", serialized_exception: str ) -> None: """ Sets the raised exception by an invocation ran. :param "InvocationId" invocation_id: The ID of the invocation to set :param str serialized_exception: The serialized exception string to set """ self._exceptions[invocation_id] = serialized_exception
[docs] def get_workflow_data( self, workflow_identity: "WorkflowIdentity", key: str, default: Any = None ) -> Any: """ Get a value from workflow data. :param "WorkflowIdentity" workflow_identity: Workflow identity :param str key: Data key to retrieve :param Any default: Default value if key doesn't exist :return: Stored value or default """ workflow_id = workflow_identity.workflow_id return self._workflow_data.get(workflow_id, {}).get(key, default)
[docs] def set_workflow_data( self, workflow_identity: "WorkflowIdentity", key: str, value: Any ) -> None: """ Set a value in workflow data. :param workflow_identity: Workflow identity :param key: Data key to set :param value: Value to store """ workflow_id = workflow_identity.workflow_id if workflow_id not in self._workflow_data: self._workflow_data[workflow_id] = {} self._workflow_data[workflow_id][key] = value
[docs] def store_app_info(self, app_info: "AppInfo") -> None: """ Register this app's information in the state backend for discovery. :param app_info: The app information to store """ with MemStateBackend._registry_lock: MemStateBackend._app_info_registry[app_info.app_id] = app_info
[docs] def get_app_info(self) -> "AppInfo": """ Retrieve information of the current app. :return: The app information :raises ValueError: If app info is not found """ app_id = self.app.app_id with MemStateBackend._registry_lock: if app_id not in MemStateBackend._app_info_registry: raise ValueError(f"No app info found for app_id '{app_id}'") return MemStateBackend._app_info_registry[app_id]
[docs] @staticmethod def discover_app_infos() -> dict[str, "AppInfo"]: with MemStateBackend._registry_lock: return dict(MemStateBackend._app_info_registry)
[docs] def store_workflow_run(self, workflow_identity: "WorkflowIdentity") -> None: """ Store a workflow run for tracking and monitoring. :param workflow_identity: The workflow identity to store """ self._workflow_types.add(workflow_identity.workflow_type) self._workflow_runs[workflow_identity.workflow_type].add(workflow_identity)
[docs] def get_all_workflow_types(self) -> Iterator["TaskId"]: """ Retrieve all workflow types (workflow_task_ids) stored in this state backend. :return: Iterator of workflow task IDs representing different workflow types """ return iter(self._workflow_types)
[docs] def get_all_workflow_runs(self) -> Iterator["WorkflowIdentity"]: """ Retrieve workflow run identities from this state backend. :return: Iterator of workflow identities for runs """ return itertools.chain.from_iterable(self._workflow_runs.values())
[docs] def get_workflow_runs( self, workflow_type: "TaskId" ) -> Iterator["WorkflowIdentity"]: """ Retrieve workflow run identities from this state backend. :param workflow_type: Filter for specific workflow type :return: Iterator of workflow identities for runs """ return iter(self._workflow_runs.get(workflow_type, []))
[docs] def store_workflow_sub_invocation( self, parent_workflow_id: "InvocationId", sub_invocation_id: "InvocationId" ) -> None: """ Store a sub-invocation ID that runs inside a parent workflow. :param parent_workflow_id: The workflow ID that contains the sub-invocation :param sub_invocation_id: The invocation ID of the task/sub-workflow running inside """ self._workflow_sub_invocations[parent_workflow_id].add(sub_invocation_id)
[docs] def get_workflow_sub_invocations( self, workflow_id: "InvocationId" ) -> Iterator["InvocationId"]: """ Retrieve all sub-invocation IDs that run inside a specific workflow. :param workflow_id: The workflow ID to get sub-invocations for :return: Iterator of invocation IDs that run inside the workflow """ return iter(self._workflow_sub_invocations.get(workflow_id, set()))
[docs] def iter_invocations_in_timerange( self, start_time: datetime, end_time: datetime, batch_size: int = 100, ) -> Iterator[list["InvocationId"]]: """Iterate over invocation IDs that have history within time range.""" # Collect all invocation IDs that have history entries in the time range matching_ids: set[InvocationId] = set() for invocation_id, history_list in self._history.items(): for history_entry in history_list: if start_time <= history_entry.timestamp <= end_time: matching_ids.add(invocation_id) break # Found match, no need to check more entries # Yield in batches id_list = sorted(matching_ids) for i in range(0, len(id_list), batch_size): yield id_list[i : i + batch_size]
[docs] def iter_history_in_timerange( self, start_time: datetime, end_time: datetime, batch_size: int = 100, ) -> Iterator[list["InvocationHistory"]]: """Iterate over history entries within time range.""" # Collect all history entries in the time range matching_entries: list[tuple[datetime, InvocationHistory]] = [] for history_list in self._history.values(): for history_entry in history_list: if start_time <= history_entry.timestamp <= end_time: matching_entries.append((history_entry.timestamp, history_entry)) # Sort by timestamp matching_entries.sort(key=lambda x: x[0]) # Yield in batches for i in range(0, len(matching_entries), batch_size): batch = [entry for _, entry in matching_entries[i : i + batch_size]] yield batch
[docs] def _store_runner_context(self, runner_context: "RunnerContext") -> None: """ Store a runner context. :param str runner_id: The runner's unique identifier :param RunnerContext runner_context: The context to store """ self._runner_contexts[runner_context.runner_id] = runner_context
[docs] def _get_runner_context(self, runner_id: str) -> "RunnerContext | None": """ Retrieve a runner context by runner_id. :param str runner_id: The runner's unique identifier :return: The stored RunnerContext or None if not found """ return self._runner_contexts.get(runner_id)
[docs] def _get_runner_contexts(self, runner_ids: list[str]) -> list["RunnerContext"]: """ Retrieve multiple runner contexts by their IDs. :param list[str] runner_ids: List of runner unique identifiers :return: list["RunnerContext"] of the stored RunnerContexts """ return [ self._runner_contexts[runner_id] for runner_id in runner_ids if runner_id in self._runner_contexts ]
[docs] def get_matching_runner_contexts( self, partial_id: str ) -> Iterator["RunnerContext"]: """Search runner contexts by partial ID match.""" for rid, ctx in self._runner_contexts.items(): if partial_id in rid: yield ctx
[docs] def get_invocation_ids_by_workflow( self, workflow_id: str | None = None, workflow_type_key: str | None = None, ) -> Iterator["InvocationId"]: """Retrieve invocation IDs filtered by workflow criteria.""" for _inv_id, (inv_dto, _call_dto) in self._cache.items(): wf = inv_dto.workflow if workflow_id and str(wf.workflow_id) != workflow_id: continue if workflow_type_key and wf.workflow_type.key != workflow_type_key: continue yield inv_dto.invocation_id