Source code for pynenc.task

from __future__ import annotations

import importlib
from logging import Logger
import time
from collections.abc import Callable, Iterable
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload

from pynenc import context
from pynenc.arguments import Arguments
from pynenc.call import Call, PreSerializedCall
from pynenc.conf.config_task import ConcurrencyControlType, ConfigTask
from pynenc.core_tasks import CoreTaskFunction
from pynenc.exceptions import InvalidTaskOptionsError, RetryError
from pynenc.invocation.base_invocation import BaseInvocation, BaseInvocationGroup
from pynenc.invocation.conc_invocation import (
    ConcurrentInvocation,
    ConcurrentInvocationGroup,
)
from pynenc.invocation.dist_invocation import (
    DistributedInvocationGroup,
)
from pynenc.identifiers.task_id import TaskId
from pynenc.types import Func, Params, Result
from pynenc.workflow.workflow_context import WorkflowContext

if TYPE_CHECKING:
    from pynenc.app import Pynenc


[docs] class Task(Generic[Params, Result]): """ Represents a distributable function in the Pynenc system. :param Pynenc app: A reference to the Pynenc application. :param Callable func: The function to be run distributed. :param dict[str, Any] **options: The options to apply. Calling the task returns a `BaseInvocation` whose concrete type depends on the execution environment (e.g. `DistributedInvocation` or `ConcurrentInvocation`). Tasks cannot be created from functions defined in ``__main__`` because workers cannot resolve ``__main__.task_name`` back to the originating module. Use the ``@app.task`` decorator to create and register tasks. :raises RuntimeError: If the function is defined in the ``__main__`` module. """ def __init__(self, app: Pynenc, func: Func, options: dict[str, Any]) -> None: if "__main__" in func.__module__: raise RuntimeError( "Cannot create a task from a function in the __main__ module" ) self.task_id = TaskId(func.__module__, func.__name__) self.app = app self.func = func self.options = options self.validate_options()
[docs] def validate_options(self) -> None: """ validate that all the option fields exists in the config_fields it will raise an exception with all the invalid options """ invalid_options = [] for option in self.options: if option not in ConfigTask.config_fields(): invalid_options.append(option) if invalid_options: raise InvalidTaskOptionsError( self.task_id, f"Invalid options: {invalid_options}" )
@cached_property def conf(self) -> ConfigTask: return ConfigTask( task_id=self.task_id, config_values=self.app.config_values, config_filepath=self.app.config_filepath, task_options=self.options, ) @property def logger(self) -> Logger: """The logger for the task""" return self.app.logger @property def invocation(self) -> BaseInvocation: """The invocation of the task""" if dist_inv := context.get_dist_invocation_context(self.app.app_id): return dist_inv if sync_inv := context._get_sync_inv_context_storage().get(self.app.app_id): return sync_inv raise RuntimeError("Task has not been invoked yet") @cached_property def wf(self) -> WorkflowContext: """ Access workflow functionality for this task. Provides methods for workflow state management, deterministic operations, and durability features like pause/resume and continue-as-new. :return: A helper object with workflow functionality Example: ```python @app.task def main_wf_task(data: dict) -> str: # Save workflow state state = main_wf_task.wf.get_state({"step": 0}) # Use deterministic random if main_wf_task.wf.random() > 0.5: state["path"] = "A" else: state["path"] = "B" # Save updated state main_wf_task.wf.save_state(state) # Conditionally pause workflow if needs_human_approval(data): main_wf_task.wf.pause("Waiting for approval") return f"Completed via path {state['path']}" ``` """ return WorkflowContext(self)
[docs] def __getstate__(self) -> dict: # Return state as a dictionary and a secondary value as a tuple return { "app": self.app, "task_id_key": self.task_id.key, "options_json": self.conf.options_to_json(), }
[docs] def __setstate__(self, state: dict) -> None: # Restore instance attributes self.app = state["app"] self.task_id = TaskId.from_key(state["task_id_key"]) self.options = ConfigTask.options_from_json(state["options_json"]) function = Task._get_from_task_id(self.task_id) # Check if the function is: # - Task (from @task) or a plain function (from @direct_task) # - CoreTaskFunction (from core_tasks_registry) if isinstance(function, (Task, CoreTaskFunction)): self.func = function.func else: self.func = function.__inner_function__ # type: ignore
[docs] @staticmethod def _get_from_task_id(task_id: TaskId) -> Task | Callable: module = importlib.import_module(task_id.module) return getattr(module, task_id.func_name)
[docs] @classmethod def from_id(cls, app: Pynenc, task_id: TaskId) -> Task: """Resolve a Task instance by its TaskId, bound to the given app. Checks the app's registered tasks first, then falls back to importing the module. When a module-level Task is found, a new Task instance is created and bound to ``app`` so the caller never shares the module-level app reference. This prevents cross-contamination when multiple apps with different runners coexist in the same process. :param Pynenc app: The application instance with registered tasks :param TaskId task_id: The task identifier to resolve :return: The resolved Task instance bound to ``app`` :raises ValueError: If the task_id cannot be resolved to a Task """ # Check app's registered tasks first — handles dynamically registered tasks # and tasks whose module-level name was replaced by a wrapper (e.g. direct_task) if task_id in app._tasks: return app._tasks[task_id] function = Task._get_from_task_id(task_id) if isinstance(function, Task): # Create an app-bound copy rather than returning the module-level # Task whose .app points to a possibly different app instance. task: Task = Task(app, function.func, function.options) app._tasks[task_id] = task return task if isinstance(function, CoreTaskFunction): # For CoreTaskFunction, just register all of them in the app # and return the one matching the task_id. app.register_core_tasks() return app._tasks[task_id] # Handle wrappers that hold a reference to the underlying Task. # Scan the object's attributes for any Task instance — this covers # direct_task's __pynenc_task__, Celery shims with pynenc_task, or # any other wrapper pattern without hardcoding attribute names. if source_task := _extract_task_from_wrapper(function): task = Task(app, source_task.func, source_task.options) app._tasks[task_id] = task return task # After importing the module, the decorator may have registered the # task in app._tasks (when the module shares the same app instance). if task_id in app._tasks: return app._tasks[task_id] app.logger.warning( f"_get_from_task_id returns a non-Task function {function} for task:{task_id}" ) raise ValueError(f"Cannot resolve Task from task:{task_id}")
@cached_property def retriable_exceptions(self) -> tuple[type[Exception], ...]: """ Retrieve a tuple of exception types that should trigger a retry of the task. This method provides a list of exception types, indicating which exceptions will cause the task to be retried. The `RetryError` exception type, specific to the Pynenc system, is always included to ensure that internal retry mechanisms are accounted for. :return: A tuple of retriable exceptions. """ if not self.conf.retry_for: return (RetryError,) if RetryError in self.conf.retry_for: return self.conf.retry_for return self.conf.retry_for + (RetryError,)
[docs] def __str__(self) -> str: return f"Task(func={self.func.__name__})"
[docs] def __repr__(self) -> str: return self.__str__()
[docs] def args(self, *args: Params.args, **kwargs: Params.kwargs) -> Arguments: """:return: an Arguments instance from the given args and kwargs""" return Arguments.from_call(self.func, *args, **kwargs)
[docs] def __call__( self, *args: Params.args, **kwargs: Params.kwargs ) -> BaseInvocation[Params, Result]: """Handles a call to the task""" arguments = Arguments.from_call(self.func, *args, **kwargs) return self._call(arguments)
[docs] def _call(self, arguments: Arguments) -> BaseInvocation[Params, Result]: """ Route the call to the orchestrator if not in dev mode, otherwise run synchronously :return: the invocation """ if self.app.conf.dev_mode_force_sync_tasks: return ConcurrentInvocation(call=Call(self, arguments)) call = Call(self, arguments) return self.app.orchestrator.route_call(call)
@overload def parallelize( self, param_iter: Iterable[tuple | dict | Arguments], common_args: None = None, ) -> BaseInvocationGroup: ... @overload def parallelize( self, param_iter: Iterable[dict], common_args: dict, ) -> BaseInvocationGroup: ...
[docs] def parallelize( self, param_iter: Iterable[tuple | dict | Arguments], common_args: dict | None = None, ) -> BaseInvocationGroup: """ Parallelize the execution of a task with different sets of parameters. This method allows for concurrent execution of the same task with varying parameters. It accepts an iterable where each element represents a set of parameters for a separate task invocation. When `common_args` is provided, `param_iter` must be an iterable of dictionaries, and common arguments are pre-serialized for efficiency. ```{note} Without common_args param_iter can be specified in different formats: - As a tuple: Interpreted as positional arguments for the task. - As a dictionary: Interpreted as keyword arguments for the task. - As an `Arguments` instance: Created using `task.args(*args, **kwargs)`. ``` ```{important} common_args is intended for optimize parallelization of huge arguments that will be cached by the client data store. if the arguments are small or the client data store is disabled, it will not provide any major improvement. However, for big arguments, it will provide massive time and memory improvements. ``` :param Iterable[tuple | dict | Arguments] param_iter: An iterable of parameters for each call. Each element in the iterable is used to invoke the task separately. :param common_args: Optional dictionary of common arguments to pre-serialize and share across calls. :return: A group of task invocations, allowing the task to be run in parallel with different parameters. The type of group (synchronous or distributed) depends on the application's configuration. ```{important} Depending on the configuration, this method creates a group of either synchronous or distributed invocations. In development mode, where `dev_mode_force_sync_tasks` is enabled, it creates synchronous invocations. Otherwise, it creates distributed invocations for parallel processing. ``` ### Examples Parallelization with tuples, dicts and arguments: ```{code-block} python app = Pynenc() @app.task def add(x: int, y: int) -> int: return x + y # Example usage of parallelize invocation_group = add.parallelize([(1, 1), add.args(1, 2), {"x": 2, "y": 3}]) print(list(invocation_group.results)) # prints [2, 3, 5] ``` Parallelization with common_args: ```python @app.task(registration_concurrency=ConcurrencyControlType.DISABLED) def process(large_data: str, index: int) -> int: return len(large_data) + index # With common_args common = {"large_data": "huge_string"} params = [{"index": i} for i in range(3)] invocation_group = process.parallelize(params, common) print(list(invocation_group.results)) # [len("huge_string") + i for i in range(3)] ``` """ self.app.logger.info(f"parallelizing task:{self.task_id}") # Convert param_iter to a list to allow multiple iterations and length checking param_list = list(param_iter) if common_args is not None and not all( isinstance(p, dict) for p in param_list or {} ): raise ValueError( "When using common_args, param_iter must contain only dictionaries" ) # Choose distribution strategy based on whether batch processing is available if can_batch_process(self, len(param_list)): return distribute_batch_calls(self, param_list, common_args) return distribute_calls(self, param_list, common_args)
[docs] def _extract_task_from_wrapper(wrapper: object) -> Task | None: """Extract a Task from a wrapper by scanning its instance attributes. Handles any wrapper pattern (direct_task's ``__pynenc_task__``, Celery migration shims with ``pynenc_task``, etc.) without hardcoding attribute names. :param object wrapper: The object to inspect :return: The first Task instance found, or None """ try: for value in vars(wrapper).values(): if isinstance(value, Task): return value except TypeError: # vars() fails on objects without __dict__ (e.g. slots-only) pass return None
[docs] def can_batch_process(task: Task, num_calls: int) -> bool: """ Determine if a task can be processed in batches. A task can be batch processed when: 1. The application is not in development mode 2. Registration concurrency is disabled 3. There are multiple calls to process 4. The task has a valid parallel_batch_size :param Task task: The task to check :param int num_calls: The number of calls to process :return: True if the task can be batch processed, False otherwise """ return ( not task.app.conf.dev_mode_force_sync_tasks and task.conf.registration_concurrency == ConcurrencyControlType.DISABLED and num_calls > 1 and task.conf.parallel_batch_size > 0 )
[docs] def prepare_arguments( task: Task, param_iter: Iterable[tuple | dict | Arguments], common_args: dict | None = None, ) -> list[Arguments]: """ Convert various parameter formats to a list of Arguments objects, merging with common_args when provided. :param Task task: The task to prepare arguments for :param Iterable[tuple | dict | Arguments] param_iter: Iterable of parameters :param dict | None common_args: Optional common arguments to merge with each parameter :return: A list of Arguments objects with common_args merged in """ result = [] for params in param_iter: if common_args and not isinstance(params, dict): raise ValueError( "common_args can only be used with an iterable of dictionaries" ) if isinstance(params, tuple): args_obj = task.args(*params) elif isinstance(params, dict): if common_args: # Create a new dict with common_args as base, then update with specific params merged_kwargs = common_args.copy() merged_kwargs.update(params) args_obj = task.args(**merged_kwargs) else: args_obj = task.args(**params) else: args_obj = params result.append(args_obj) return result
[docs] def distribute_calls( task: Task, param_list: list[tuple | dict | Arguments], common_args: dict | None = None, ) -> BaseInvocationGroup: """ Distribute calls individually without batch processing. :param Task task: The task to process :param list[tuple | dict | Arguments] param_list: List of parameters :param dict | None common_args: Optional common arguments to merge with each parameter :return: A list of created invocations """ # Prepare arguments with common_args merged in all_args = prepare_arguments(task, param_list, common_args) # Standard processing - distribute calls normally invocations = [] for args in all_args: invocation = task._call(args) if invocation: invocations.append(invocation) group_cls: type[BaseInvocationGroup] if task.app.conf.dev_mode_force_sync_tasks: group_cls = ConcurrentInvocationGroup else: group_cls = DistributedInvocationGroup return group_cls(task, invocations)
[docs] def distribute_batch_calls( task: Task[Params, Result], param_list: list[tuple | dict | Arguments], common_args: dict | None = None, ) -> DistributedInvocationGroup: """ Process a list of parameters in batches using PreSerializedCall. Handles pre-serialization of common arguments for efficient distribution. :param Task task: The task to process :param list[tuple | dict | Arguments] param_list: The arguments to process :param dict | None common_args: Optional common arguments to be pre-serialized once :return: An invocation group for the distributed calls """ # Pre-serialize common arguments if provided pre_serialized_args = {} other_args: list[dict[str, Any]] = [] if common_args: task.logger.info("Pre-serializing common arguments for batch parallelization") pre_serialized_args = task.app.client_data_store.serialize_arguments( common_args, task.conf.disable_cache_args ) task.logger.debug(f"Pre-serialized {len(pre_serialized_args)} common arguments") other_args = param_list # type: ignore else: other_args = [a.kwargs for a in prepare_arguments(task, param_list)] invocations = [] batch_size = task.conf.parallel_batch_size task.logger.info(f"Processing {len(other_args)} calls in batches of {batch_size}") start_time = time.time() # Process in batches for i in range(0, len(other_args), batch_size): batch_args = other_args[i : i + batch_size] batch_calls = [ PreSerializedCall( task, other_args=args, common_serialized_args=pre_serialized_args, common_args=common_args, ) for args in batch_args ] batch_invocations = task.app.orchestrator.route_calls(batch_calls) invocations.extend(batch_invocations) if i + batch_size < len(other_args): task.logger.info( f"Processed batch {i // batch_size + 1}, " f"{len(invocations)}/{len(other_args)} invocations" ) elapsed = time.time() - start_time task.logger.info(f"Batch processed {len(invocations)} calls in {elapsed:.2f}s") return DistributedInvocationGroup(task, invocations)