import pickle
import threading
from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterator
from datetime import UTC, datetime
from time import time
from typing import TYPE_CHECKING, Any
from pynenc.invocation.status import (
InvocationStatus,
InvocationStatusRecord,
status_record_transition,
)
from pynenc.orchestrator.base_orchestrator import (
BaseBlockingControl,
BaseOrchestrator,
)
from pynenc.orchestrator.atomic_service import ActiveRunnerInfo
from pynenc.types import Params, Result
if TYPE_CHECKING:
from pynenc.app import Pynenc
from pynenc.identifiers.call_id import CallId
from pynenc.identifiers.invocation_id import InvocationId
from pynenc.invocation.dist_invocation import DistributedInvocation
from pynenc.task import Task, TaskId
[docs]
class MemBlockingControl(BaseBlockingControl):
"""
An implementation of blocking control using a directed acyclic graph (DAG) to represent invocation dependencies.
This class manages dependencies between task invocations, ensuring that invocations waiting for others are properly handled.
:param Pynenc app: The Pynenc application instance.
"""
def __init__(self, app: "Pynenc") -> None:
self.app = app
self._lock = threading.RLock()
self.waiting_for: dict[InvocationId, set[InvocationId]] = defaultdict(set)
self.waited_by: dict[InvocationId, set[InvocationId]] = OrderedDict()
# Maintained set of "ready" invocations: waited on by others but not waiting themselves
self._ready: set[InvocationId] = set()
[docs]
def waiting_for_results(
self,
caller_invocation_id: "InvocationId",
result_invocation_ids: list["InvocationId"],
) -> None:
"""
Notifies the system that an invocation is waiting for the results of other invocations.
:param InvocationId caller_invocation_id: The ID of the invocation that is waiting.
:param list[InvocationId] result_invocation_ids: The IDs of the invocations being waited on.
"""
waiter_id = caller_invocation_id
with self._lock:
for waited_id in result_invocation_ids:
self.waiting_for[waiter_id].add(waited_id)
if waited_id not in self.waited_by:
self.waited_by[waited_id] = set()
self.waited_by[waited_id].add(waiter_id)
# waited_id is ready if it's not itself waiting
if waited_id not in self.waiting_for:
self._ready.add(waited_id)
# waiter_id is now waiting, so it can't be ready
self._ready.discard(waiter_id)
[docs]
def release_waiters(self, waited_invocation_id: "InvocationId") -> None:
"""
Removes an invocation from the graph, along with any dependencies related to it.
:param InvocationId waited_invocation_id: The ID of the invocation that has finished and will no longer block other invocations.
"""
with self._lock:
for waiter_id in self.waited_by.get(waited_invocation_id, []):
self.waiting_for[waiter_id].discard(waited_invocation_id)
if not self.waiting_for[waiter_id]:
del self.waiting_for[waiter_id]
# waiter_id no longer waiting; if it's waited on, it's now ready
if waiter_id in self.waited_by:
self._ready.add(waiter_id)
self.waited_by.pop(waited_invocation_id, None)
self.waiting_for.pop(waited_invocation_id, None)
self._ready.discard(waited_invocation_id)
[docs]
def get_blocking_invocations(
self, max_num_invocations: int
) -> Iterator["InvocationId"]:
"""
Retrieves invocations that are blocking others but are not themselves waiting for any results.
Uses a maintained ready set for O(1) lookup instead of scanning all keys.
:param int max_num_invocations: The maximum number of blocking invocations to retrieve.
:return: An iterator over invocations that are blocking others.
:rtype: Iterator["InvocationId"]
"""
with self._lock:
candidates = list(self._ready)
for inv_id in candidates:
if self.app.orchestrator.get_invocation_status(
inv_id
).is_available_for_run():
max_num_invocations -= 1
yield inv_id
if max_num_invocations == 0:
return
[docs]
class ArgPair:
"""Helper to simulate a Memory cache for key:value pairs in Task Invocations"""
def __init__(self, key: str, value: Any) -> None:
self.key = key
self.value = value
self._hash: int | None = None
[docs]
def __hash__(self) -> int:
"""Generate a hash that works with serialized values, cached after first call."""
if self._hash is None:
if isinstance(self.value, str):
self._hash = hash((self.key, self.value))
else:
self._hash = hash((self.key, pickle.dumps(self.value)))
return self._hash
[docs]
def __eq__(self, other: Any) -> bool:
"""Equality check that works with serialized values"""
if not isinstance(other, ArgPair):
return False
# Basic key comparison
if self.key != other.key:
return False
# For string values (likely serialized JSON), direct comparison
if isinstance(self.value, str) and isinstance(other.value, str):
return self.value == other.value
# For other types, try direct comparison first
if self.value == other.value:
return True
# Last resort: pickle comparison for complex objects
try:
return pickle.dumps(self.value) == pickle.dumps(other.value)
except (pickle.PickleError, TypeError):
# If pickling fails, they're not equal
return False
[docs]
def __str__(self) -> str:
return f"{self.key}:{self.value}"
[docs]
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.__str__()})"
[docs]
class MemOrchestrator(BaseOrchestrator):
"""
A memory-based implementation of the Orchestrator,
managing task invocations and their lifecycle.
This class provides an in-memory solution for orchestrating task invocations,
including blocking controls, as well as caching of invocation statuses and retries.
```{warning}
This orchestrator is not intended for production use.
As it stores all invocations in the running process memory.
```
:param Pynenc app: The Pynenc application instance.
"""
def __init__(self, app: "Pynenc") -> None:
self.app = app
self.task_id_to_inv_id: dict[TaskId, set[InvocationId]] = defaultdict(set)
self.call_id_to_inv_id: dict[CallId, set[InvocationId]] = defaultdict(set)
self.inv_id_to_call_id: dict[InvocationId, CallId] = {}
self.invocation_args: dict[InvocationId, set[ArgPair]] = defaultdict(set)
self.args_index: dict[ArgPair, set[InvocationId]] = defaultdict(set)
self.status_index: dict[InvocationStatus, set[InvocationId]] = defaultdict(set)
self.invocation_status_record: dict[InvocationId, InvocationStatusRecord] = {}
self.invocation_retries: dict[InvocationId, int] = {}
self.invocations_to_purge: deque[tuple[float, InvocationId]] = deque()
self.locks: dict[InvocationId, threading.Lock] = {}
# Runner heartbeat tracking
self.runner_creation_time: dict[str, float] = {}
self.runner_last_heartbeat: dict[str, float] = {}
self.runner_last_service_start: dict[str, datetime] = {}
self.runner_last_service_end: dict[str, datetime] = {}
self.runner_atomic_service_eligible: dict[str, bool] = {}
self._blocking_control: MemBlockingControl | None = None
super().__init__(app)
@property
def blocking_control(self) -> "MemBlockingControl":
if not self._blocking_control:
self._blocking_control = MemBlockingControl(self.app)
return self._blocking_control
[docs]
def _register_new_invocations(
self,
invocations: list["DistributedInvocation[Params, Result]"],
runner_id: str | None = None,
) -> InvocationStatusRecord:
"""Registers new invocations and sets them to REGISTERED status."""
status_record = InvocationStatusRecord(InvocationStatus.REGISTERED, runner_id)
for invocation in invocations:
self._interanl_atomic_status_transition(
invocation.invocation_id, None, status_record
)
self.task_id_to_inv_id[invocation.call.task.task_id].add(
invocation.invocation_id
)
self.call_id_to_inv_id[invocation.call.call_id].add(
invocation.invocation_id
)
self.inv_id_to_call_id[invocation.invocation_id] = invocation.call.call_id
self.invocation_retries[invocation.invocation_id] = 0
return status_record
[docs]
def filter_by_key_arguments(
self, key_arguments: dict[str, str]
) -> set["InvocationId"]:
"""
Filters invocations by key arguments, requiring ALL keys to match.
:param dict[str, str] key_arguments: Key-value pairs to filter by
:return: Set of invocation IDs that match ALL provided key-value pairs
"""
if not key_arguments:
return set()
# First, get candidates that match at least one key-value pair
all_candidate_sets = []
for key, value in key_arguments.items():
# Get invocation IDs matching this specific key-value pair
pair = ArgPair(key, value)
matching_ids = self.args_index.get(pair, set())
# If no matches for any single key, we can return early
if not matching_ids:
return set()
all_candidate_sets.append(matching_ids)
# Return only invocations that appear in ALL the candidate sets (intersection)
if not all_candidate_sets:
return set()
# Start with the first set and intersect with each subsequent set
result = all_candidate_sets[0].copy()
for candidate_set in all_candidate_sets[1:]:
result.intersection_update(candidate_set)
# Early exit if intersection becomes empty
if not result:
break
return result
[docs]
def filter_by_statuses(
self, statuses: list[InvocationStatus]
) -> set["InvocationId"]:
matched_ids = set()
for status in statuses:
matched_ids.update(self.status_index[status])
return matched_ids
[docs]
def get_existing_invocations(
self,
task: "Task[Params, Result]",
key_serialized_arguments: dict[str, str] | None = None,
statuses: "list[InvocationStatus] | None" = None,
) -> Iterator["InvocationId"]:
"""
Retrieves invocation ids based on provided key arguments and/or status.
:param dict[str, str] | None key_arguments: The key arguments to filter the invocations.
:param list[InvocationStatus] | None status: The statuses to filter the invocations.
:return: An iterator over the filtered invocations.
:rtype: Iterator["InvocationId"]
"""
task_matches = self.task_id_to_inv_id.get(task.task_id, set())
if key_serialized_arguments and statuses:
key_matches = self.filter_by_key_arguments(key_serialized_arguments)
status_matches = self.filter_by_statuses(statuses)
yield from task_matches.intersection(key_matches).intersection(
status_matches
)
elif key_serialized_arguments:
key_matches = self.filter_by_key_arguments(key_serialized_arguments)
yield from task_matches.intersection(key_matches)
elif statuses:
status_matches = self.filter_by_statuses(statuses)
yield from task_matches.intersection(status_matches)
else:
yield from task_matches
[docs]
def get_task_invocation_ids(self, task_id: "TaskId") -> Iterator["InvocationId"]:
"""
Retrieves all invocation ids for a given task id.
:param TaskId task_id: The task id to filter the invocations.
:return: An iterator over the invocation ids for the specified task.
:rtype: Iterator["InvocationId"]
"""
yield from self.task_id_to_inv_id.get(task_id, set())
[docs]
def get_invocation_ids_paginated(
self,
task_id: "TaskId | None" = None,
statuses: list[InvocationStatus] | None = None,
limit: int = 100,
offset: int = 0,
) -> list["InvocationId"]:
"""
Retrieves invocation IDs with pagination support.
:param TaskId | None task_id: Optional task ID to filter by.
:param list[InvocationStatus] | None statuses: Optional statuses to filter by.
:param int limit: Maximum number of results to return.
:param int offset: Number of results to skip.
:return: List of matching invocation IDs.
"""
# Build candidate set based on filters
if task_id:
candidates = self.task_id_to_inv_id.get(task_id, set()).copy()
else:
# Collect all invocation IDs across all tasks
candidates = set()
for inv_ids in self.task_id_to_inv_id.values():
candidates.update(inv_ids)
# Filter by statuses if provided
if statuses:
status_matches = self.filter_by_statuses(statuses)
candidates = candidates.intersection(status_matches)
# Sort by status timestamp (newest first) for consistent pagination
sorted_ids = sorted(
candidates,
key=lambda inv_id: self.invocation_status_record.get(
inv_id, InvocationStatusRecord(InvocationStatus.REGISTERED)
).timestamp,
reverse=True,
)
# Apply pagination
return sorted_ids[offset : offset + limit]
[docs]
def count_invocations(
self,
task_id: "TaskId | None" = None,
statuses: list[InvocationStatus] | None = None,
) -> int:
"""
Counts invocations matching the given filters.
:param str | None task_id: Optional task ID to filter by.
:param list[InvocationStatus] | None statuses: Optional statuses to filter by.
:return: The total count of matching invocations.
"""
if task_id:
candidates = self.task_id_to_inv_id.get(task_id, set())
else:
candidates = set()
for inv_ids in self.task_id_to_inv_id.values():
candidates.update(inv_ids)
if statuses:
status_matches = self.filter_by_statuses(statuses)
candidates = candidates.intersection(status_matches)
return len(candidates)
[docs]
def get_call_invocation_ids(self, call_id: "CallId") -> Iterator["InvocationId"]:
"""Retrieves all invocation IDs associated with a specific call ID."""
yield from self.call_id_to_inv_id.get(call_id, set())
[docs]
def set_up_invocation_auto_purge(self, invocation_id: "InvocationId") -> None:
"""
Sets up an invocation for automatic purging after a specified time.
:param InvocationId invocation_id: The ID of the invocation to be set up for auto-purge.
"""
self.invocations_to_purge.append((time(), invocation_id))
[docs]
def auto_purge(self) -> None:
"""
Automatically purges invocations that have been in a final state for longer than a specified duration.
"""
end_time = (
time() - self.app.orchestrator.conf.auto_final_invocation_purge_hours * 3600
)
while self.invocations_to_purge and self.invocations_to_purge[0][0] <= end_time:
_, elem = self.invocations_to_purge.popleft()
self.clean_up_invocation(elem)
[docs]
def clean_up_invocation(self, invocation_id: "InvocationId") -> None:
"""
Cleans up an invocation from the cache.
:param InvocationId invocation_id: The ID of the invocation to be cleaned up.
"""
self.release_waiters(invocation_id)
invocation = self.app.state_backend.get_invocation(invocation_id)
for key, value in invocation.call.serialized_arguments.items():
self.args_index[ArgPair(key, value)].discard(invocation_id)
self.status_index[self.invocation_status_record[invocation_id].status].discard(
invocation_id
)
self.task_id_to_inv_id.get(invocation.task.task_id, set()).discard(
invocation_id
)
self.call_id_to_inv_id.get(invocation.call.call_id, set()).discard(
invocation_id
)
self.inv_id_to_call_id.pop(invocation_id, None)
if args := self.invocation_args.pop(invocation_id, None):
for arg in args:
self.args_index[arg].discard(invocation_id)
self.status_index[self.invocation_status_record[invocation_id].status].discard(
invocation_id
)
self.invocation_status_record.pop(invocation_id, None)
self.invocation_retries.pop(invocation_id, None)
[docs]
def _get_invocation_lock(self, invocation_id: "InvocationId") -> threading.Lock:
"""Get or create a per-invocation lock for atomic transitions.
:param invocation_id: The invocation to get the lock for.
:return: A threading Lock for the given invocation.
"""
if invocation_id not in self.locks:
self.locks[invocation_id] = threading.Lock()
return self.locks[invocation_id]
[docs]
def _atomic_status_transition(
self,
invocation_id: "InvocationId",
status: InvocationStatus,
runner_id: str | None = None,
) -> InvocationStatusRecord:
"""Sets the status record of a specific invocation.
Uses per-invocation locking to prevent TOCTOU race conditions where
multiple threads could read the same current status, both validate
their transition, and both write — producing duplicate history entries.
"""
lock = self._get_invocation_lock(invocation_id)
with lock:
prev_status_record = self.invocation_status_record.get(invocation_id)
new_record = status_record_transition(prev_status_record, status, runner_id)
return self._interanl_atomic_status_transition(
invocation_id, prev_status_record, new_record
)
[docs]
def _interanl_atomic_status_transition(
self,
invocation_id: "InvocationId",
prev_status_record: InvocationStatusRecord | None,
new_record: InvocationStatusRecord,
) -> InvocationStatusRecord:
"""Sets the status record of a specific invocation."""
if prev_status_record:
self.status_index[prev_status_record.status].discard(invocation_id)
self.status_index[new_record.status].add(invocation_id)
self.invocation_status_record[invocation_id] = new_record
return new_record
[docs]
def index_arguments_for_concurrency_control(
self,
invocation: "DistributedInvocation[Params, Result]",
) -> None:
for key, value in invocation.call.serialized_arguments.items():
self.args_index[ArgPair(key, value)].add(invocation.invocation_id)
[docs]
def get_invocation_status_record(
self, invocation_id: "InvocationId"
) -> InvocationStatusRecord:
"""Retrieves the current status of an invocation"""
return self.invocation_status_record[invocation_id]
[docs]
def increment_invocation_retries(self, invocation_id: "InvocationId") -> None:
"""
Increases the retry count for a given invocation.
:param InvocationId invocation_id: The ID of the invocation for which the retry count is to be increased.
"""
self.invocation_retries[invocation_id] = (
self.invocation_retries.get(invocation_id, 0) + 1
)
[docs]
def get_invocation_retries(self, invocation_id: "InvocationId") -> int:
"""
Retrieves the current number of retries for a given invocation.
:param InvocationId invocation_id: The ID of the invocation to get the retry count for.
:return: The number of retries for the invocation.
:rtype: int
"""
return self.invocation_retries.get(invocation_id, 0)
[docs]
def filter_by_status(
self,
invocation_ids: list["InvocationId"],
status_filter: frozenset["InvocationStatus"],
) -> list["InvocationId"]:
if not invocation_ids or status_filter is None:
return []
return [
inv_id
for inv_id in invocation_ids
if self.get_invocation_status(inv_id) in status_filter
]
[docs]
def register_runner_heartbeats(
self, runner_ids: list[str], can_run_atomic_service: bool = False
) -> None:
"""Register or update heartbeat timestamps for one or more runners."""
current_time = time()
for runner_id in runner_ids:
# Only update heartbeat for existing runners, or create new ones
if runner_id not in self.runner_creation_time:
self.runner_creation_time[runner_id] = current_time
self.runner_last_heartbeat[runner_id] = current_time
self.runner_atomic_service_eligible[runner_id] = can_run_atomic_service
[docs]
def _get_active_runners(
self, timeout_seconds: float, can_run_atomic_service: bool | None = None
) -> list[ActiveRunnerInfo]:
"""Retrieve all active runners with heartbeat information."""
current_time = time()
cutoff_time = current_time - timeout_seconds
active_runners = []
for runner_id, last_heartbeat in self.runner_last_heartbeat.items():
if last_heartbeat < cutoff_time:
continue
allow_to_run_atomic_service = self.runner_atomic_service_eligible[runner_id]
if (
can_run_atomic_service is not None
and allow_to_run_atomic_service != can_run_atomic_service
):
continue
creation_ts = self.runner_creation_time[runner_id]
service_start = self.runner_last_service_start.get(runner_id)
service_end = self.runner_last_service_end.get(runner_id)
active_runners.append(
ActiveRunnerInfo(
runner_id=runner_id,
creation_time=datetime.fromtimestamp(creation_ts, tz=UTC),
last_heartbeat=datetime.fromtimestamp(last_heartbeat, tz=UTC),
allow_to_run_atomic_service=allow_to_run_atomic_service,
last_service_start=service_start,
last_service_end=service_end,
)
)
# Sort by creation time (oldest first)
active_runners.sort(key=lambda info: info.creation_time)
return active_runners
[docs]
def record_atomic_service_execution(
self, runner_id: str, start_time: datetime, end_time: datetime
) -> None:
"""Record the latest atomic service execution window for a runner."""
self.runner_last_service_start[runner_id] = start_time
self.runner_last_service_end[runner_id] = end_time
[docs]
def get_pending_invocations_for_recovery(self) -> Iterator["InvocationId"]:
"""Retrieve invocation IDs stuck in PENDING status beyond the allowed time."""
max_pending_seconds = self.app.conf.max_pending_seconds
current_time = time()
cutoff_time = current_time - max_pending_seconds
# Create a snapshot to avoid RuntimeError when status changes during iteration
pending_invocations = list(
self.status_index.get(InvocationStatus.PENDING, set())
)
for invocation_id in pending_invocations:
status_record = self.invocation_status_record.get(invocation_id)
if status_record and status_record.timestamp.timestamp() <= cutoff_time:
yield invocation_id
[docs]
def _get_running_invocations_for_recovery(
self, timeout_seconds: float
) -> Iterator["InvocationId"]:
"""Retrieve RUNNING invocation IDs owned by inactive runners."""
current_time = time()
cutoff_time = current_time - timeout_seconds
# Get set of active runner IDs (those with recent heartbeats)
active_runner_ids = {
runner_id
for runner_id, last_heartbeat in self.runner_last_heartbeat.items()
if last_heartbeat >= cutoff_time
}
# Create a snapshot to avoid RuntimeError when status changes during iteration
running_invocations = list(
self.status_index.get(InvocationStatus.RUNNING, set())
)
for invocation_id in running_invocations:
status_record = self.invocation_status_record.get(invocation_id)
if (
status_record
and status_record.runner_id
and status_record.runner_id not in active_runner_ids
):
yield invocation_id
[docs]
def purge(self) -> None:
self._blocking_control = None
self.task_id_to_inv_id.clear()
self.inv_id_to_call_id.clear()
self.call_id_to_inv_id.clear()
self.invocation_args.clear()
self.args_index.clear()
self.status_index.clear()
self.invocation_status_record.clear()
self.invocation_retries.clear()
self.invocations_to_purge.clear()
self.locks.clear()
self.runner_creation_time.clear()
self.runner_last_heartbeat.clear()
self.runner_last_service_start.clear()
self.runner_last_service_end.clear()