Source code for pynenc.runner.multi_thread_runner

import time
from enum import Enum
from functools import cached_property
from multiprocessing import Manager, Process, cpu_count
from typing import TYPE_CHECKING, Any, NamedTuple

from pynenc import context
from pynenc.conf.config_runner import ConfigMultiThreadRunner
from pynenc.runner.base_runner import BaseRunner
from pynenc.runner.runner_context import RunnerContext
from pynenc.runner.shutdown_diagnostics import log_runner_shutdown
from pynenc.runner.thread_runner import ThreadRunner
from pynenc.util.multiprocessing_utils import warn_missing_main_guard


if TYPE_CHECKING:
    from pynenc.app import Pynenc
    from pynenc.identifiers.invocation_id import InvocationId


[docs] class ProcessState(Enum): ACTIVE = "active" IDLE = "idle"
[docs] class ProcessStatus(NamedTuple): last_update: float active_count: int state: ProcessState
[docs] def is_idle(self, now: float, idle_timeout: float) -> bool: """Return True if the process is idle and has been idle longer than idle_timeout.""" return ( now - self.last_update ) > idle_timeout and self.state == ProcessState.IDLE
[docs] def thread_runner_process_main( app: "Pynenc", *, parent_ctx_json: str, child_runner_id: str, runner_cache: dict, shared_status: dict[str, ProcessStatus], ) -> None: """ Entry point for ThreadRunner worker processes spawned by MultiThreadRunner. The parent pre-generates the child_runner_id before spawning, enabling parent-based health reporting. The parent reports heartbeats for alive children via its main loop. """ import signal as _signal parent_ctx = RunnerContext.from_json(parent_ctx_json) runner_ctx = parent_ctx.new_child_context( ThreadRunner.__name__, runner_id=child_runner_id ) app.runner._register_new_child_runner_context(runner_ctx) runner = ThreadRunner(app, runner_cache, runner_context=runner_ctx) # Replace the MultiThreadRunner with ThreadRunner in this process context.set_runner_context(app.app_id, runner_ctx) runner._on_start() app.logger.info(f"ThreadRunner process worker:{child_runner_id} started") def _handle_signal(signum: int, frame: Any) -> None: # A second SIGTERM can arrive while shutdown diagnostics are collecting # process information. Ignore repeated signals so logging cannot be # re-entered from inside imports or stdlib helpers. _signal.signal(_signal.SIGTERM, _signal.SIG_IGN) runner._log_shutdown(signum) raise KeyboardInterrupt # trigger _on_stop via finally block _signal.signal(_signal.SIGTERM, _handle_signal) try: while True: # Clean up finished threads. runner.threads = { k: v for k, v in runner.threads.items() if v.thread.is_alive() } active_count = len(runner.threads) state = ProcessState.ACTIVE if active_count > 0 else ProcessState.IDLE shared_status[child_runner_id] = ProcessStatus( time.time(), active_count, state ) app.logger.debug( f"worker:{child_runner_id}: {active_count} active threads, state:{state}" ) runner.runner_loop_iteration() except KeyboardInterrupt: app.logger.warning( f"ThreadRunner process worker:{child_runner_id} interrupted, shutting down" ) finally: # Ignore further signals during cleanup so a second SIGTERM from the # parent's _terminate_all_processes cannot interrupt _kill_and_reroute. _signal.signal(_signal.SIGTERM, _signal.SIG_IGN) runner._on_stop()
[docs] class MultiThreadRunner(BaseRunner): """ MultiThreadRunner spawns separate processes, each running a ThreadRunner. It scales processes based on pending invocations and terminates those that remain idle. """ WAITING_FOR_RESULTS_WARNING = ( "waiting_for_results called on MultiThreadRunner from within a task. " "This should be handled by the ThreadRunner instance in the process." ) def __init__( self, app: "Pynenc", runner_cache: dict | None = None, runner_context: RunnerContext | None = None, ) -> None: self.child_runner_ids: dict[str, Process] = {} self.shared_status: dict[str, ProcessStatus] = {} self.max_processes: int = 0 self.manager: Manager | None = None # type: ignore super().__init__(app, runner_cache, runner_context) @cached_property def conf(self) -> ConfigMultiThreadRunner: return ConfigMultiThreadRunner( config_values=self.app.config_values, config_filepath=self.app.config_filepath, )
[docs] @staticmethod def mem_compatible() -> bool: """ Indicates if the runner is compatible with in-memory components. :return: False, as each thread runs in a separate process with independent memory """ return False
@property def max_parallel_slots(self) -> int: """ The maximum number of parallel tasks that the runner can handle. :return: int representing the maximum number of parallel tasks """ return max(self.conf.min_processes, self.max_processes)
[docs] def get_active_child_runner_ids(self) -> list[str]: """Return runner_ids of child processes that are still alive.""" return [ runner_id for runner_id, proc in self.child_runner_ids.items() if proc.is_alive() ]
[docs] def _log_shutdown(self, signum: int | None) -> None: log_runner_shutdown( self.app.logger, self.__class__.__name__, self.runner_id, signum, processes={ rid: (proc, None) for rid, proc in self.child_runner_ids.items() }, )
[docs] def _on_start(self) -> None: """ Initialize multiprocessing infrastructure for spawning worker processes. Validates that multiprocessing is being used safely before creating the Manager and spawning initial processes. """ self.logger.info("Starting MultiThreadRunner") warn_missing_main_guard() self.manager = Manager() self.shared_status = self.manager.dict() # type: ignore self.runner_cache = self._runner_cache or self.manager.dict() # type: ignore self.child_runner_ids = {} self.max_processes = self.conf.max_processes or cpu_count() for _ in range(self.conf.min_processes): self._spawn_thread_runner_process()
[docs] def _spawn_thread_runner_process(self) -> None: """Spawn a new ThreadRunner worker process with pre-generated runner_id.""" import uuid child_runner_id = str(uuid.uuid4()) args = { "app": self.app, "parent_ctx_json": self.runner_context.to_json(), "child_runner_id": child_runner_id, "runner_cache": self.runner_cache, "shared_status": self.shared_status, } p = Process(target=thread_runner_process_main, kwargs=args, daemon=True) p.start() self.child_runner_ids[child_runner_id] = p # Initialize shared_status with current time, 0 active threads, state IDLE. self.shared_status[child_runner_id] = ProcessStatus( time.time(), 0, ProcessState.IDLE ) self.logger.info( f"Spawned ThreadRunner process worker:{child_runner_id} pid:{p.pid}" )
[docs] def _terminate_process( self, proc: Process, runner_id: str, timeout: float = 10.0 ) -> None: """ Terminate a child process with SIGKILL fallback. Sends SIGTERM via proc.terminate(), waits up to timeout seconds, then sends SIGKILL if the process is still alive. :param proc: The child process to terminate :param runner_id: The runner ID for logging :param timeout: Seconds to wait before escalating to SIGKILL """ try: proc.terminate() except OSError: self.logger.debug(f"worker:{runner_id} already exited before terminate") return proc.join(timeout=timeout) try: if proc.is_alive(): self.logger.warning( f"worker:{runner_id} pid:{proc.pid} did not exit after {timeout}s, sending SIGKILL" ) proc.kill() proc.join(timeout=5) except OSError: self.logger.debug(f"worker:{runner_id} exited before SIGKILL")
[docs] def _on_stop(self) -> None: """Stop all worker processes and shutdown the manager.""" self.logger.info("Stopping MultiThreadRunner") for runner_id, proc in self.child_runner_ids.items(): if proc.is_alive(): self._terminate_process(proc, runner_id) self.logger.info(f"Terminated process worker:{runner_id}") if self.manager is not None: self.manager.shutdown() # type: ignore self.logger.info("MultiThreadRunner stopped")
[docs] def _safe_remove_shared_state(self, key: str) -> None: """ Safely remove a process's shared state, handling manager shutdown cases. :param str key: The process key to remove from shared state """ try: self.shared_status.pop(key, None) except (EOFError, BrokenPipeError): self.logger.debug( f"Manager already stopped while removing state for worker:{key}" )
[docs] def _on_stop_runner_loop(self) -> None: """Internal method called after receiving a signal to stop the runner loop.""" self.logger.info("Stopping MultiThreadRunner loop") for runner_id, proc in list(self.child_runner_ids.items()): try: if proc.is_alive(): self._terminate_process(proc, runner_id) self.child_runner_ids.pop(runner_id, None) self._safe_remove_shared_state(runner_id) self.logger.info( f"Terminated process worker:{runner_id} during loop stop" ) except AssertionError: self.logger.info( f"Skipping process worker:{runner_id} termination - not a child process" ) self.logger.info("MultiThreadRunner loop stopped")
[docs] def _cleanup_dead_processes(self) -> None: """Remove processes that are no longer alive from tracking dictionaries.""" dead_ids = [ rid for rid, proc in self.child_runner_ids.items() if not proc.is_alive() ] if dead_ids: self.logger.warning(f"Found {len(dead_ids)} dead processes to clean up") for runner_id in dead_ids: self.child_runner_ids.pop(runner_id, None) self._safe_remove_shared_state(runner_id) self.logger.info(f"Cleaned up dead process worker:{runner_id}")
[docs] def _scale_up_processes(self) -> None: """Spawn new processes based on enforce_max_processes setting and pending tasks.""" current_processes = len(self.child_runner_ids) if self.conf.enforce_max_processes: while current_processes < self.max_processes: self._spawn_thread_runner_process() current_processes = len(self.child_runner_ids) else: queued_invocations = self.app.broker.count_invocations() if ( queued_invocations > current_processes and current_processes < self.max_processes ): to_spawn = min( queued_invocations - current_processes, self.max_processes - current_processes, ) for _ in range(to_spawn): self._spawn_thread_runner_process()
[docs] def _terminate_idle_processes(self) -> None: """Terminate processes that are idle longer than the configured timeout.""" if ( self.conf.enforce_max_processes and len(self.child_runner_ids) > self.max_processes ): return now = time.time() ids_to_remove: list[str] = [] for runner_id, proc in self.child_runner_ids.items(): if ( len(self.child_runner_ids) - len(ids_to_remove) <= self.conf.min_processes ): break if (status := self.shared_status.get(runner_id)) is None: continue if status.is_idle(now, self.conf.idle_timeout_process_sec): idle_time = now - status.last_update self.logger.info( f"worker:{runner_id} idle_time:{idle_time} sec, terminating." ) self._terminate_process(proc, runner_id) ids_to_remove.append(runner_id) for runner_id in ids_to_remove: self.child_runner_ids.pop(runner_id, None) self._safe_remove_shared_state(runner_id)
[docs] def runner_loop_iteration(self) -> None: """Execute one iteration of the runner loop.""" self._scale_up_processes()
[docs] def _waiting_for_results( self, running_invocation_id: "InvocationId", result_invocation_ids: list["InvocationId"], runner_args: dict[str, Any] | None = None, ) -> None: """ Handle waiting for results when called outside a process context. This method warns if called directly on MultiThreadRunner, as result waiting should occur within a ThreadRunner process using the context-set runner. :param InvocationId running_invocation_id: ID of the invocation waiting for results :param list[InvocationId] result_invocation_ids: IDs of invocations being awaited :param dict[str, Any] | None runner_args: Additional runner-specific arguments """ del running_invocation_id, result_invocation_ids, runner_args self.logger.warning(self.WAITING_FOR_RESULTS_WARNING)