from __future__ import annotations
import importlib
from logging import Logger
import time
from collections.abc import Callable, Iterable
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload
from pynenc import context
from pynenc.arguments import Arguments
from pynenc.call import Call, PreSerializedCall
from pynenc.conf.config_task import ConcurrencyControlType, ConfigTask
from pynenc.core_tasks import CoreTaskFunction
from pynenc.exceptions import InvalidTaskOptionsError, RetryError
from pynenc.invocation.base_invocation import BaseInvocation, BaseInvocationGroup
from pynenc.invocation.conc_invocation import (
ConcurrentInvocation,
ConcurrentInvocationGroup,
)
from pynenc.invocation.dist_invocation import (
DistributedInvocation,
DistributedInvocationGroup,
)
from pynenc.identifiers.task_id import TaskId
from pynenc.types import Func, Params, Result
from pynenc.workflow.workflow_context import WorkflowContext
if TYPE_CHECKING:
from pynenc.app import Pynenc
[docs]
class Task(Generic[Params, Result]):
"""
Represents a distributable function in the Pynenc system.
:param Pynenc app: A reference to the Pynenc application.
:param Callable func: The function to be run distributed.
:param dict[str, Any] **options: The options to apply.
Calling the task returns a `BaseInvocation` whose concrete type depends on
the execution environment (e.g. `DistributedInvocation` or `ConcurrentInvocation`).
Tasks cannot be created from functions defined in ``__main__`` because
workers cannot resolve ``__main__.task_name`` back to the originating module.
Use the ``@app.task`` decorator to create and register tasks.
:raises RuntimeError: If the function is defined in the ``__main__`` module.
"""
def __init__(self, app: Pynenc, func: Func, options: dict[str, Any]) -> None:
if "__main__" in func.__module__:
raise RuntimeError(
"Cannot create a task from a function in the __main__ module"
)
self.task_id = TaskId(func.__module__, func.__name__)
self.app = app
self.func = func
self.options = options
self.validate_options()
[docs]
def validate_options(self) -> None:
"""
validate that all the option fields exists in the config_fields
it will raise an exception with all the invalid options
"""
invalid_options = []
for option in self.options:
if option not in ConfigTask.config_fields():
invalid_options.append(option)
if invalid_options:
raise InvalidTaskOptionsError(
self.task_id, f"Invalid options: {invalid_options}"
)
@cached_property
def conf(self) -> ConfigTask:
return ConfigTask(
task_id=self.task_id,
config_values=self.app.config_values,
config_filepath=self.app.config_filepath,
task_options=self.options,
)
@property
def logger(self) -> Logger:
"""The logger for the task"""
return self.app.logger
@property
def invocation(self) -> BaseInvocation:
"""The invocation of the task"""
if dist_inv := context.get_dist_invocation_context(self.app.app_id):
return dist_inv
if sync_inv := context._get_sync_inv_context_storage().get(self.app.app_id):
return sync_inv
raise RuntimeError("Task has not been invoked yet")
@cached_property
def wf(self) -> WorkflowContext:
"""
Access workflow functionality for this task.
Provides methods for workflow state management, deterministic operations,
and durability features like pause/resume and continue-as-new.
:return: A helper object with workflow functionality
Example:
```python
@app.task
def main_wf_task(data: dict) -> str:
# Save workflow state
state = main_wf_task.wf.get_state({"step": 0})
# Use deterministic random
if main_wf_task.wf.random() > 0.5:
state["path"] = "A"
else:
state["path"] = "B"
# Save updated state
main_wf_task.wf.save_state(state)
# Conditionally pause workflow
if needs_human_approval(data):
main_wf_task.wf.pause("Waiting for approval")
return f"Completed via path {state['path']}"
```
"""
return WorkflowContext(self)
[docs]
def __getstate__(self) -> dict:
# Return state as a dictionary and a secondary value as a tuple
return {
"app": self.app,
"task_id_key": self.task_id.key,
"options_json": self.conf.options_to_json(),
}
[docs]
def __setstate__(self, state: dict) -> None:
# Restore instance attributes
self.app = state["app"]
self.task_id = TaskId.from_key(state["task_id_key"])
self.options = ConfigTask.options_from_json(state["options_json"])
function = Task._get_from_task_id(self.task_id)
# Check if the function is:
# - Task (from @task) or a plain function (from @direct_task)
# - CoreTaskFunction (from core_tasks_registry)
if isinstance(function, (Task, CoreTaskFunction)):
self.func = function.func
else:
self.func = function.__inner_function__ # type: ignore
[docs]
@staticmethod
def _get_from_task_id(task_id: TaskId) -> Task | Callable:
module = importlib.import_module(task_id.module)
return getattr(module, task_id.func_name)
[docs]
@classmethod
def from_id(cls, app: Pynenc, task_id: TaskId) -> Task:
"""Resolve a Task instance by its TaskId, bound to the given app.
Checks the app's registered tasks first, then falls back to importing
the module. When a module-level Task is found, a new Task instance is
created and bound to ``app`` so the caller never shares the module-level
app reference. This prevents cross-contamination when multiple apps
with different runners coexist in the same process.
:param Pynenc app: The application instance with registered tasks
:param TaskId task_id: The task identifier to resolve
:return: The resolved Task instance bound to ``app``
:raises ValueError: If the task_id cannot be resolved to a Task
"""
# Check app's registered tasks first — handles dynamically registered tasks
# and tasks whose module-level name was replaced by a wrapper (e.g. direct_task)
if task_id in app._tasks:
return app._tasks[task_id]
function = Task._get_from_task_id(task_id)
if isinstance(function, Task):
# Create an app-bound copy rather than returning the module-level
# Task whose .app points to a possibly different app instance.
task: Task = Task(app, function.func, function.options)
app._tasks[task_id] = task
return task
if isinstance(function, CoreTaskFunction):
# For CoreTaskFunction, just register all of them in the app
# and return the one matching the task_id.
app.register_core_tasks()
return app._tasks[task_id]
# Handle wrappers that hold a reference to the underlying Task.
# Scan the object's attributes for any Task instance — this covers
# direct_task's __pynenc_task__, Celery shims with pynenc_task, or
# any other wrapper pattern without hardcoding attribute names.
if source_task := _extract_task_from_wrapper(function):
task = Task(app, source_task.func, source_task.options)
app._tasks[task_id] = task
return task
# After importing the module, the decorator may have registered the
# task in app._tasks (when the module shares the same app instance).
if task_id in app._tasks:
return app._tasks[task_id]
app.logger.warning(
f"_get_from_task_id returns a non-Task function {function} for task:{task_id}"
)
raise ValueError(f"Cannot resolve Task from task:{task_id}")
@cached_property
def retriable_exceptions(self) -> tuple[type[Exception], ...]:
"""
Retrieve a tuple of exception types that should trigger a retry of the task.
This method provides a list of exception types, indicating which exceptions will cause the task to be retried.
The `RetryError` exception type, specific to the Pynenc system, is always included to ensure that internal retry
mechanisms are accounted for.
:return: A tuple of retriable exceptions.
"""
if not self.conf.retry_for:
return (RetryError,)
if RetryError in self.conf.retry_for:
return self.conf.retry_for
return self.conf.retry_for + (RetryError,)
[docs]
def __str__(self) -> str:
return f"Task(func={self.func.__name__})"
[docs]
def __repr__(self) -> str:
return self.__str__()
[docs]
def args(self, *args: Params.args, **kwargs: Params.kwargs) -> Arguments:
""":return: an Arguments instance from the given args and kwargs"""
return Arguments.from_call(self.func, *args, **kwargs)
[docs]
def __call__(
self, *args: Params.args, **kwargs: Params.kwargs
) -> BaseInvocation[Params, Result]:
"""Handles a call to the task"""
arguments = Arguments.from_call(self.func, *args, **kwargs)
return self._call(arguments)
[docs]
def _call(self, arguments: Arguments) -> BaseInvocation[Params, Result]:
"""
Route the call to the orchestrator if not in dev mode, otherwise run synchronously
:return: the invocation
"""
if self.app.conf.dev_mode_force_sync_tasks:
return ConcurrentInvocation(call=Call(self, arguments))
call = Call(self, arguments)
return self.app.orchestrator.route_call(call)
@overload
def parallelize(
self,
param_iter: Iterable[tuple | dict | Arguments],
common_args: None = None,
) -> BaseInvocationGroup: ...
@overload
def parallelize(
self,
param_iter: Iterable[dict],
common_args: dict,
) -> BaseInvocationGroup: ...
[docs]
def parallelize(
self,
param_iter: Iterable[tuple | dict | Arguments],
common_args: dict | None = None,
) -> BaseInvocationGroup:
"""
Parallelize the execution of a task with different sets of parameters.
This method allows for concurrent execution of the same task with varying parameters.
It accepts an iterable where each element represents a set of parameters for a separate task invocation.
When `common_args` is provided, `param_iter` must be an iterable of dictionaries, and common arguments
are pre-serialized for efficiency.
```{note}
Without common_args param_iter can be specified in different formats:
- As a tuple: Interpreted as positional arguments for the task.
- As a dictionary: Interpreted as keyword arguments for the task.
- As an `Arguments` instance: Created using `task.args(*args, **kwargs)`.
```
```{important}
common_args is intended for optimize parallelization of huge arguments that will be cached by the client data store.
if the arguments are small or the client data store is disabled, it will not provide any major improvement.
However, for big arguments, it will provide massive time and memory improvements.
```
:param Iterable[tuple | dict | Arguments] param_iter: An iterable of parameters for each call.
Each element in the iterable is used to invoke the task separately.
:param common_args: Optional dictionary of common arguments to pre-serialize and share across calls.
:return: A group of task invocations, allowing the task to be run in parallel with different parameters.
The type of group (synchronous or distributed) depends on the application's configuration.
```{important}
Depending on the configuration, this method creates a group of either synchronous or distributed invocations.
In development mode, where `dev_mode_force_sync_tasks` is enabled, it creates synchronous invocations.
Otherwise, it creates distributed invocations for parallel processing.
```
### Examples
Parallelization with tuples, dicts and arguments:
```{code-block} python
app = Pynenc()
@app.task
def add(x: int, y: int) -> int:
return x + y
# Example usage of parallelize
invocation_group = add.parallelize([(1, 1), add.args(1, 2), {"x": 2, "y": 3}])
print(list(invocation_group.results))
# prints [2, 3, 5]
```
Parallelization with common_args:
```python
@app.task(registration_concurrency=ConcurrencyControlType.DISABLED)
def process(large_data: str, index: int) -> int:
return len(large_data) + index
# With common_args
common = {"large_data": "huge_string"}
params = [{"index": i} for i in range(3)]
invocation_group = process.parallelize(params, common)
print(list(invocation_group.results)) # [len("huge_string") + i for i in range(3)]
```
"""
self.app.logger.info(f"parallelizing task:{self.task_id}")
# Convert param_iter to a list to allow multiple iterations and length checking
param_list = list(param_iter)
if common_args is not None and not all(
isinstance(p, dict) for p in param_list or {}
):
raise ValueError(
"When using common_args, param_iter must contain only dictionaries"
)
# Choose distribution strategy based on whether batch processing is available
if can_batch_process(self, len(param_list)):
return distribute_batch_calls(self, param_list, common_args)
return distribute_calls(self, param_list, common_args)
[docs]
def can_batch_process(task: Task, num_calls: int) -> bool:
"""
Determine if a task can be processed in batches.
A task can be batch processed when:
1. The application is not in development mode
2. Registration concurrency is disabled
3. There are multiple calls to process
4. The task has a valid parallel_batch_size
:param Task task: The task to check
:param int num_calls: The number of calls to process
:return: True if the task can be batch processed, False otherwise
"""
return (
not task.app.conf.dev_mode_force_sync_tasks
and task.conf.registration_concurrency == ConcurrencyControlType.DISABLED
and num_calls > 1
and task.conf.parallel_batch_size > 0
)
[docs]
def prepare_arguments(
task: Task,
param_iter: Iterable[tuple | dict | Arguments],
common_args: dict | None = None,
) -> list[Arguments]:
"""
Convert various parameter formats to a list of Arguments objects,
merging with common_args when provided.
:param Task task: The task to prepare arguments for
:param Iterable[tuple | dict | Arguments] param_iter: Iterable of parameters
:param dict | None common_args: Optional common arguments to merge with each parameter
:return: A list of Arguments objects with common_args merged in
"""
result = []
for params in param_iter:
if common_args and not isinstance(params, dict):
raise ValueError(
"common_args can only be used with an iterable of dictionaries"
)
if isinstance(params, tuple):
args_obj = task.args(*params)
elif isinstance(params, dict):
if common_args:
# Create a new dict with common_args as base, then update with specific params
merged_kwargs = common_args.copy()
merged_kwargs.update(params)
args_obj = task.args(**merged_kwargs)
else:
args_obj = task.args(**params)
else:
args_obj = params
result.append(args_obj)
return result
[docs]
def distribute_calls(
task: Task,
param_list: list[tuple | dict | Arguments],
common_args: dict | None = None,
) -> BaseInvocationGroup:
"""
Distribute calls individually without batch processing.
:param Task task: The task to process
:param list[tuple | dict | Arguments] param_list: List of parameters
:param dict | None common_args: Optional common arguments to merge with each parameter
:return: A list of created invocations
"""
# Prepare arguments with common_args merged in
all_args = prepare_arguments(task, param_list, common_args)
if task.app.conf.dev_mode_force_sync_tasks:
concurrent_invocations: list[ConcurrentInvocation] = []
for args in all_args:
invocation = task._call(args)
if not isinstance(invocation, ConcurrentInvocation):
raise TypeError(
"dev_mode_force_sync_tasks must create ConcurrentInvocation"
)
concurrent_invocations.append(invocation)
return ConcurrentInvocationGroup(task, concurrent_invocations)
distributed_invocations: list[DistributedInvocation] = []
for args in all_args:
invocation = task._call(args)
if not isinstance(invocation, DistributedInvocation):
raise TypeError("distributed mode must create DistributedInvocation")
distributed_invocations.append(invocation)
return DistributedInvocationGroup(task, distributed_invocations)
[docs]
def distribute_batch_calls(
task: Task[Params, Result],
param_list: list[tuple | dict | Arguments],
common_args: dict | None = None,
) -> DistributedInvocationGroup:
"""
Process a list of parameters in batches using PreSerializedCall.
Handles pre-serialization of common arguments for efficient distribution.
:param Task task: The task to process
:param list[tuple | dict | Arguments] param_list: The arguments to process
:param dict | None common_args: Optional common arguments to be pre-serialized once
:return: An invocation group for the distributed calls
"""
# Pre-serialize common arguments if provided
pre_serialized_args = {}
other_args: list[dict[str, Any]] = []
if common_args:
task.logger.info("Pre-serializing common arguments for batch parallelization")
pre_serialized_args = task.app.client_data_store.serialize_arguments(
common_args, task.conf.disable_cache_args
)
task.logger.debug(f"Pre-serialized {len(pre_serialized_args)} common arguments")
other_args = param_list # type: ignore
else:
other_args = [a.kwargs for a in prepare_arguments(task, param_list)]
invocations = []
batch_size = task.conf.parallel_batch_size
task.logger.info(f"Processing {len(other_args)} calls in batches of {batch_size}")
start_time = time.time()
# Process in batches
for i in range(0, len(other_args), batch_size):
batch_args = other_args[i : i + batch_size]
batch_calls = [
PreSerializedCall(
task,
other_args=args,
common_serialized_args=pre_serialized_args,
common_args=common_args,
)
for args in batch_args
]
batch_invocations = task.app.orchestrator.route_calls(batch_calls)
invocations.extend(batch_invocations)
if i + batch_size < len(other_args):
task.logger.info(
f"Processed batch {i // batch_size + 1}, "
f"{len(invocations)}/{len(other_args)} invocations"
)
elapsed = time.time() - start_time
task.logger.info(f"Batch processed {len(invocations)} calls in {elapsed:.2f}s")
return DistributedInvocationGroup(task, invocations)