112 lines
3.2 KiB
Python
Executable File
112 lines
3.2 KiB
Python
Executable File
import json
|
|
from dataclasses import asdict
|
|
from typing import cast
|
|
|
|
from .message_types.broker import NodeMode, TaskSettings
|
|
from .constants import (
|
|
BROKER_INFO_REQUEST,
|
|
BROKER_RUNNER_REGISTERED,
|
|
BROKER_TASK_CANCEL,
|
|
BROKER_TASK_OFFER_ACCEPT,
|
|
BROKER_TASK_SETTINGS,
|
|
)
|
|
from .message_types import (
|
|
BrokerMessage,
|
|
RunnerMessage,
|
|
BrokerInfoRequest,
|
|
BrokerRunnerRegistered,
|
|
BrokerTaskOfferAccept,
|
|
BrokerTaskSettings,
|
|
BrokerTaskCancel,
|
|
)
|
|
|
|
|
|
NODE_MODE_MAP = {
|
|
"runOnceForAllItems": "all_items",
|
|
"runOnceForEachItem": "per_item",
|
|
}
|
|
|
|
|
|
def _get_node_mode(node_mode_str: str) -> NodeMode:
|
|
if node_mode_str not in NODE_MODE_MAP:
|
|
raise ValueError(f"Unknown nodeMode: {node_mode_str}")
|
|
return cast(NodeMode, NODE_MODE_MAP[node_mode_str])
|
|
|
|
|
|
def _parse_task_settings(d: dict) -> BrokerTaskSettings:
|
|
try:
|
|
task_id = d["taskId"]
|
|
settings_dict = d["settings"]
|
|
code = settings_dict["code"]
|
|
node_mode = _get_node_mode(settings_dict["nodeMode"])
|
|
continue_on_fail = settings_dict.get("continueOnFail", False)
|
|
items = settings_dict["items"]
|
|
except KeyError as e:
|
|
raise ValueError(f"Missing field in task settings message: {e}")
|
|
|
|
return BrokerTaskSettings(
|
|
task_id=task_id,
|
|
settings=TaskSettings(
|
|
code=code,
|
|
node_mode=node_mode,
|
|
continue_on_fail=continue_on_fail,
|
|
items=items,
|
|
),
|
|
)
|
|
|
|
|
|
def _parse_task_offer_accept(d: dict) -> BrokerTaskOfferAccept:
|
|
try:
|
|
task_id = d["taskId"]
|
|
offer_id = d["offerId"]
|
|
except KeyError as e:
|
|
raise ValueError(f"Missing field in task offer acceptance message: {e}")
|
|
|
|
return BrokerTaskOfferAccept(task_id=task_id, offer_id=offer_id)
|
|
|
|
|
|
def _parse_task_cancel(d: dict) -> BrokerTaskCancel:
|
|
try:
|
|
task_id = d["taskId"]
|
|
reason = d["reason"]
|
|
except KeyError as e:
|
|
raise ValueError(f"Missing field in task cancel message: {e}")
|
|
|
|
return BrokerTaskCancel(task_id=task_id, reason=reason)
|
|
|
|
|
|
MESSAGE_TYPE_MAP = {
|
|
BROKER_INFO_REQUEST: lambda _: BrokerInfoRequest(),
|
|
BROKER_RUNNER_REGISTERED: lambda _: BrokerRunnerRegistered(),
|
|
BROKER_TASK_OFFER_ACCEPT: _parse_task_offer_accept,
|
|
BROKER_TASK_SETTINGS: _parse_task_settings,
|
|
BROKER_TASK_CANCEL: _parse_task_cancel,
|
|
}
|
|
|
|
|
|
class MessageSerde:
|
|
"""Responsible for deserializing incoming messages and serializing outgoing messages."""
|
|
|
|
@staticmethod
|
|
def deserialize_broker_message(data: str) -> BrokerMessage:
|
|
message_dict = json.loads(data)
|
|
message_type = message_dict.get("type")
|
|
|
|
if message_type not in MESSAGE_TYPE_MAP:
|
|
raise ValueError(f"Unknown message type: {message_type}")
|
|
|
|
return MESSAGE_TYPE_MAP[message_type](message_dict)
|
|
|
|
@staticmethod
|
|
def serialize_runner_message(message: RunnerMessage) -> str:
|
|
data = asdict(message)
|
|
camel_case_data = {
|
|
MessageSerde._snake_to_camel_case(k): v for k, v in data.items()
|
|
}
|
|
return json.dumps(camel_case_data)
|
|
|
|
@staticmethod
|
|
def _snake_to_camel_case(snake_case_str: str) -> str:
|
|
parts = snake_case_str.split("_")
|
|
return parts[0] + "".join(word.capitalize() for word in parts[1:])
|