import json
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic
from pynenc.arguments import Arguments
from pynenc.conf.config_task import ConcurrencyControlType
from pynenc.types import Params, Result
if TYPE_CHECKING:
from pynenc.app import Pynenc
from pynenc.task import Task
[docs]
@dataclass
class Call(Generic[Params, Result]):
"""
Base class for task calls with common functionality.
:param Task[Params, Result] task: The task associated with the call.
"""
task: "Task[Params, Result]"
_arguments: "Arguments" = field(default_factory=Arguments, repr=False)
_serialized_arguments: dict[str, str] | None = None
@property
def app(self) -> "Pynenc":
"""
Gets the Pynenc application instance associated with the task.
:return: The Pynenc application instance.
"""
return self.task.app
@property
def arguments(self) -> "Arguments":
"""
Get the arguments for this call.
This property allows subclasses to override argument handling.
:return: Arguments object containing call arguments
"""
return self._arguments
@property
def call_id(self) -> str:
"""
Generates a unique identifier for the call based on the task ID and the arguments.
:return: A string representing the unique identifier of the call.
"""
return "#task_id#" + self.task.task_id + "#args_id#" + self.arguments.args_id
@cached_property
def serialized_arguments(self) -> dict[str, str]:
"""
Serializes the call arguments into strings.
:return: A dictionary of serialized argument strings.
"""
if self._serialized_arguments:
return self._serialized_arguments
disable_cache = "*" in self.task.conf.disable_cache_args
return {
k: self.app.arg_cache.serialize(
v, disable_cache or k in self.task.conf.disable_cache_args
)
for k, v in self.arguments.kwargs.items()
}
@cached_property
def serialized_args_for_concurrency_check(self) -> dict[str, str] | None:
"""
Determines the call arguments required for the task concurrency check.
:return: A dictionary of serialized argument strings required for concurrency control, or None if concurrency control is disabled.
"""
if self.task.conf.registration_concurrency == ConcurrencyControlType.DISABLED:
return None
if self.task.conf.registration_concurrency == ConcurrencyControlType.TASK:
return None
if self.task.conf.registration_concurrency == ConcurrencyControlType.ARGUMENTS:
return self.serialized_arguments
if self.task.conf.registration_concurrency == ConcurrencyControlType.KEYS:
return {
key: self.serialized_arguments[key]
for key in self.task.conf.key_arguments
}
return None
[docs]
def deserialize_arguments(self, serialized_arguments: dict[str, str]) -> Arguments:
"""
Deserializes the given serialized arguments.
:param dict[str, str] serialized_arguments: The serialized arguments to deserialize.
:return: An Arguments object representing the deserialized arguments.
"""
return Arguments(
{
k: self.app.arg_cache.deserialize(v)
for k, v in serialized_arguments.items()
}
)
[docs]
def to_json(self) -> str:
"""
Serializes the call into a JSON string.
:return: A JSON string representing the serialized call.
"""
return json.dumps(
{"task": self.task.to_json(), "arguments": self.serialized_arguments}
)
[docs]
def __getstate__(self) -> dict:
"""
Gets the state of the Call object for serialization purposes.
:return: A dictionary representing the state of the Call object.
"""
return {"task": self.task, "arguments": self.serialized_arguments}
[docs]
def __setstate__(self, state: dict) -> None:
"""
Sets the state of the Call object from the provided dictionary.
:param dict state: A dictionary representing the state to set.
"""
object.__setattr__(self, "task", state["task"])
arguments = self.deserialize_arguments(state["arguments"])
object.__setattr__(self, "_arguments", arguments)
[docs]
@classmethod
def from_json(cls, app: "Pynenc", serialized: str) -> "Call":
"""
Creates a Call object from a serialized JSON string.
:param Pynenc app: The Pynenc application instance.
:param str serialized: The serialized JSON string representing the call.
:return: A Call object created from the serialized data.
"""
from pynenc.task import Task
call_dict = json.loads(serialized)
return cls(
task=Task.from_json(app, call_dict["task"]),
_arguments=Arguments(
{
k: app.arg_cache.deserialize(v)
for k, v in call_dict["arguments"].items()
}
),
_serialized_arguments=call_dict["arguments"],
)
[docs]
def __str__(self) -> str:
return f"Call(call_id={self.call_id}, task={self.task}, arguments={self.arguments})"
[docs]
def __repr__(self) -> str:
return self.__str__()
[docs]
def __hash__(self) -> int:
return hash(self.call_id)
[docs]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Call):
return False
return self.call_id == other.call_id
[docs]
@dataclass
class PreSerializedCall(Call[Params, Result]):
"""
Represents a call optimized for parallel routing with pre-serialized arguments.
This call type is used for batch processing tasks with disabled concurrency control,
where some arguments are pre-serialized (e.g., large shared data).
:param Task[Params, Result] task: The task associated with the call.
:param dict[str, Any] other_args: Unique arguments for this specific call.
:param dict[str, str] pre_serialized_args: Pre-serialized common arguments.
"""
other_args: dict[str, Any] = field(default_factory=dict)
pre_serialized_args: dict[str, str] = field(default_factory=dict)
_call: Call[Params, Result] | None = field(default=None, repr=False)
_cached_arguments: Arguments | None = field(default=None, repr=False)
_serialized_arguments: dict[str, str] | None = field(default=None, repr=False)
_args_hash: str | None = field(default=None, repr=False)
@property
def call(self) -> "Call[Params, Result]":
if self._call is None:
self.app.logger.warning(
"Generating a regular Call object from a RoutingParallelCall "
"is inefficient and should be avoided if possible. "
)
self._call = Call(
task=self.task,
_arguments=self.deserialize_arguments(self.serialized_arguments),
)
return self._call
@property
def arguments(self) -> "Arguments":
if not self._cached_arguments:
self._cached_arguments = self.deserialize_arguments(
self.serialized_arguments
)
return self._cached_arguments
@property
def call_id(self) -> str:
"""
Get the call_id from the underlying Call object.
:return: A string representing the unique identifier of the call.
"""
return self.call.call_id
@property
def serialized_arguments(self) -> dict[str, str]:
"""
Serializes the call arguments into strings.
:return: A dictionary of serialized argument strings.
"""
if self._serialized_arguments is None:
disable_cache = "*" in self.task.conf.disable_cache_args
serialized = {
k: self.app.arg_cache.serialize(
v, disable_cache or k in self.task.conf.disable_cache_args
)
for k, v in self.other_args.items()
if k not in self.pre_serialized_args
}
self._serialized_arguments = {**serialized, **self.pre_serialized_args}
return self._serialized_arguments
@property
def serialized_args_for_concurrency_check(self) -> dict[str, str] | None:
raise NotImplementedError(
"RoutingParallelCall does not support serialized_args_for_concurrency_check "
"(intended for batch routing only)"
)
[docs]
def __getstate__(self) -> dict:
return {
"task": self.task,
"other_args": {
k: v
for k, v in self.serialized_arguments.items()
if k in self.other_args
},
"pre_serialized_args": self.pre_serialized_args,
}
[docs]
def __setstate__(self, state: dict) -> None:
object.__setattr__(self, "task", state["task"])
other_args = self.deserialize_arguments(state["other_args"]).kwargs
object.__setattr__(self, "other_args", other_args)
object.__setattr__(self, "pre_serialized_args", state["pre_serialized_args"])
[docs]
@classmethod
def from_json(cls, app: "Pynenc", serialized: str) -> "Call":
"""
PreSerializedCall doesn't support from_json as it's meant for batch routing.
:raises NotImplementedError: This method is not implemented for PreSerializedCall.
"""
raise NotImplementedError(
"PreSerializedCall does not support from_json method "
"(Use a regular Call object for serialization)"
)
[docs]
def __str__(self) -> str:
return f"PreSerializedCall(task={self.task}, other_args={self.other_args}, pre_serialized_args={list(self.pre_serialized_args.keys())})"
[docs]
def __hash__(self) -> int:
raise NotImplementedError(
"RoutingParallelCall does not support __hash__ (not intended for sets/dicts)"
)
[docs]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Call):
return False
if isinstance(other, PreSerializedCall):
return self.call == other.call
return self.call == other