from __future__ import annotations
import importlib
import json
import time
from collections.abc import Iterable
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload
from pynenc import context
from pynenc.arguments import Arguments
from pynenc.call import Call, PreSerializedCall
from pynenc.conf.config_task import ConcurrencyControlType, ConfigTask
from pynenc.exceptions import InvalidTaskOptionsError, RetryError
from pynenc.invocation.base_invocation import BaseInvocation, BaseInvocationGroup
from pynenc.invocation.conc_invocation import (
ConcurrentInvocation,
ConcurrentInvocationGroup,
)
from pynenc.invocation.dist_invocation import DistributedInvocationGroup
from pynenc.types import Func, Params, Result
from pynenc.util.log import TaskLoggerAdapter
if TYPE_CHECKING:
from pynenc.app import Pynenc
[docs]
class Task(Generic[Params, Result]):
"""
**A task in the Pynenc library that represents a function that can be distributed.**
:param Pynenc app: A reference to the Pynenc application.
:param Callable func: The function to be run distributed.
:param dict[str, Any] **options: The options to apply.
The `BaseTask` can be called normally and will return an instance of `BaseResult`.
The result will be an `AsyncResult` when running normally but can be `SyncResult`
when running eagerly in development with the `pynenc` app's `dev_mode_force_sync_tasks`
option set to `True` (or the 'PYNENC_DEV_MODE_FORCE_SYNC_TASK' environment variable set).
The option `dev_mode_force_sync_tasks` should only be used in development.
```{hint}
Although it is possible to create a `BaseTask` instance directly,
it is recommended to use the decorator provided in the `pynenc` application, i.e., `@app.task(options...)`.
This is the expected way of instantiating a class and registering it in the app.
```
### Limitations
```{attention}
This implementation does not support the creation of tasks
from functions defined in modules intended to run as standalone scripts.
```
This applies to any module executed directly, where its `__name__` attribute becomes `"__main__"`.
This is not exclusive to modules with `if __name__ == "__main__"`
sections but includes any module run as the main program.
In such situations, `func.__module__` being `"__main__"`
poses a challenge for task instantiation and serialization.
When a task is executed in the initiator script, it is identified as `__main__.task_name`.
However, in a Pynenc worker's distributed environment, `__main__` refers to the worker itself.
As a result, the task identified as `__main__.task_name` cannot be found,
since the worker's `__main__` differs from that of the initiator script.
**To ensure simplicity and robustness in task management,
tasks defined in modules run as the main program are not supported.**
### Examples
```{code-block} python
@app.task(options)
def func():
pass
result = func()
```
### Raises
- **RuntimeError:** If an attempt is made to create a task from a function in the `__main__` module.
"""
def __init__(self, app: Pynenc, func: Func, options: dict[str, Any]) -> None:
if "__main__" in func.__module__:
raise RuntimeError(
"Cannot create a task from a function in the __main__ module"
)
self.task_id = f"{func.__module__}.{func.__name__}"
self.app = app
self.logger = TaskLoggerAdapter(self.app.logger, self.task_id)
self.func = func
self.options = options
self.validate_options()
[docs]
def validate_options(self) -> None:
"""
validate that all the option fields exists in the config_fields
it will raise an exception with all the invalid options
"""
invalid_options = []
for option in self.options:
if option not in ConfigTask.config_fields():
invalid_options.append(option)
if invalid_options:
raise InvalidTaskOptionsError(
self.task_id, f"Invalid options: {invalid_options}"
)
@cached_property
def conf(self) -> ConfigTask:
return ConfigTask(
task_id=self.task_id,
config_values=self.app.config_values,
config_filepath=self.app.config_filepath,
task_options=self.options,
)
@property
def invocation(self) -> BaseInvocation:
"""The invocation of the task"""
if dist_inv := context.get_dist_invocation_context(self.app.app_id):
return dist_inv
if sync_inv := context.sync_inv_context.get(self.app.app_id):
return sync_inv
raise RuntimeError("Task has not been invoked yet")
[docs]
def to_json(self) -> str:
""":return: The serialized task"""
return json.dumps(
{"task_id": self.task_id, "options": self.conf.options_to_json()}
)
[docs]
def __getstate__(self) -> dict:
# Return state as a dictionary and a secondary value as a tuple
return {"app": self.app, "task_json": self.to_json()}
[docs]
def __setstate__(self, state: dict) -> None:
# Restore instance attributes
self.app = state["app"]
serialized = state["task_json"]
task_id, func, options = Task._from_json(serialized)
# Restore the cached property
self.task_id = task_id
self.app = self.app
self.func = func
self.options = options
self.logger = TaskLoggerAdapter(self.app.logger, self.task_id)
[docs]
@staticmethod
def _from_json(serialized: str) -> tuple[str, Func, dict[str, Any]]:
""":return: a function and options from a serialized task"""
task_dict = json.loads(serialized)
task_id = task_dict["task_id"]
module_name, function_name = task_id.rsplit(".", 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
options = ConfigTask.options_from_json(task_dict["options"])
# Check if the function is a Task (from @task) or a plain function (from @direct_task)
if isinstance(function, Task):
return task_id, function.func, options
# For direct_task, return the function itself
return task_id, function.__inner_function__, options # type: ignore
[docs]
@classmethod
def from_json(cls, app: Pynenc, serialized: str) -> Task:
""":return: a new task from a serialized task"""
_, func, options = cls._from_json(serialized)
return cls(app, func, options)
@cached_property
def retriable_exceptions(self) -> tuple[type[Exception], ...]:
"""
Retrieve a tuple of exception types that should trigger a retry of the task.
This method provides a list of exception types, indicating which exceptions will cause the task to be retried.
The `RetryError` exception type, specific to the Pynenc system, is always included to ensure that internal retry
mechanisms are accounted for.
:return: A tuple of retriable exceptions.
"""
if not self.conf.retry_for:
return (RetryError,)
if RetryError in self.conf.retry_for:
return self.conf.retry_for
return self.conf.retry_for + (RetryError,)
[docs]
def __str__(self) -> str:
return f"Task(func={self.func.__name__})"
[docs]
def __repr__(self) -> str:
return self.__str__()
[docs]
def args(self, *args: Params.args, **kwargs: Params.kwargs) -> Arguments:
""":return: an Arguments instance from the given args and kwargs"""
return Arguments.from_call(self.func, *args, **kwargs)
[docs]
def __call__(
self, *args: Params.args, **kwargs: Params.kwargs
) -> BaseInvocation[Params, Result]:
"""Handles a call to the task"""
arguments = Arguments.from_call(self.func, *args, **kwargs)
return self._call(arguments)
[docs]
def _call(self, arguments: Arguments) -> BaseInvocation[Params, Result]:
"""
Route the call to the orchestrator if not in dev mode, otherwise run synchronously
:return: the invocation
"""
if self.app.conf.dev_mode_force_sync_tasks:
return ConcurrentInvocation(call=Call(self, arguments))
call = Call(self, arguments)
return self.app.orchestrator.route_call(call)
@overload
def parallelize(
self,
param_iter: Iterable[tuple | dict | Arguments],
common_args: None = None,
) -> BaseInvocationGroup:
...
@overload
def parallelize(
self,
param_iter: Iterable[dict],
common_args: dict,
) -> BaseInvocationGroup:
...
[docs]
def parallelize(
self,
param_iter: Iterable[tuple | dict | Arguments],
common_args: dict | None = None,
) -> BaseInvocationGroup:
"""
Parallelize the execution of a task with different sets of parameters.
This method allows for concurrent execution of the same task with varying parameters.
It accepts an iterable where each element represents a set of parameters for a separate task invocation.
When `common_args` is provided, `param_iter` must be an iterable of dictionaries, and common arguments
are pre-serialized for efficiency.
```{note}
Without common_args param_iter can be specified in different formats:
- As a tuple: Interpreted as positional arguments for the task.
- As a dictionary: Interpreted as keyword arguments for the task.
- As an `Arguments` instance: Created using `task.args(*args, **kwargs)`.
```
```{important}
common_args is intended for optimize parallelization of huge arguments that will be cached by the arg_cache.
if the arguments are small or the arg_cache is disabled, it will not provide any major improvement.
However, for big arguments, it will provide massive time and memory improvements.
```
:param Iterable[tuple | dict | Arguments] param_iter: An iterable of parameters for each call.
Each element in the iterable is used to invoke the task separately.
:param common_args: Optional dictionary of common arguments to pre-serialize and share across calls.
:return: A group of task invocations, allowing the task to be run in parallel with different parameters.
The type of group (synchronous or distributed) depends on the application's configuration.
```{important}
Depending on the configuration, this method creates a group of either synchronous or distributed invocations.
In development mode, where `dev_mode_force_sync_tasks` is enabled, it creates synchronous invocations.
Otherwise, it creates distributed invocations for parallel processing.
```
### Examples
Parallelization with tuples, dicts and arguments:
```{code-block} python
app = Pynenc()
@app.task
def add(x: int, y: int) -> int:
return x + y
# Example usage of parallelize
invocation_group = add.parallelize([(1, 1), add.args(1, 2), {"x": 2, "y": 3}])
print(list(invocation_group.results))
# prints [2, 3, 5]
```
Parallelization with common_args:
```python
@app.task(registration_concurrency=ConcurrencyControlType.DISABLED)
def process(large_data: str, index: int) -> int:
return len(large_data) + index
# With common_args
common = {"large_data": "huge_string"}
params = [{"index": i} for i in range(3)]
invocation_group = process.parallelize(params, common)
print(list(invocation_group.results)) # [len("huge_string") + i for i in range(3)]
```
"""
self.logger.info(f"parallelizing {self.task_id}")
# Convert param_iter to a list to allow multiple iterations and length checking
param_list = list(param_iter)
if common_args is not None and not all(
isinstance(p, dict) for p in param_list or {}
):
raise ValueError(
"When using common_args, param_iter must contain only dictionaries"
)
# Choose distribution strategy based on whether batch processing is available
if can_batch_process(self, len(param_list)):
return distribute_batch_calls(self, param_list, common_args)
return distribute_calls(self, param_list, common_args)
[docs]
def can_batch_process(task: Task, num_calls: int) -> bool:
"""
Determine if a task can be processed in batches.
A task can be batch processed when:
1. The application is not in development mode
2. Registration concurrency is disabled
3. There are multiple calls to process
4. The task has a valid parallel_batch_size
:param Task task: The task to check
:param int num_calls: The number of calls to process
:return: True if the task can be batch processed, False otherwise
"""
return (
not task.app.conf.dev_mode_force_sync_tasks
and task.conf.registration_concurrency == ConcurrencyControlType.DISABLED
and num_calls > 1
and task.conf.parallel_batch_size > 0
)
[docs]
def prepare_arguments(
task: Task,
param_iter: Iterable[tuple | dict | Arguments],
common_args: dict | None = None,
) -> list[Arguments]:
"""
Convert various parameter formats to a list of Arguments objects,
merging with common_args when provided.
:param Task task: The task to prepare arguments for
:param Iterable[tuple | dict | Arguments] param_iter: Iterable of parameters
:param dict | None common_args: Optional common arguments to merge with each parameter
:return: A list of Arguments objects with common_args merged in
"""
result = []
for params in param_iter:
if common_args and not isinstance(params, dict):
raise ValueError(
"common_args can only be used with an iterable of dictionaries"
)
if isinstance(params, tuple):
args_obj = task.args(*params)
elif isinstance(params, dict):
if common_args:
# Create a new dict with common_args as base, then update with specific params
merged_kwargs = common_args.copy()
merged_kwargs.update(params)
args_obj = task.args(**merged_kwargs)
else:
args_obj = task.args(**params)
else:
args_obj = params
result.append(args_obj)
return result
[docs]
def distribute_calls(
task: Task,
param_list: list[tuple | dict | Arguments],
common_args: dict | None = None,
) -> BaseInvocationGroup:
"""
Distribute calls individually without batch processing.
:param Task task: The task to process
:param list[tuple | dict | Arguments] param_list: List of parameters
:param dict | None common_args: Optional common arguments to merge with each parameter
:return: A list of created invocations
"""
# Prepare arguments with common_args merged in
all_args = prepare_arguments(task, param_list, common_args)
# Standard processing - distribute calls normally
invocations = []
for args in all_args:
invocation = task._call(args)
if invocation:
invocations.append(invocation)
group_cls: type[BaseInvocationGroup]
if task.app.conf.dev_mode_force_sync_tasks:
group_cls = ConcurrentInvocationGroup
else:
group_cls = DistributedInvocationGroup
return group_cls(task, invocations)
[docs]
def distribute_batch_calls(
task: Task[Params, Result],
param_list: list[tuple | dict | Arguments],
common_args: dict | None = None,
) -> DistributedInvocationGroup:
"""
Process a list of parameters in batches using PreSerializedCall.
Handles pre-serialization of common arguments for efficient distribution.
:param Task task: The task to process
:param list[tuple | dict | Arguments] param_list: The arguments to process
:param dict | None common_args: Optional common arguments to be pre-serialized once
:return: An invocation group for the distributed calls
"""
# Pre-serialize common arguments if provided
pre_serialized_args = {}
other_args: list[dict[str, Any]] = []
if common_args:
task.logger.info("Pre-serializing common arguments for batch parallelization")
pre_serialized_args = {
k: task.app.arg_cache.serialize(v) for k, v in common_args.items()
}
task.logger.debug(f"Pre-serialized {len(pre_serialized_args)} common arguments")
other_args = param_list # type: ignore
else:
other_args = [a.kwargs for a in prepare_arguments(task, param_list)]
invocations = []
batch_size = task.conf.parallel_batch_size
task.logger.info(f"Processing {len(other_args)} calls in batches of {batch_size}")
start_time = time.time()
# Process in batches
for i in range(0, len(other_args), batch_size):
batch_args = other_args[i : i + batch_size]
batch_calls = [
PreSerializedCall(
task, other_args=args, pre_serialized_args=pre_serialized_args
)
for args in batch_args
]
batch_invocations = task.app.orchestrator.route_calls(batch_calls)
invocations.extend(batch_invocations)
if i + batch_size < len(other_args):
task.logger.info(
f"Processed batch {i//batch_size + 1}, "
f"{len(invocations)}/{len(other_args)} invocations"
)
elapsed = time.time() - start_time
task.logger.info(f"Batch processed {len(invocations)} calls in {elapsed:.2f}s")
return DistributedInvocationGroup(task, invocations)