import pickle
import threading
from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterator
from time import time
from typing import TYPE_CHECKING, Any, Generic, Optional
from pynenc.exceptions import CycleDetectedError, PendingInvocationLockError
from pynenc.invocation.status import InvocationStatus
from pynenc.orchestrator.base_orchestrator import (
BaseBlockingControl,
BaseCycleControl,
BaseOrchestrator,
)
from pynenc.types import Params, Result
if TYPE_CHECKING:
from pynenc.app import Pynenc
from pynenc.call import Call
from pynenc.invocation.dist_invocation import DistributedInvocation
from pynenc.task import Task
[docs]
class MemCycleControl(BaseCycleControl):
"""
An implementation of cycle control using a directed acyclic graph (DAG) to represent call dependencies.
This class manages dependencies between task invocations to prevent call cycles, which could lead to deadlocks or infinite loops.
:param Pynenc app: The Pynenc application instance.
"""
def __init__(self, app: "Pynenc") -> None:
self.app = app
self.invocations: dict[str, "DistributedInvocation"] = {}
self.calls: dict[str, "Call"] = {}
self.call_to_invocation: dict[
str, OrderedDict[str, "DistributedInvocation"]
] = defaultdict(OrderedDict)
self.edges: dict[str, set[str]] = defaultdict(set)
[docs]
def add_call_and_check_cycles(
self, caller: "DistributedInvocation", callee: "DistributedInvocation"
) -> None:
"""
Adds a new invocation to the graph, representing a dependency where the caller is dependent on the callee.
Raises a CycleDetectedError if adding the invocation would cause a cycle in the call graph.
:param DistributedInvocation caller: The invocation making the call.
:param DistributedInvocation callee: The invocation being called.
:raises CycleDetectedError: If adding the invocation causes a cycle.
"""
if caller.call_id == callee.call_id:
raise CycleDetectedError.from_cycle([caller.call])
if cycle := self.find_cycle_caused_by_new_invocation(caller, callee):
raise CycleDetectedError.from_cycle(cycle)
self.invocations[caller.invocation_id] = caller
self.invocations[callee.invocation_id] = callee
self.calls[caller.call_id] = caller.call
self.calls[callee.call_id] = callee.call
self.call_to_invocation[caller.call_id][caller.invocation_id] = caller
self.call_to_invocation[callee.call_id][callee.invocation_id] = callee
self.edges[caller.call_id].add(callee.call_id)
[docs]
def clean_up_invocation_cycles(self, invocation: "DistributedInvocation") -> None:
"""
Removes an invocation from the graph, along with any edges to or from the invocation.
:param DistributedInvocation invocation: The invocation to be removed from the graph.
"""
call_id = invocation.call_id
if call_id in self.call_to_invocation:
self.call_to_invocation[call_id].pop(invocation.invocation_id, None)
if not self.call_to_invocation[call_id]:
del self.call_to_invocation[call_id]
if call_id in self.edges:
del self.edges[call_id]
for edges in self.edges.values():
edges.discard(call_id)
[docs]
def find_cycle_caused_by_new_invocation(
self, caller: "DistributedInvocation", callee: "DistributedInvocation"
) -> list["Call"]:
"""
Determines if adding a new edge from the caller to the callee would create a cycle in the graph.
:param DistributedInvocation caller: The invocation making the call.
:param DistributedInvocation callee: The invocation being called.
:return: A list of Calls that would form a cycle after adding the new invocation, else an empty list.
:rtype: list[Call]
"""
# Temporarily add the edge to check if it would cause a cycle
self.edges[caller.call_id].add(callee.call_id)
# Set for tracking visited nodes
visited: set[str] = set()
# List for tracking the nodes on the path from caller to callee
path: list[str] = []
cycle = self._is_cyclic_util(caller.call_id, visited, path)
# Remove the temporarily added edge
self.edges[caller.call_id].remove(callee.call_id)
return cycle
[docs]
def _is_cyclic_util(
self,
current_call_id: str,
visited: set[str],
path: list[str],
) -> list["Call"]:
"""
Utility function for cycle detection in the graph.
:param str current_call_id: The current call ID being checked for cycles.
:param set[str] visited: Set of already visited call IDs.
:param list[str] path: Current path of call IDs being traversed.
:return: A list of Calls that form a cycle, if one is detected.
:rtype: list[Call]
"""
visited.add(current_call_id)
path.append(current_call_id)
for neighbour_call_id in self.edges.get(current_call_id, []):
if neighbour_call_id not in visited:
cycle = self._is_cyclic_util(neighbour_call_id, visited, path)
if cycle:
return cycle
elif neighbour_call_id in path:
cycle_start_index = path.index(neighbour_call_id)
return [self.calls[_id] for _id in path[cycle_start_index:]]
path.pop()
return []
[docs]
class MemBlockingControl(BaseBlockingControl):
"""
An implementation of blocking control using a directed acyclic graph (DAG) to represent invocation dependencies.
This class manages dependencies between task invocations, ensuring that invocations waiting for others are properly handled.
:param Pynenc app: The Pynenc application instance.
"""
def __init__(self, app: "Pynenc") -> None:
self.app = app
self.invocations: dict[str, "DistributedInvocation"] = {}
self.waiting_for: dict[str, set[str]] = defaultdict(set)
self.waited_by: dict[str, set[str]] = OrderedDict()
[docs]
def waiting_for_results(
self,
caller_invocation: "DistributedInvocation[Params, Result]",
result_invocations: list["DistributedInvocation[Params, Result]"],
) -> None:
"""
Registers that an invocation (waiter) is waiting for the results of other invocations (waited).
:param DistributedInvocation[Params, Result] caller_invocation: The invocation waiting for results.
:param list[DistributedInvocation[Params, Result]] result_invocations: The invocations whose results are being waited for.
"""
waiter = caller_invocation
for waited in result_invocations:
self.invocations[waited.invocation_id] = waited
self.waiting_for[waiter.invocation_id].add(waited.invocation_id)
if waited.invocation_id not in self.waited_by:
self.waited_by[waited.invocation_id] = set()
self.waited_by[waited.invocation_id].add(waiter.invocation_id)
[docs]
def release_waiters(self, invocation: "DistributedInvocation") -> None:
"""
Removes an invocation from the graph, along with any dependencies related to it.
:param DistributedInvocation invocation: The invocation that has finished and will no longer block other invocations.
"""
for waiter_id in self.waited_by.get(invocation.invocation_id, []):
self.waiting_for[waiter_id].discard(invocation.invocation_id)
if not self.waiting_for[waiter_id]:
del self.waiting_for[waiter_id]
self.waited_by.pop(invocation.invocation_id, None)
self.waiting_for.pop(invocation.invocation_id, None)
[docs]
def get_blocking_invocations(
self, max_num_invocations: int
) -> Iterator["DistributedInvocation[Params, Result]"]:
"""
Retrieves invocations that are blocking others but are not themselves waiting for any results.
:param int max_num_invocations: The maximum number of blocking invocations to retrieve.
:return: An iterator over invocations that are blocking others (older firsts).
:rtype: Iterator[DistributedInvocation[Params, Result]]
"""
# Create a snapshot of the keys to avoid mutation during iteration
for inv_id in list(self.waited_by.keys()):
if inv_id not in self.waiting_for:
if self.app.orchestrator.get_invocation_status(
self.invocations[inv_id]
).is_available_for_run():
max_num_invocations -= 1
yield self.invocations[inv_id]
if max_num_invocations == 0:
return
[docs]
class ArgPair:
"""Helper to simulate a Memory cache for key:value pairs in Task Invocations"""
def __init__(self, key: str, value: Any) -> None:
self.key = key
self.value = value
[docs]
def __hash__(self) -> int:
"""Generate a hash that works with serialized values"""
# For string values (like serialized JSON), use the string value directly
if isinstance(self.value, str):
return hash((self.key, self.value))
# For other types, use pickle for consistent hashing
return hash((self.key, pickle.dumps(self.value)))
[docs]
def __eq__(self, other: Any) -> bool:
"""Equality check that works with serialized values"""
if not isinstance(other, ArgPair):
return False
# Basic key comparison
if self.key != other.key:
return False
# For string values (likely serialized JSON), direct comparison
if isinstance(self.value, str) and isinstance(other.value, str):
return self.value == other.value
# For other types, try direct comparison first
if self.value == other.value:
return True
# Last resort: pickle comparison for complex objects
try:
return pickle.dumps(self.value) == pickle.dumps(other.value)
except (pickle.PickleError, TypeError):
# If pickling fails, they're not equal
return False
[docs]
def __str__(self) -> str:
return f"{self.key}:{self.value}"
[docs]
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.__str__()})"
[docs]
class TaskInvocationCache(Generic[Result]):
"""
A cache for storing and managing task invocations and their statuses.
This class provides functionalities to track task invocations, including their arguments, statuses, retries, and auto-purge mechanisms.
:param Pynenc app: The Pynenc application instance.
"""
def __init__(self, app: "Pynenc") -> None:
self.app = app
self.invocations: dict[str, "DistributedInvocation"] = {}
self.args_index: dict[ArgPair, set[str]] = defaultdict(set)
self.status_index: dict[InvocationStatus, set[str]] = defaultdict(set)
self.pending_timer: dict[str, float] = {}
self.pre_pending_status: dict[str, InvocationStatus] = {}
self.invocation_status: dict[str, InvocationStatus] = {}
self.invocation_retries: dict[str, int] = {}
self.invocations_to_purge: deque[tuple[float, str]] = deque()
self.locks: dict[str, threading.Lock] = {}
[docs]
def filter_by_key_arguments(self, key_arguments: dict[str, str]) -> set[str]:
"""
Filters invocations by key arguments, requiring ALL keys to match.
:param dict[str, str] key_arguments: Key-value pairs to filter by
:return: Set of invocation IDs that match ALL provided key-value pairs
"""
if not key_arguments:
return set()
# First, get candidates that match at least one key-value pair
all_candidate_sets = []
for key, value in key_arguments.items():
# Get invocation IDs matching this specific key-value pair
pair = ArgPair(key, value)
matching_ids = self.args_index.get(pair, set())
# If no matches for any single key, we can return early
if not matching_ids:
return set()
all_candidate_sets.append(matching_ids)
# Return only invocations that appear in ALL the candidate sets (intersection)
if not all_candidate_sets:
return set()
# Start with the first set and intersect with each subsequent set
result = all_candidate_sets[0].copy()
for candidate_set in all_candidate_sets[1:]:
result.intersection_update(candidate_set)
# Early exit if intersection becomes empty
if not result:
break
return result
[docs]
def filter_by_statuses(self, statuses: list[InvocationStatus]) -> set[str]:
matched_ids = set()
for status in statuses:
matched_ids.update(self.status_index[status])
return matched_ids
[docs]
def get_invocation(self, invocation_id: str) -> Optional["DistributedInvocation"]:
"""Retrieves an invocation by its ID."""
return self.invocations.get(invocation_id)
[docs]
def get_invocations(
self,
key_arguments: dict[str, str] | None,
statuses: list[InvocationStatus] | None,
) -> Iterator["DistributedInvocation"]:
"""
Retrieves invocations based on provided key arguments and/or status.
:param dict[str, str] | None key_arguments: The key arguments to filter the invocations.
:param list[InvocationStatus] | None status: The statuses to filter the invocations.
:return: An iterator over the filtered invocations.
:rtype: Iterator[DistributedInvocation]
"""
if key_arguments and statuses:
key_matches = self.filter_by_key_arguments(key_arguments)
status_matches = self.filter_by_statuses(statuses)
invocation_ids = key_matches.intersection(status_matches)
elif key_arguments:
invocation_ids = self.filter_by_key_arguments(key_arguments)
elif statuses:
invocation_ids = self.filter_by_statuses(statuses)
else:
invocation_ids = set(self.invocations.keys())
for invocation_id in invocation_ids:
yield self.invocations[invocation_id]
[docs]
def set_up_invocation_auto_purge(
self, invocation: "DistributedInvocation[Params, Result]"
) -> None:
"""
Sets up an invocation for automatic purging after a specified time.
:param DistributedInvocation[Params, Result] invocation: The invocation to be set up for auto-purge.
"""
self.invocations_to_purge.append((time(), invocation.invocation_id))
[docs]
def auto_purge(self) -> None:
"""
Automatically purges invocations that have been in a final state for longer than a specified duration.
"""
end_time = (
time() - self.app.orchestrator.conf.auto_final_invocation_purge_hours * 3600
)
while self.invocations_to_purge and self.invocations_to_purge[0][0] <= end_time:
_, elem = self.invocations_to_purge.popleft()
self.clean_up_invocation(elem)
[docs]
def clean_up_invocation(self, invocation_id: str) -> None:
"""
Cleans up an invocation from the cache.
:param str invocation_id: The ID of the invocation to be cleaned up.
"""
invocation = self.invocations.pop(invocation_id, None)
if invocation:
for key, value in invocation.serialized_arguments.items():
self.args_index[ArgPair(key, value)].discard(invocation_id)
self.status_index[self.invocation_status[invocation_id]].discard(
invocation_id
)
self.invocation_status.pop(invocation_id, None)
self.invocation_retries.pop(invocation_id, None)
self.pending_timer.pop(invocation_id, None)
self.pre_pending_status.pop(invocation_id, None)
[docs]
def set_status(
self,
invocation: "DistributedInvocation[Params, Result]",
status: InvocationStatus,
) -> None:
"""
Sets the status of a specific invocation.
:param DistributedInvocation[Params, Result] invocation: The invocation whose status is to be set.
:param InvocationStatus status: The status to set for the invocation.
"""
if status != InvocationStatus.PENDING:
self.clean_pending_status(invocation)
if (_id := invocation.invocation_id) not in self.invocations:
self.invocations[_id] = invocation
for key, value in invocation.serialized_arguments.items():
self.args_index[ArgPair(key, value)].add(_id)
self.status_index[status].add(_id)
else:
# already exists, remove previous status
self.status_index[self.invocation_status[_id]].discard(_id)
self.status_index[status].add(_id)
self.invocation_status[_id] = status
[docs]
def clean_pending_status(
self, invocation: "DistributedInvocation[Params, Result]"
) -> None:
"""
Cleans the pending status of an invocation if it has exceeded the maximum pending time.
:param DistributedInvocation[Params, Result] invocation: The invocation whose pending status is to be cleaned.
"""
self.pending_timer.pop(invocation.invocation_id, None)
self.pre_pending_status.pop(invocation.invocation_id, None)
[docs]
def set_pending_status(
self, invocation: "DistributedInvocation[Params, Result]"
) -> None:
"""
Sets the status of an invocation to pending, handling any potential locking issues.
:param DistributedInvocation[Params, Result] invocation: The invocation to set to pending status.
:raises PendingInvocationLockError: If the invocation is already in pending status or cannot acquire a lock.
"""
invocation_id = invocation.invocation_id
lock = self.locks.setdefault(invocation_id, threading.Lock())
if not lock.acquire(False):
raise PendingInvocationLockError(invocation_id)
try:
self.pending_timer[invocation_id] = time()
previous_status = self.invocation_status[invocation.invocation_id]
if previous_status == InvocationStatus.PENDING:
raise PendingInvocationLockError(invocation_id)
self.pre_pending_status[invocation_id] = previous_status
self.set_status(invocation, InvocationStatus.PENDING)
finally:
lock.release()
[docs]
def get_status(
self, invocation: "DistributedInvocation[Params, Result]"
) -> InvocationStatus:
"""
Retrieves the current status of an invocation, accounting for pending timeout.
:param DistributedInvocation[Params, Result] invocation: The invocation to get the status for.
:return: The current status of the invocation.
:rtype: InvocationStatus
"""
status = self.invocation_status[invocation.invocation_id]
if status == InvocationStatus.PENDING:
elapsed = time() - self.pending_timer[invocation.invocation_id]
if elapsed > self.app.conf.max_pending_seconds:
pre_pending_status = self.pre_pending_status[invocation.invocation_id]
self.set_status(invocation, pre_pending_status)
return pre_pending_status
return status
[docs]
def increase_retries(
self, invocation: "DistributedInvocation[Params, Result]"
) -> None:
"""
Increases the retry count for a given invocation.
:param DistributedInvocation[Params, Result] invocation: The invocation for which the retry count is to be increased.
"""
self.invocation_retries[invocation.invocation_id] = (
self.invocation_retries.get(invocation.invocation_id, 0) + 1
)
[docs]
def get_retries(self, invocation: "DistributedInvocation[Params, Result]") -> int:
"""
Retrieves the current number of retries for a given invocation.
:param DistributedInvocation[Params, Result] invocation: The invocation to get the retry count for.
:return: The number of retries for the invocation.
:rtype: int
"""
return self.invocation_retries.get(invocation.invocation_id, 0)
[docs]
class MemOrchestrator(BaseOrchestrator):
"""
A memory-based implementation of the Orchestrator,
managing task invocations and their lifecycle.
This class provides an in-memory solution for orchestrating task invocations,
including cycle and blocking controls, as well as caching of invocation statuses and retries.
```{warning}
This orchestrator is not intended for production use.
As it stores all invocations in the running process memory.
```
:param Pynenc app: The Pynenc application instance.
"""
def __init__(self, app: "Pynenc") -> None:
self.cache: dict[str, TaskInvocationCache] = defaultdict(
lambda: TaskInvocationCache(app)
)
self._cycle_control: MemCycleControl | None = None
self._blocking_control: MemBlockingControl | None = None
super().__init__(app)
@property
def cycle_control(self) -> "MemCycleControl":
if not self._cycle_control:
self._cycle_control = MemCycleControl(self.app)
return self._cycle_control
@property
def blocking_control(self) -> "MemBlockingControl":
if not self._blocking_control:
self._blocking_control = MemBlockingControl(self.app)
return self._blocking_control
[docs]
def get_existing_invocations(
self,
task: "Task[Params, Result]",
key_serialized_arguments: dict[str, str] | None = None,
statuses: list[InvocationStatus] | None = None,
) -> Iterator["DistributedInvocation"]:
return self.cache[task.task_id].get_invocations(
key_serialized_arguments, statuses
)
[docs]
def get_invocation(self, invocation_id: str) -> Optional["DistributedInvocation"]:
"""Retrieves an invocation by its ID."""
for task_cache in self.cache.values():
if invocation := task_cache.get_invocation(invocation_id):
return invocation
return None
[docs]
def _set_invocation_status(
self,
invocation: "DistributedInvocation[Params, Result]",
status: InvocationStatus,
) -> None:
self.cache[invocation.task.task_id].set_status(invocation, status)
[docs]
def _set_invocations_status(
self, invocations: list["DistributedInvocation"], status: InvocationStatus
) -> None:
"""
Set the status of multiple invocations at once.
:param list[DistributedInvocation] invocations: The invocations to update.
:param InvocationStatus status: The status to set.
"""
for invocation in invocations:
self.cache[invocation.task.task_id].set_status(invocation, status)
[docs]
def _set_invocation_pending_status(
self, invocation: "DistributedInvocation[Params, Result]"
) -> None:
self.cache[invocation.task.task_id].set_pending_status(invocation)
[docs]
def set_up_invocation_auto_purge(
self, invocation: "DistributedInvocation[Params, Result]"
) -> None:
self.cache[invocation.task.task_id].set_up_invocation_auto_purge(invocation)
[docs]
def auto_purge(self) -> None:
for cache in self.cache.values():
cache.auto_purge()
[docs]
def get_invocation_status(
self, invocation: "DistributedInvocation[Params, Result]"
) -> InvocationStatus:
return self.cache[invocation.task.task_id].get_status(invocation)
[docs]
def increment_invocation_retries(
self, invocation: "DistributedInvocation[Params, Result]"
) -> None:
self.cache[invocation.task.task_id].increase_retries(invocation)
[docs]
def get_invocation_retries(
self, invocation: "DistributedInvocation[Params, Result]"
) -> int:
return self.cache[invocation.task.task_id].get_retries(invocation)
[docs]
def filter_by_status(
self,
invocations: list["DistributedInvocation"],
status_filter: set["InvocationStatus"] | None = None,
) -> list["DistributedInvocation"]:
if not invocations or status_filter is None:
return []
return [
inv
for inv in invocations
if self.get_invocation_status(inv) in status_filter
]
[docs]
def purge(self) -> None:
self.cache.clear()
self._cycle_control = None
self._blocking_control = None