Source code for pynenc.app

import asyncio
from collections.abc import Callable, Iterable
from collections import defaultdict
from functools import wraps
from logging import Logger
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, overload

from pynenc import context
from pynenc.app_info import AppInfo
from pynenc.broker.base_broker import BaseBroker
from pynenc.client_data_store.base_client_data_store import BaseClientDataStore
from pynenc.conf.config_pynenc import ConfigPynenc
from pynenc.conf.config_task import ConcurrencyControlType
from pynenc.core_tasks import core_tasks_registry
from pynenc.orchestrator.base_orchestrator import BaseOrchestrator
from pynenc.plugin_loader import load_all_plugins
from pynenc.runner.base_runner import BaseRunner
from pynenc.serializer.base_serializer import BaseSerializer
from pynenc.state_backend.base_state_backend import BaseStateBackend
from pynenc.task import Task
from pynenc.trigger.base_trigger import BaseTrigger
from pynenc.trigger.trigger_builder import on_cron
from pynenc.util.log import create_logger
from pynenc.util.subclasses import get_subclass

if TYPE_CHECKING:
    from pynenc.arguments import Arguments
    from pynenc.core_tasks import CoreTaskDefinition
    from pynenc.invocation import BaseInvocationGroup
    from pynenc.identifiers.task_id import TaskId
    from pynenc.trigger import TriggerBuilder
    from pynenc.types import Args, Func, Params, Result

    # Type for the parallel function that generates arguments for parallel processing
    ParallelFuncReturn = Union[  # noqa: UP007 # Use `X | Y` for type annotations
        # Option 1: Just return an iterable of arguments (any format)
        Iterable[tuple | dict | Arguments],
        # Option 2: Return a tuple of (common_args, param_iter) for optimized processing of large shared data
        # This approach pre-serializes common_args once, reducing overhead for large arguments
        # tuple.0 Common arguments shared by all tasks
        # tuple.1 Iterable of dictionaries with task-specific arguments
        tuple[dict[str, Any], Iterable[dict]],
    ]
    ParallelFunc = Callable[[Args], ParallelFuncReturn]

    # Type for the aggregation function that combines results
    AggregateFunc = Callable[[Iterable[Result]], Result]


