Files
Agent-n8n/n8n-n8n-1.109.2/packages/@n8n/task-runner-python/src/task_runner.py
2025-09-08 04:48:28 +08:00

296 lines
9.5 KiB
Python
Executable File

import asyncio
from dataclasses import dataclass
import logging
import time
from typing import Dict, Optional, Any
from urllib.parse import urlparse
import websockets
import random
from .errors import WebsocketConnectionError, TaskMissingError
from .message_types.broker import TaskSettings
from .nanoid import nanoid
from .constants import (
RUNNER_NAME,
TASK_REJECTED_REASON_AT_CAPACITY,
TASK_REJECTED_REASON_OFFER_EXPIRED,
TASK_TYPE_PYTHON,
OFFER_INTERVAL,
OFFER_VALIDITY,
OFFER_VALIDITY_MAX_JITTER,
OFFER_VALIDITY_LATENCY_BUFFER,
TASK_BROKER_WS_PATH,
)
from .message_types import (
BrokerMessage,
RunnerMessage,
BrokerInfoRequest,
BrokerRunnerRegistered,
BrokerTaskOfferAccept,
BrokerTaskSettings,
BrokerTaskCancel,
RunnerInfo,
RunnerTaskOffer,
RunnerTaskAccepted,
RunnerTaskRejected,
RunnerTaskDone,
RunnerTaskError,
)
from .message_serde import MessageSerde
from .task_state import TaskState, TaskStatus
from .task_executor import TaskExecutor
class TaskOffer:
def __init__(self, offer_id: str, valid_until: float):
self.offer_id = offer_id
self.valid_until = valid_until
@property
def has_expired(self) -> bool:
return time.time() > self.valid_until
@dataclass()
class TaskRunnerOpts:
grant_token: str
task_broker_uri: str
max_concurrency: int
max_payload_size: int
task_timeout: int
class TaskRunner:
def __init__(
self,
opts: TaskRunnerOpts,
):
self.runner_id = nanoid()
self.name = RUNNER_NAME
self.grant_token = opts.grant_token
self.opts = opts
self.websocket_connection: Optional[Any] = None
self.can_send_offers = False
self.open_offers: Dict[str, TaskOffer] = {}
self.running_tasks: Dict[str, TaskState] = {}
self.offers_coroutine: Optional[asyncio.Task] = None
self.serde = MessageSerde()
self.executor = TaskExecutor()
self.logger = logging.getLogger(__name__)
self.task_broker_uri = opts.task_broker_uri
websocket_host = urlparse(opts.task_broker_uri).netloc
self.websocket_url = (
f"ws://{websocket_host}{TASK_BROKER_WS_PATH}?id={self.runner_id}"
)
async def start(self) -> None:
headers = {"Authorization": f"Bearer {self.grant_token}"}
try:
self.websocket_connection = await websockets.connect(
self.websocket_url,
additional_headers=headers,
max_size=self.opts.max_payload_size,
)
self.logger.info("Connected to broker")
await self._listen_for_messages()
except Exception:
raise WebsocketConnectionError(self.task_broker_uri)
async def stop(self) -> None:
if self.offers_coroutine:
self.offers_coroutine.cancel()
if self.websocket_connection:
await self.websocket_connection.close()
self.logger.info("Disconnected from broker")
# ========== Messages ==========
async def _listen_for_messages(self) -> None:
if self.websocket_connection is None:
raise WebsocketConnectionError(self.task_broker_uri)
async for raw_message in self.websocket_connection:
try:
message = self.serde.deserialize_broker_message(raw_message)
await self._handle_message(message)
except Exception as e:
self.logger.error(f"Error handling message: {e}")
async def _handle_message(self, message: BrokerMessage) -> None:
match message:
case BrokerInfoRequest():
await self._handle_info_request()
case BrokerRunnerRegistered():
await self._handle_runner_registered()
case BrokerTaskOfferAccept():
await self._handle_task_offer_accept(message)
case BrokerTaskSettings():
await self._handle_task_settings(message)
case BrokerTaskCancel():
await self._handle_task_cancel(message)
case _:
self.logger.warning(f"Unhandled message type: {type(message)}")
async def _handle_info_request(self) -> None:
response = RunnerInfo(name=self.name, types=[TASK_TYPE_PYTHON])
await self._send_message(response)
async def _handle_runner_registered(self) -> None:
self.can_send_offers = True
self.offers_coroutine = asyncio.create_task(self._send_offers_loop())
self.logger.info("Registered with broker")
async def _handle_task_offer_accept(self, message: BrokerTaskOfferAccept) -> None:
offer = self.open_offers.get(message.offer_id)
if offer is None or offer.has_expired:
response = RunnerTaskRejected(
task_id=message.task_id,
reason=TASK_REJECTED_REASON_OFFER_EXPIRED,
)
await self._send_message(response)
return
if len(self.running_tasks) >= self.opts.max_concurrency:
response = RunnerTaskRejected(
task_id=message.task_id,
reason=TASK_REJECTED_REASON_AT_CAPACITY,
)
await self._send_message(response)
return
del self.open_offers[message.offer_id]
task_state = TaskState(message.task_id)
self.running_tasks[message.task_id] = task_state
response = RunnerTaskAccepted(task_id=message.task_id)
await self._send_message(response)
self.logger.info(f"Accepted task {message.task_id}")
async def _handle_task_settings(self, message: BrokerTaskSettings) -> None:
task_state = self.running_tasks.get(message.task_id)
if task_state is None:
raise TaskMissingError(message.task_id)
if task_state.status != TaskStatus.WAITING_FOR_SETTINGS:
self.logger.warning(
f"Received settings for task but it is already {task_state.status}. Discarding message."
)
return
task_state.status = TaskStatus.RUNNING
asyncio.create_task(self._execute_task(message.task_id, message.settings))
self.logger.info(f"Received task {message.task_id}")
async def _execute_task(self, task_id: str, task_settings: TaskSettings) -> None:
try:
task_state = self.running_tasks.get(task_id)
if task_state is None:
raise TaskMissingError(task_id)
process, queue = self.executor.create_process(
task_settings.code, task_settings.node_mode, task_settings.items
)
task_state.process = process
result = self.executor.execute_process(
process, queue, self.opts.task_timeout, task_settings.continue_on_fail
)
response = RunnerTaskDone(task_id=task_id, data={"result": result})
await self._send_message(response)
self.logger.info(f"Completed task {task_id}")
except Exception as e:
response = RunnerTaskError(task_id=task_id, error={"message": str(e)})
await self._send_message(response)
finally:
self.running_tasks.pop(task_id, None)
async def _handle_task_cancel(self, message: BrokerTaskCancel) -> None:
task_state = self.running_tasks.get(message.task_id)
if task_state is None:
self.logger.warning(
f"Received cancel for unknown task: {message.task_id}. Discarding message."
)
return
if task_state.status == TaskStatus.WAITING_FOR_SETTINGS:
self.running_tasks.pop(message.task_id, None)
await self._send_offers()
return
if task_state.status == TaskStatus.RUNNING:
task_state.status = TaskStatus.ABORTING
self.executor.stop_process(task_state.process)
async def _send_message(self, message: RunnerMessage) -> None:
if self.websocket_connection is None:
raise WebsocketConnectionError(self.task_broker_uri)
serialized = self.serde.serialize_runner_message(message)
await self.websocket_connection.send(serialized)
# ========== Offers ==========
async def _send_offers_loop(self) -> None:
while self.can_send_offers:
try:
await self._send_offers()
await asyncio.sleep(OFFER_INTERVAL)
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Error sending offers: {e}")
async def _send_offers(self) -> None:
if not self.can_send_offers:
return
expired_offer_ids = [
offer_id
for offer_id, offer in self.open_offers.items()
if offer.has_expired
]
for offer_id in expired_offer_ids:
self.open_offers.pop(offer_id, None)
offers_to_send = self.opts.max_concurrency - (
len(self.open_offers) + len(self.running_tasks)
)
for _ in range(offers_to_send):
offer_id = nanoid()
valid_for_ms = OFFER_VALIDITY + random.randint(0, OFFER_VALIDITY_MAX_JITTER)
valid_until = (
time.time() + (valid_for_ms / 1000) + OFFER_VALIDITY_LATENCY_BUFFER
)
self.open_offers[offer_id] = TaskOffer(offer_id, valid_until)
message = RunnerTaskOffer(
offer_id=offer_id, task_type=TASK_TYPE_PYTHON, valid_for=valid_for_ms
)
await self._send_message(message)