import json
from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterator
import redis
from pynenc import exceptions
from pynenc.conf.config_redis import ConfigRedis
from pynenc.conf.config_state_backend import ConfigStateBackendRedis
from pynenc.invocation.dist_invocation import DistributedInvocation
from pynenc.state_backend.base_state_backend import BaseStateBackend, InvocationHistory
from pynenc.util.redis_client import get_redis_client
from pynenc.util.redis_keys import Key
from pynenc.workflow import WorkflowIdentity
if TYPE_CHECKING:
from pynenc.app import AppInfo, Pynenc
from pynenc.types import Params, Result
[docs]
class RedisStateBackend(BaseStateBackend):
"""
A Redis-based implementation of the state backend.
This backend uses Redis to store and retrieve the state of invocations, including their data,
history, results, and exceptions. It's suitable for distributed systems where shared state management is required.
"""
def __init__(self, app: "Pynenc") -> None:
super().__init__(app)
self._client: redis.Redis | None = None
self.key = Key(app.app_id, "state_backend")
@cached_property
def conf(self) -> ConfigStateBackendRedis:
return ConfigStateBackendRedis(
config_values=self.app.config_values,
config_filepath=self.app.config_filepath,
)
@property
def client(self) -> redis.Redis:
"""Lazy initialization of Redis client"""
if self._client is None:
self._client = get_redis_client(self.conf)
return self._client
[docs]
def purge(self) -> None:
"""Clears all data from the Redis backend for the current `app.app_id`."""
self.key.purge(self.client)
[docs]
def _upsert_invocation(self, invocation: "DistributedInvocation") -> None:
"""
Inserts or updates an invocation in Redis.
:param DistributedInvocation invocation: The invocation object to upsert.
"""
self.client.set(
self.key.invocation(invocation.invocation_id), invocation.to_json()
)
[docs]
def _get_invocation(
self, invocation_id: str
) -> "DistributedInvocation[Params, Result]":
"""
Retrieves an invocation from Redis by its ID.
:param DistributedInvocation invocation_id: The ID of the invocation to retrieve.
:return: The retrieved invocation object.
"""
if inv := self.client.get(self.key.invocation(invocation_id)):
return DistributedInvocation.from_json(self.app, inv.decode())
raise KeyError(f"Invocation {invocation_id} not found")
[docs]
def _add_history(
self,
invocation: "DistributedInvocation",
invocation_history: "InvocationHistory",
) -> None:
"""
Adds a history record to an invocation in Redis.
:param DistributedInvocation invocation: The invocation to add history for.
:param InvocationHistory invocation_history: The history record to add.
"""
self.client.rpush(
self.key.history(invocation.invocation_id),
invocation_history.to_json(),
)
[docs]
def _get_history(
self, invocation: "DistributedInvocation[Params, Result]"
) -> list[InvocationHistory]:
"""
Retrieves the history of an invocation from Redis.
:param DistributedInvocation invocation: The invocation to get the history for.
:return: A list of invocation history records.
"""
return [
InvocationHistory.from_json(h.decode())
for h in self.client.lrange(
self.key.history(invocation.invocation_id), 0, -1
)
]
[docs]
def _set_result(
self, invocation: "DistributedInvocation[Params, Result]", result: "Result"
) -> None:
"""
Sets the result for an invocation in Redis.
:param DistributedInvocation invocation: The invocation to set the result for.
:param Result result: The result of the invocation.
"""
self.client.set(
self.key.result(invocation.invocation_id),
self.app.serializer.serialize(result),
)
[docs]
def _get_result(
self, invocation: "DistributedInvocation[Params, Result]"
) -> "Result":
"""
Retrieves the result of an invocation from Redis.
:param DistributedInvocation invocation: The invocation to get the result for.
:return: The result of the invocation.
"""
if res := self.client.get(self.key.result(invocation.invocation_id)):
return self.app.serializer.deserialize(res.decode())
raise KeyError(f"Result for invocation {invocation.invocation_id} not found")
[docs]
def _set_exception(
self,
invocation: "DistributedInvocation[Params, Result]",
exception: "Exception",
) -> None:
"""
Sets the exception for an invocation in Redis.
:param DistributedInvocation invocation: The invocation to set the exception for.
:param Exception exception: The exception to set.
"""
serialized_exception: dict[str, str | bool] = {
"error_name": exception.__class__.__name__
}
if isinstance(exception, exceptions.PynencError):
serialized_exception["pynenc_error"] = True
serialized_exception["error_data"] = exception.to_json()
else:
serialized_exception["pynenc_error"] = False
serialized_exception["error_data"] = self.app.serializer.serialize(
exception
)
self.client.set(
self.key.exception(invocation.invocation_id),
json.dumps(serialized_exception),
)
[docs]
def _get_exception(
self, invocation: "DistributedInvocation[Params, Result]"
) -> Exception:
"""
Retrieves the exception of an invocation from Redis.
:param DistributedInvocation invocation: The invocation to get the exception for.
:return: The exception of the invocation.
"""
if exc := self.client.get(self.key.exception(invocation.invocation_id)):
serialized_exception = json.loads(exc.decode())
if serialized_exception["pynenc_error"]:
return exceptions.PynencError.from_json(
serialized_exception["error_name"],
serialized_exception["error_data"],
)
return self.app.serializer.deserialize(serialized_exception["error_data"])
raise KeyError(f"Exception for invocation {invocation.invocation_id} not found")
[docs]
def get_workflow_data(
self, workflow_identity: "WorkflowIdentity", key: str, default: Any = None
) -> Any:
"""
Get a value from workflow data.
:param workflow_identity: Workflow identity
:param key: Data key to retrieve
:param default: Default value if key doesn't exist
:return: Stored value or default
"""
data_key = self.key.workflow_data_value(workflow_identity.workflow_id, key)
serialized_value = self.client.get(data_key)
if serialized_value is None:
return default
return self.app.serializer.deserialize(serialized_value.decode())
[docs]
def set_workflow_data(
self, workflow_identity: "WorkflowIdentity", key: str, value: Any
) -> None:
"""
Set a value in workflow data.
:param workflow_identity: Workflow identity
:param key: Data key to set
:param value: Value to store
"""
data_key = self.key.workflow_data_value(workflow_identity.workflow_id, key)
serialized_value = self.app.serializer.serialize(value)
self.client.set(data_key, serialized_value)
[docs]
def get_workflow_deterministic_value(
self, workflow: "WorkflowIdentity", key: str
) -> Any:
"""
Retrieve a deterministic value for workflow operations.
:param workflow: The workflow identity
:param key: Key identifying the deterministic value
:return: The stored value or None if not found
"""
deterministic_key = self.key.workflow_deterministic_value(
workflow.workflow_id, key
)
value = self.client.get(deterministic_key)
if value is None:
return None
return self.app.serializer.deserialize(value.decode())
[docs]
def set_workflow_deterministic_value(
self, workflow: "WorkflowIdentity", key: str, value: Any
) -> None:
"""
Store a deterministic value for workflow operations.
:param workflow: The workflow identity
:param key: Key identifying the deterministic value
:param value: The value to store (must be serializable)
"""
deterministic_key = self.key.workflow_deterministic_value(
workflow.workflow_id, key
)
serialized_value = self.app.serializer.serialize(value)
self.client.set(deterministic_key, serialized_value)
[docs]
def store_app_info(self, app_info: "AppInfo") -> None:
"""
Register this app's information in the state backend for discovery.
:param app_info: The app information to store
"""
self.client.set(self.key.all_apps_info_key(app_info.app_id), app_info.to_json())
[docs]
def get_app_info(self) -> "AppInfo":
"""
Retrieve information of the current app.
:return: The app information
:raises ValueError: If app info is not found
"""
app_info_data = self.client.get(self.key.all_apps_info_key(self.app.app_id))
if not app_info_data:
raise ValueError(f"No app info found for app_id '{self.app.app_id}'")
return AppInfo.from_json(app_info_data.decode())
[docs]
@staticmethod
def get_all_app_infos() -> dict[str, "AppInfo"]:
"""
Retrieve all app information registered in this state backend.
:return: Dictionary mapping app_id to app information
"""
from pynenc.app import AppInfo
redis_client = get_redis_client(ConfigRedis())
# Scan for all app info keys
pattern = Key.all_apps_info_key("*")
all_keys = redis_client.keys(pattern)
# Extract all available app IDs and Info
result = {}
for key in all_keys:
key_str = key.decode() if isinstance(key, bytes) else key
app_id = key_str.split(":")[-1] # Last part is app_id
app_info_data = redis_client.get(key_str)
if app_info_data:
app_info = AppInfo.from_json(app_info_data.decode())
result[app_id] = app_info
return result
[docs]
def store_workflow_run(self, workflow_identity: "WorkflowIdentity") -> None:
"""
Store a workflow run for tracking and monitoring.
Maintains workflow type registry and specific workflow run instances.
This enables monitoring of workflow types and their execution history.
:param workflow_identity: The workflow identity to store
"""
# Store workflow type using key that automatizes purge
workflow_types_key = self.key.workflow_types()
self.client.sadd(workflow_types_key, workflow_identity.workflow_task_id)
# Store workflow run using key that automatizes purge
workflow_runs_key = self.key.workflow_runs(workflow_identity.workflow_task_id)
self.client.lpush(workflow_runs_key, workflow_identity.to_json())
[docs]
def get_all_workflows(self) -> Iterator[str]:
"""
Retrieve all workflow types (workflow_task_ids) stored in this Redis state backend.
:return: Iterator of workflow task IDs representing different workflow types (task_ids)
"""
workflow_types_key = self.key.workflow_types()
workflow_types = self.client.smembers(workflow_types_key)
return (wt.decode() for wt in workflow_types)
[docs]
def get_all_workflows_runs(self) -> Iterator["WorkflowIdentity"]:
"""
Retrieve workflow run identities from this Redis state backend.
:return: Iterator of workflow identities for runs
"""
# Get runs for all workflow types - iterate through known workflow types
for workflow_types_task_id in self.get_all_workflows():
yield from self.get_workflow_runs(workflow_types_task_id)
[docs]
def get_workflow_runs(self, workflow_task_id: str) -> Iterator["WorkflowIdentity"]:
"""
Retrieve workflow run identities from this Redis state backend with pagination.
Uses configurable batch size to efficiently handle large datasets without
overwhelming memory usage by processing data in manageable chunks.
:param workflow_task_id: Filter for specific workflow type
:return: Iterator of workflow identities for runs
"""
workflow_runs_key = self.key.workflow_runs(workflow_task_id)
batch_size = self.conf.pagination_batch_size
start = 0
while batch_data := self.client.lrange(
workflow_runs_key, start, start + batch_size - 1
):
for run_data in batch_data:
yield WorkflowIdentity.from_json(run_data.decode())
start += batch_size
[docs]
def store_workflow_sub_invocation(
self, parent_workflow_id: str, sub_invocation_id: str
) -> None:
"""
Store a sub-invocation ID that runs inside a parent workflow.
:param parent_workflow_id: The workflow ID that contains the sub-invocation
:param sub_invocation_id: The invocation ID of the task/sub-workflow running inside
"""
sub_invocations_key = self.key.workflow_sub_invocations(parent_workflow_id)
self.client.sadd(sub_invocations_key, sub_invocation_id)
[docs]
def get_workflow_sub_invocations(self, workflow_id: str) -> Iterator[str]:
"""
Retrieve all sub-invocation IDs that run inside a specific workflow.
:param workflow_id: The workflow ID to get sub-invocations for
:return: Iterator of invocation IDs that run inside the workflow
"""
sub_invocations_key = self.key.workflow_sub_invocations(workflow_id)
sub_invocation_ids = self.client.smembers(sub_invocations_key)
return (sid.decode() for sid in sub_invocation_ids)