[docs] def _new_pynenc( cls: "type[Pynenc]", config_values: "dict[str, Any] | None", config_filepath: "str | None", ) -> "Pynenc": """Module-level pickle reconstructor that forwards config to Pynenc.__new__. Pickle calls ``__new__`` with no arguments by default, which breaks the multiton lookup (config_values=None → app_id defaults to 'pynenc'). This helper is returned by ``__reduce__`` so pickle calls it instead, ensuring ``__new__`` receives the real config_values. The instance is pre-registered in ``_instances`` and minimally initialised so that module-level ``Pynenc()`` calls triggered during state deserialisation (``Task.__setstate__`` → ``importlib.import_module``) reuse this instance instead of creating a second one with a freshly generated temp DB path. """ instance = cls.__new__(cls, config_values, config_filepath) # Only bootstrap genuinely new instances — __new__ may have returned an # already-initialised multiton. if not hasattr(instance, "_initialised"): # Minimal init so __init__ is skipped on multiton re-encounter and # lazy properties don't raise AttributeError. instance._initialised = True instance._config_values = config_values instance._config_filepath = config_filepath instance.reporting = None instance._runner_instance = None instance._tasks = {} instance._deferred_trigger_tasks = defaultdict(list) instance._reset_cached_components() # Register early — *before* pickle deserialises the state dict whose # Task objects may trigger module re-import. # Use ConfigPynenc to resolve app_id consistently (it may come from # config_values, config_filepath, or defaults). resolved = ConfigPynenc( config_values=config_values, config_filepath=config_filepath ) cls._instances[resolved.app_id] = instance return instance
[docs] class Pynenc: """ The main class of the Pynenc library that creates an application object. :param dict[str, Any] | None config_values: A dictionary of configuration values. Use ``{"app_id": "my_app"}`` to set the application identifier. :param str | None config_filepath: A path to a configuration file. ```{note} All of these base classes are abstract and cannot be used directly. If none is specified, they will default to `MemTaskBroker`, `MemStateBackend`, etc. These default classes do not actually distribute the code but are helpers for tests or for running an application on your localhost. They may help to parallelize to some degree but cannot be used in a production system. ``` """ # Multiton registry: at most one instance per app_id in a process. # Populated by __setstate__ when a parent app is unpickled in a child # process, and consulted by __new__ so that module-level re-construction # returns the same (already-configured) instance. _instances: ClassVar[dict[str, "Pynenc"]] = {}
[docs] def __new__( cls, config_values: dict[str, Any] | None = None, config_filepath: str | None = None, ) -> "Pynenc": """Return existing instance for this app_id if one was registered by __setstate__.""" if cls._instances: resolved = ConfigPynenc( config_values=config_values, config_filepath=config_filepath ) if existing := cls._instances.get(resolved.app_id): return existing return super().__new__(cls)
def __init__( self, config_values: dict[str, Any] | None = None, config_filepath: str | None = None, ) -> None: # Skip re-initialisation when __new__ returned a cached instance if hasattr(self, "_initialised"): return self._initialised = True self._config_values = config_values self._config_filepath = config_filepath self.reporting = None self._runner_instance: BaseRunner | None = None self._tasks: dict[TaskId, Task] = {} self._deferred_trigger_tasks: dict[TaskId, list[TriggerBuilder]] = defaultdict( list ) self._reset_cached_components() load_all_plugins()
[docs] @classmethod def from_info(cls, app_info: AppInfo) -> "Pynenc": """ Create a Pynenc app instance from AppInfo. :param app_info: The AppInfo object containing app metadata :return: A new Pynenc app instance """ from pynenc.util.import_app import create_app_from_info if app := create_app_from_info(app_info): return app config_values = dict(app_info.config_values) if app_info.config_values else {} config_values["app_id"] = app_info.app_id return cls( config_values=config_values, config_filepath=app_info.config_filepath, )
[docs] def register_core_tasks(self) -> None: """Register all core tasks defined in the core_tasks_registry.""" for task_def in core_tasks_registry.definitions: options = task_def.options.copy() if task_def.config_cron: cron_val = getattr(self.conf, task_def.config_cron) options["triggers"] = [on_cron(cron_val)] self.task(task_def.func, **options)
[docs] def register_core_task(self, task_def: "CoreTaskDefinition") -> None: options = task_def.options.copy() if task_def.config_cron: cron_val = getattr(self.conf, task_def.config_cron) options["triggers"] = [on_cron(cron_val)] self.task(task_def.func, **options)
[docs] def _store_deferred_trigger( self, task: "Task", triggers: Union["TriggerBuilder", list["TriggerBuilder"], None], ) -> None: """Store triggers to be registered later by the runner. :param Task task: The Task instance that declared triggers :param triggers: TriggerBuilder or list of TriggerBuilder """ if not triggers: return if isinstance(triggers, list): trigger_list = triggers else: trigger_list = [triggers] self._deferred_trigger_tasks[task.task_id] = trigger_list
[docs] def register_deferred_triggers(self) -> None: """Register all deferred task triggers that were collected during task decoration. This is intended to be called by a runner during startup after relevant modules have been imported so trigger backends are available. """ if not self._deferred_trigger_tasks: return self.logger.info("Registering deferred task trigger(s):") for task_id, triggers in self._deferred_trigger_tasks.items(): task = self.get_task(task_id) self.trigger.register_task_triggers(task, triggers) self.logger.info(f" - Registered task:{task_id.key} triggers") self._deferred_trigger_tasks.clear()
@property def config_values(self) -> dict[str, Any] | None: return self._config_values @config_values.setter def config_values(self, value: dict[str, Any] | None) -> None: self._config_values = value @property def config_filepath(self) -> str | None: return self._config_filepath @config_filepath.setter def config_filepath(self, value: str | None) -> None: self._config_filepath = value @property def app_id(self) -> str: return self.conf.app_id @property def tasks(self) -> dict["TaskId", Task]: """ Get the dictionary of registered tasks. :return: A dictionary mapping task_id to Task instances. """ return self._tasks
[docs] def get_task(self, task_id: "TaskId") -> Task: """ Get a task by its ID. :param task_id: The ID of the task to retrieve. :return: The Task instance if found, None otherwise. warning it may overwrite the options """ if task_id not in self._tasks: self._tasks[task_id] = Task.from_id(self, task_id) return self._tasks[task_id]
[docs] def __reduce__(self) -> tuple: """Control pickle reconstruction so __new__ receives config_values. Without this, pickle calls ``Pynenc.__new__(Pynenc)`` with no args, causing the multiton to fall back to the default app_id ('pynenc') and return a stale instance for same-process roundtrips. """ return ( _new_pynenc, (type(self), self.config_values, self.config_filepath), self.__getstate__(), )
[docs] def __getstate__(self) -> dict: """Return the serializable state of the app for pickling or multiprocessing.""" return { "config_values": self.config_values, "config_filepath": self.config_filepath, "reporting": self.reporting, "tasks": self._tasks, }
[docs] def __setstate__(self, state: dict) -> None: """Restore the app state and register in the multiton. When a child process (spawned via ``multiprocessing``) unpickles the parent's app, this registers the instance so that subsequent ``Pynenc(config_values={"app_id": ...})`` calls (e.g. during module re-import) return the same already-configured instance. """ self._initialised = True self._config_values = state["config_values"] self._config_filepath = state["config_filepath"] self.reporting = state["reporting"] self._tasks = state.get("tasks", {}) self._deferred_trigger_tasks = defaultdict(list) self._runner_instance = None self._reset_cached_components() load_all_plugins() # Register so module-level Pynenc() with the same app_id reuses this instance. # Use ConfigPynenc to resolve app_id consistently (may come from # config_values, config_filepath, or environment variables). resolved = ConfigPynenc( config_values=self._config_values, config_filepath=self._config_filepath, ) type(self)._instances[resolved.app_id] = self
[docs] def _reset_cached_components(self) -> None: """Reset all lazily-initialised components so they are re-created on next access.""" self._conf: ConfigPynenc | None = None self._logger: Logger | None = None self._orchestrator: BaseOrchestrator | None = None self._trigger: BaseTrigger | None = None self._broker: BaseBroker | None = None self._state_backend: BaseStateBackend | None = None self._serializer: BaseSerializer | None = None self._client_data_store: BaseClientDataStore | None = None
@property def conf(self) -> ConfigPynenc: if self._conf is None: self._conf = ConfigPynenc( config_values=self.config_values, config_filepath=self.config_filepath ) return self._conf @property def logger(self) -> Logger: if self._logger is None: self._logger = create_logger(self) return self._logger @property def orchestrator(self) -> BaseOrchestrator: if self._orchestrator is None: self._orchestrator = get_subclass( BaseOrchestrator, # type: ignore[type-abstract] self.conf.orchestrator_cls, )(self) return self._orchestrator @property def trigger(self) -> BaseTrigger: if self._trigger is None: self._trigger = get_subclass(BaseTrigger, self.conf.trigger_cls)(self) # type: ignore # mypy issue #4717 return self._trigger @property def broker(self) -> BaseBroker: if self._broker is None: self._broker = get_subclass(BaseBroker, self.conf.broker_cls)(self) # type: ignore # mypy issue #4717 return self._broker @property def state_backend(self) -> BaseStateBackend: if self._state_backend is None: self._state_backend = get_subclass( BaseStateBackend, # type: ignore[type-abstract] self.conf.state_backend_cls, )(self) self._state_backend.store_app_info(AppInfo.from_app(self)) return self._state_backend @property def serializer(self) -> BaseSerializer: if self._serializer is None: self._serializer = get_subclass(BaseSerializer, self.conf.serializer_cls)() # type: ignore # mypy issue #4717 return self._serializer @property def client_data_store(self) -> BaseClientDataStore: if self._client_data_store is None: self._client_data_store = get_subclass( BaseClientDataStore, # type: ignore[type-abstract] self.conf.client_data_store_cls, )(self) return self._client_data_store @property def runner(self) -> BaseRunner: """ Get the runner for this app, prioritizing thread/process-specific context. First, it checks the thread-local context for a runner (via get_current_runner). This is crucial in the MultiThreadRunner, where each process runs a ThreadRunner and needs to use its own runner instance rather than the app's default. If no context runner exists, it falls back to the instance-level runner. This mechanism ensures correct runner isolation across threads and processes. :return: The runner instance for the current context or the app instance. """ # Check if there's a runner in the context if context_runner := context.get_current_runner(self.app_id): return context_runner # Fall back to instance-level runner if self._runner_instance is None: self._runner_instance = get_subclass(BaseRunner, self.conf.runner_cls)(self) # type: ignore return self._runner_instance @runner.setter def runner(self, runner_instance: BaseRunner) -> None: self._runner_instance = runner_instance
[docs] def purge(self) -> None: """Purge all data from the broker and state backend""" self.broker.purge() self.orchestrator.purge() self.state_backend.purge() self.client_data_store.purge() self.trigger.purge()
@overload def task( self, func: "Func", *, parallel_batch_size: int | None = None, retry_for: tuple[type[Exception], ...] | None = None, max_retries: int | None = None, running_concurrency: ConcurrencyControlType | None = None, registration_concurrency: ConcurrencyControlType | None = None, key_arguments: tuple[str, ...] | None = None, on_diff_non_key_args_raise: bool | None = None, call_result_cache: bool | None = None, disable_cache_args: tuple[str, ...] | None = None, triggers: Union["TriggerBuilder", list["TriggerBuilder"]] | None = None, force_new_workflow: bool | None = None, reroute_on_concurrency_control: bool | None = None, ) -> "Task": ... @overload def task( self, func: None = None, *, parallel_batch_size: int | None = None, retry_for: tuple[type[Exception], ...] | None = None, max_retries: int | None = None, running_concurrency: ConcurrencyControlType | None = None, registration_concurrency: ConcurrencyControlType | None = None, key_arguments: tuple[str, ...] | None = None, on_diff_non_key_args_raise: bool | None = None, call_result_cache: bool | None = None, disable_cache_args: tuple[str, ...] | None = None, triggers: Union["TriggerBuilder", list["TriggerBuilder"]] | None = None, force_new_workflow: bool | None = None, reroute_on_concurrency_control: bool | None = None, ) -> Callable[["Func"], "Task"]: ...
[docs] def task( self, func: Optional["Func"] = None, *, parallel_batch_size: int | None = None, retry_for: tuple[type[Exception], ...] | None = None, max_retries: int | None = None, running_concurrency: ConcurrencyControlType | None = None, registration_concurrency: ConcurrencyControlType | None = None, key_arguments: tuple[str, ...] | None = None, on_diff_non_key_args_raise: bool | None = None, call_result_cache: bool | None = None, disable_cache_args: tuple[str, ...] | None = None, triggers: Union["TriggerBuilder", list["TriggerBuilder"]] | None = None, force_new_workflow: bool | None = None, reroute_on_concurrency_control: bool | None = None, ) -> "Task | Callable[[Func], Task]": """ The task decorator converts the function into an instance of a BaseTask. It accepts any kind of options, however these options will be validated with the options class assigned to the class. :param Optional[Callable] func: The function to be converted into a Task instance. :param int | None parallel_batch_size: If set to 0, auto parallelization is disabled. If greater than 0, tasks with iterable arguments are automatically split into chunks. :param Optional[Tuple[Exception, ...]] retry_for: Exceptions for which the task should be retried. :param int | None max_retries: The maximum number of retries for a task. :param ConcurrencyControlType | None running_concurrency: Controls the concurrency behavior of the task. :param ConcurrencyControlType | None registration_concurrency: Manages task registration concurrency. :param Optional[Tuple[str, ...]] key_arguments: Key arguments for concurrency control. :param bool | None on_diff_non_key_args_raise: If True, raises an exception for task invocations with matching key arguments but different non-key arguments. :param bool | None call_result_cache: If True, it will return the latest result of a Task with the same arguments if availble, otherwise it will trigger a new invocation as expected. :param tuple[str, ...] | None disable_cache_args: Arguments to exclude from caching, it will accept "*" to disable caching for all arguments. :param Union[TriggerBuilder, list[TriggerBuilder]] | None triggers: Trigger definitions that determine when this task should execute automatically. Can be a single TriggerBuilder or a list of builders for multiple trigger conditions. :param bool | None force_new_workflow: If True, this task will always create a new workflow when invoked. Even when called from within another workflow, it creates a subworkflow that maintains a reference to its parent workflow. :return: A Task instance or a callable that when called returns a Task instance. :example: ```python # Basic task with no triggers @app.task(max_retries=3) def simple_task(x: int, y: int) -> int: return x + y # Task with a single trigger using a cron schedule from pynenc.trigger import on_cron @app.task(triggers=on_cron("0 0 * * *")) # Run daily at midnight def daily_report() -> None: # Generate daily report pass # Task with multiple triggers using different conditions from pynenc.trigger import on_event, on_status @app.task( triggers=[ on_event("payment.completed", filters={"amount": {"$gt": 1000}}), on_status("validate_data", statuses=["SUCCESS"]) ] ) def process_important_payment(payment_id: str) -> None: # Process high-value payment after validation pass # Task with complex trigger condition using a builder from pynenc.trigger import TriggerBuilder from pynenc.trigger.conditions import CompositeLogic trigger = ( TriggerBuilder() .on_event("payment.received") .on_status("validate_payment") .with_logic(CompositeLogic.AND) # Both conditions must be met .with_arguments(lambda ctx: {"payment_id": ctx["event"].payload["id"]}) ) @app.task(triggers=trigger) def process_payment(payment_id: str) -> None: # Process payment that has been received and validated pass ``` """ options = { "parallel_batch_size": parallel_batch_size, "retry_for": retry_for, "max_retries": max_retries, "running_concurrency": running_concurrency, "registration_concurrency": registration_concurrency, "key_arguments": key_arguments, "on_diff_non_key_args_raise": on_diff_non_key_args_raise, "call_result_cache": call_result_cache, "disable_cache_args": disable_cache_args, "force_new_workflow": force_new_workflow, "reroute_on_concurrency_control": reroute_on_concurrency_control, } options = {k: v for k, v in options.items() if v is not None} def init_task(_func: "Func") -> Task["Params", "Result"]: if _func.__qualname__ != _func.__name__: raise ValueError( "Decorated function must be defined at the module level." ) task: Task = Task(self, _func, options) self._tasks[task.task_id] = task self._store_deferred_trigger(task, triggers) return task if func is None: return init_task return init_task(func)
@overload def direct_task( self, func: "Func", *, parallel_func: "ParallelFunc | None" = None, aggregate_func: "AggregateFunc | None" = None, parallel_batch_size: int | None = None, retry_for: tuple[type[Exception], ...] | None = None, max_retries: int | None = None, running_concurrency: ConcurrencyControlType | None = None, registration_concurrency: ConcurrencyControlType | None = None, key_arguments: tuple[str, ...] | None = None, on_diff_non_key_args_raise: bool | None = None, call_result_cache: bool | None = None, disable_cache_args: tuple[str, ...] | None = None, force_new_workflow: bool | None = None, reroute_on_concurrency_control: bool | None = None, ) -> "Func": ... @overload def direct_task( self, func: "Func[Params, Result]", *, parallel_func: "ParallelFunc | None" = None, aggregate_func: "AggregateFunc | None" = None, parallel_batch_size: int | None = None, retry_for: tuple[type[Exception], ...] | None = None, max_retries: int | None = None, running_concurrency: ConcurrencyControlType | None = None, registration_concurrency: ConcurrencyControlType | None = None, key_arguments: tuple[str, ...] | None = None, on_diff_non_key_args_raise: bool | None = None, call_result_cache: bool | None = None, disable_cache_args: tuple[str, ...] | None = None, force_new_workflow: bool | None = None, reroute_on_concurrency_control: bool | None = None, ) -> "Func": ... @overload def direct_task( self, func: None = None, *, parallel_func: "ParallelFunc | None" = None, aggregate_func: "AggregateFunc | None" = None, parallel_batch_size: int | None = None, retry_for: tuple[type[Exception], ...] | None = None, max_retries: int | None = None, running_concurrency: ConcurrencyControlType | None = None, registration_concurrency: ConcurrencyControlType | None = None, key_arguments: tuple[str, ...] | None = None, on_diff_non_key_args_raise: bool | None = None, call_result_cache: bool | None = None, disable_cache_args: tuple[str, ...] | None = None, force_new_workflow: bool | None = None, reroute_on_concurrency_control: bool | None = None, ) -> Callable[["Func[Params, Result]"], "Func[Params, Result]"]: ...
[docs] def direct_task( self, func: Optional["Func[Params, Result]"] = None, *, parallel_func: "ParallelFunc | None" = None, aggregate_func: "AggregateFunc | None" = None, parallel_batch_size: int | None = None, retry_for: tuple[type[Exception], ...] | None = None, max_retries: int | None = None, running_concurrency: ConcurrencyControlType | None = None, registration_concurrency: ConcurrencyControlType | None = None, key_arguments: tuple[str, ...] | None = None, on_diff_non_key_args_raise: bool | None = None, call_result_cache: bool | None = None, disable_cache_args: tuple[str, ...] | None = None, force_new_workflow: bool | None = None, reroute_on_concurrency_control: bool | None = None, ) -> ( "Func[Params, Result] | Callable[[Func[Params, Result]], Func[Params, Result]]" ): """ Create a task that directly returns its result rather than returning an invocation. This decorator maintains the original function's behavior: - For synchronous functions, it waits for the result and returns it directly - For async functions, it returns an awaitable that resolves to the result It also supports parallel execution via the parallel_func parameter, which takes a function that generates arguments for parallel processing, and aggregate_func, which combines the results. :param Optional[Func] func: The function to be converted into a Task instance that returns results directly. :param Optional[ParallelFunc] parallel_func: Function that takes a dict of key arguments and returns either: 1. An iterable of parameters for parallel execution (can be tuples, dicts, or Arguments) ```python # Example returning just parameters lambda args: [(i, i+1) for i in range(5)] # Returns tuples lambda args: [{"x": i, "y": i+1} for i in range(5)] # Returns dicts ``` 2. A tuple containing (common_args, param_iter) for efficient handling of large shared data: - common_args: Dictionary of arguments shared by all parallel tasks - param_iter: Iterable of dictionaries with task-specific arguments ```python # Example with common arguments lambda args: { "common_args": {"large_data": args["large_data"]}, # Shared data (serialized once) "param_iter": [{"index": i} for i in range(10)] # Task-specific args } ``` This second approach provides major performance benefits when dealing with large shared arguments (20MB+) as they're serialized only once instead of for each parallel task. :param Optional[AggregateFunc] aggregate_func: Function that takes a list of results and aggregates them into a single result. :param int | None parallel_batch_size: If set to 0, auto parallelization is disabled. If greater than 0, tasks with iterable arguments are automatically split into chunks. :param Optional[Tuple[Exception, ...]] retry_for: Exceptions for which the task should be retried. :param int | None max_retries: The maximum number of retries for a task. :param ConcurrencyControlType | None running_concurrency: Controls the concurrency behavior of the task. :param ConcurrencyControlType | None registration_concurrency: Manages task registration concurrency. :param Optional[Tuple[str, ...]] key_arguments: Key arguments for concurrency control. :param bool | None on_diff_non_key_args_raise: If True, raises an exception for task invocations with matching key arguments but different non-key arguments. :param bool | None call_result_cache: If True, it will return the latest result of a Task with the same arguments if available, otherwise it will trigger a new invocation as expected. :param tuple[str, ...] | None disable_cache_args: Arguments to exclude from caching, it will accept "*" to disable caching for all arguments. :param bool | None force_new_workflow: If True, this task will always create a new workflow when invoked. Even when called from within another workflow, it creates a subworkflow that maintains a reference to its parent workflow. :return: A function that behaves like the original but is backed by a distributed task system. :note: A direct task do not have triggers, it is always executed when called. :example: ```python @app.direct_task(max_retries=3) def my_func(x, y): return x + y # This will return the result directly result = my_func(1, 2) # Returns 3 # With parallel execution @app.direct_task( parallel_func=lambda _: [(i, i+1) for i in range(5)], aggregate_func=sum ) def add_parallel(x, y): return x + y result = add_parallel(0, 0) # Returns sum of all parallel results # With optimized pre-serialization of large shared data @app.direct_task( parallel_func=lambda args: { "common_args": {"large_data": args["large_data"]}, "param_iter": [{"index": i} for i in range(100)] }, aggregate_func=lambda results: sum(r[0] for r in results) ) def process_data(large_data: str, index: int = 0) -> tuple[int, int]: # Process large data with multiple parallel tasks return (len(large_data) + index, index) # Calling with 20MB of data huge_data = "x" * (20 * 1024 * 1024) result = process_data(huge_data) # Pre-serializes huge_data only once ``` """ def _parallelize( task: Task, *args: "Params.args", **kwargs: "Params.kwargs" ) -> "BaseInvocationGroup": parsed_args = task.args(*args, **kwargs).kwargs parallel_result = parallel_func(parsed_args) # type: ignore if isinstance(parallel_result, tuple) and len(parallel_result) == 2: common_args, param_iter = parallel_result assert isinstance(param_iter, Iterable) assert isinstance(common_args, dict) return task.parallelize(param_iter, common_args) return task.parallelize(parallel_result) def _aggregate_results(results: Iterable["Result"]) -> "Result": if aggregate_func is not None: return aggregate_func(results) # TODO try to infer aggregate function from the type # eg. list.concat, dict.update, etc... raise ValueError("Aggregation function required for parallel execution") def decorator(func: "Func[Params, Result]") -> "Func[Params, Result]": task_options = { "parallel_batch_size": parallel_batch_size, "retry_for": retry_for, "max_retries": max_retries, "running_concurrency": running_concurrency, "registration_concurrency": registration_concurrency, "key_arguments": key_arguments, "on_diff_non_key_args_raise": on_diff_non_key_args_raise, "call_result_cache": call_result_cache, "disable_cache_args": disable_cache_args, "force_new_workflow": force_new_workflow, "reroute_on_concurrency_control": reroute_on_concurrency_control, } task_options = {k: v for k, v in task_options.items() if v is not None} task = self.task(func, **task_options) # type: ignore is_async = asyncio.iscoroutinefunction(func) if is_async: @wraps(func) async def async_wrapper( *args: "Params.args", **kwargs: "Params.kwargs" ) -> "Result": if parallel_func: invocation_group = _parallelize(task, *args, **kwargs) results = [ result async for result in invocation_group.async_results() ] return _aggregate_results(results) return await task(*args, **kwargs).async_result() # Attach references for serialization and from_id resolution async_wrapper.__inner_function__ = func # type: ignore async_wrapper.__pynenc_task__ = task # type: ignore return async_wrapper # type: ignore @wraps(func) def sync_wrapper( *args: "Params.args", **kwargs: "Params.kwargs" ) -> "Result": if parallel_func: invocation_group = _parallelize(task, *args, **kwargs) return _aggregate_results(invocation_group.results) return task(*args, **kwargs).result # Attach references for serialization and from_id resolution sync_wrapper.__inner_function__ = func # type: ignore sync_wrapper.__pynenc_task__ = task # type: ignore return sync_wrapper if func is None: return decorator return decorator(func)