Module enrgdaq.cnc

Sub-modules

enrgdaq.cnc.base
enrgdaq.cnc.handlers
enrgdaq.cnc.log_util
enrgdaq.cnc.models
enrgdaq.cnc.rest

Functions

def start_supervisor_cnc(supervisor,
config: SupervisorCNCConfig) ‑> SupervisorCNC | None
Expand source code
def start_supervisor_cnc(
    supervisor, config: SupervisorCNCConfig
) -> Optional[SupervisorCNC]:
    try:
        cnc = SupervisorCNC(supervisor, config)
        cnc.start()
        return cnc
    except Exception as e:
        logging.getLogger(__name__).error(
            f"Failed to start SupervisorCNC: {e}", exc_info=True
        )
        return None

Classes

class CNCLogHandler (cnc_instance)
Expand source code
class CNCLogHandler(logging.Handler):
    """
    A logging handler that sends log records to the CNC server.
    """

    def __init__(self, cnc_instance):
        super().__init__()
        self.cnc_instance = cnc_instance

    def emit(self, record):
        """
        Emit a log record by sending it to the CNC server.
        """
        try:
            # Format the log message
            # If no formatter is set, use the default formatting
            if self.formatter:
                log_message = self.format(record)
            else:
                # Use a simple default format if no formatter is set
                log_message = record.getMessage()
                if record.exc_info:
                    # Add exception info if present
                    import traceback

                    log_message += "\n" + "".join(
                        traceback.format_exception(*record.exc_info)
                    )

            # Create a CNC log message
            client_id = "unknown"
            if hasattr(self.cnc_instance, "supervisor_info") and hasattr(
                self.cnc_instance.supervisor_info, "supervisor_id"
            ):
                client_id = self.cnc_instance.supervisor_info.supervisor_id

            cnc_log_msg = CNCMessageLog(
                level=record.levelname,
                message=log_message,
                timestamp=datetime.now().isoformat(),
                module=record.name,
                client_id=client_id,
            )

            # Send the log message via the CNC system
            # For clients, we send the message directly to the server
            if not self.cnc_instance.is_server:
                self.cnc_instance._send_zmq_message(None, cnc_log_msg)

            # Propagate the log message to the root logger
            logging.getLogger().handle(record)

        except Exception:
            pass

A logging handler that sends log records to the CNC server.

Initializes the instance - basically setting the formatter to None and the filter list to empty.

Ancestors

  • logging.Handler
  • logging.Filterer

Methods

def emit(self, record)
Expand source code
def emit(self, record):
    """
    Emit a log record by sending it to the CNC server.
    """
    try:
        # Format the log message
        # If no formatter is set, use the default formatting
        if self.formatter:
            log_message = self.format(record)
        else:
            # Use a simple default format if no formatter is set
            log_message = record.getMessage()
            if record.exc_info:
                # Add exception info if present
                import traceback

                log_message += "\n" + "".join(
                    traceback.format_exception(*record.exc_info)
                )

        # Create a CNC log message
        client_id = "unknown"
        if hasattr(self.cnc_instance, "supervisor_info") and hasattr(
            self.cnc_instance.supervisor_info, "supervisor_id"
        ):
            client_id = self.cnc_instance.supervisor_info.supervisor_id

        cnc_log_msg = CNCMessageLog(
            level=record.levelname,
            message=log_message,
            timestamp=datetime.now().isoformat(),
            module=record.name,
            client_id=client_id,
        )

        # Send the log message via the CNC system
        # For clients, we send the message directly to the server
        if not self.cnc_instance.is_server:
            self.cnc_instance._send_zmq_message(None, cnc_log_msg)

        # Propagate the log message to the root logger
        logging.getLogger().handle(record)

    except Exception:
        pass

Emit a log record by sending it to the CNC server.

class CNCMessageLog (level: str,
message: str,
timestamp: str,
module: str,
client_id: str,
*,
req_id: str | None = None)
Expand source code
class CNCMessageLog(CNCMessage):
    """Log message from client to server."""

    level: str  # e.g. 'INFO', 'WARNING', 'ERROR'
    message: str
    timestamp: str
    module: str  # module or component that generated the log
    client_id: str  # ID of the client sending the log

Log message from client to server.

Ancestors

  • CNCMessage
  • msgspec.Struct
  • msgspec._core._StructMixin

Instance variables

var client_id : str
Expand source code
class CNCMessageLog(CNCMessage):
    """Log message from client to server."""

    level: str  # e.g. 'INFO', 'WARNING', 'ERROR'
    message: str
    timestamp: str
    module: str  # module or component that generated the log
    client_id: str  # ID of the client sending the log
var level : str
Expand source code
class CNCMessageLog(CNCMessage):
    """Log message from client to server."""

    level: str  # e.g. 'INFO', 'WARNING', 'ERROR'
    message: str
    timestamp: str
    module: str  # module or component that generated the log
    client_id: str  # ID of the client sending the log
var message : str
Expand source code
class CNCMessageLog(CNCMessage):
    """Log message from client to server."""

    level: str  # e.g. 'INFO', 'WARNING', 'ERROR'
    message: str
    timestamp: str
    module: str  # module or component that generated the log
    client_id: str  # ID of the client sending the log
var module : str
Expand source code
class CNCMessageLog(CNCMessage):
    """Log message from client to server."""

    level: str  # e.g. 'INFO', 'WARNING', 'ERROR'
    message: str
    timestamp: str
    module: str  # module or component that generated the log
    client_id: str  # ID of the client sending the log
var timestamp : str
Expand source code
class CNCMessageLog(CNCMessage):
    """Log message from client to server."""

    level: str  # e.g. 'INFO', 'WARNING', 'ERROR'
    message: str
    timestamp: str
    module: str  # module or component that generated the log
    client_id: str  # ID of the client sending the log
class SupervisorCNC (supervisor,
config: SupervisorCNCConfig)
Expand source code
class SupervisorCNC:
    """
    Simplified Command and Control.
    Handles ZMQ communication and provides a direct interface for the REST API.
    """

    def __init__(self, supervisor, config: SupervisorCNCConfig):
        from enrgdaq.supervisor import Supervisor

        self.supervisor: Supervisor = supervisor
        self.config = config
        self.is_server = config.is_server
        self.supervisor_info = supervisor.config.info

        self.context = zmq.Context()
        self.poller = zmq.Poller()
        self.socket: Optional[zmq.Socket] = None

        self.clients: Dict[str, CNCClientInfo] = {}
        self._pending_responses: Dict[str, Future] = {}
        self._command_queue: queue.Queue = queue.Queue()
        self.client_logs: defaultdict[str, deque[CNCMessageLog]] = defaultdict(deque)
        self._server_client_id: Optional[str] = None

        self._stop_event = threading.Event()
        self._thread = threading.Thread(target=self.run, daemon=True)

        # Set up CNC logging to automatically capture and send logs
        self._logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
        self._logger.setLevel(config.verbosity.to_logging_level())

        if not self.is_server:
            from .log_util import CNCLogHandler

            cnc_handler = CNCLogHandler(self)
            cnc_handler.setLevel(config.verbosity.to_logging_level())

            self._logger.addHandler(cnc_handler)
            self._logger.propagate = False

        self.message_handlers = {
            CNCMessageHeartbeat: HeartbeatHandler(self),
            CNCMessageLog: ReqLogHandler(self),
            CNCMessageReqPing: ReqPingHandler(self),
            CNCMessageResPing: ResPingHandler(self),
            CNCMessageReqStatus: ReqStatusHandler(self),
            CNCMessageResStatus: ResStatusHandler(self),
            CNCMessageReqListClients: ReqListClientsHandler(self),
            CNCMessageReqRestartDAQ: ReqRestartHandler(self),
            CNCMessageReqStopDAQJobs: ReqStopDAQJobsHandler(self),
            CNCMessageReqRunCustomDAQJob: ReqRunCustomDAQJobHandler(self),
            CNCMessageReqStopDAQJob: ReqStopDAQJobHandler(self),
            CNCMessageReqSendMessage: ReqSendMessageHandler(self),
        }

        if self.is_server:
            self.socket = self.context.socket(zmq.ROUTER)
            self.socket.bind(f"tcp://*:{config.server_port}")
            self._logger.info(f"C&C Server started on port {config.server_port}")

            # Register the server as its own client
            self._server_client_id = self.supervisor_info.supervisor_id
            self.clients[self._server_client_id] = CNCClientInfo(
                identity=None,  # No ZMQ identity for local processing
                last_seen=datetime.now().isoformat(),
                info=self.supervisor_info,
            )

            if config.rest_api_enabled:
                start_rest_api(self)
        else:
            self.socket = self.context.socket(zmq.DEALER)
            self.socket.setsockopt_string(
                zmq.IDENTITY, self.supervisor_info.supervisor_id
            )
            self.socket.connect(f"tcp://{config.server_host}:{config.server_port}")
            self._logger.info(
                f"C&C Client connected to {config.server_host}:{config.server_port}"
            )

        self.poller.register(self.socket, zmq.POLLIN)

    def start(self):
        self._thread.start()

    def stop(self):
        self._stop_event.set()
        if self.socket:
            self.socket.setsockopt(zmq.LINGER, 0)
            self.socket.close()
        self._thread.join(timeout=2)
        self.context.term()

    def send_command_sync(
        self, target_client_id: str, msg: CNCMessage, timeout: int = 2
    ) -> CNCMessage:
        """
        Thread-safe method called by REST API to send a command and wait for a reply.
        """
        if not self.is_server:
            raise RuntimeError("Only server can send commands to clients.")

        # Handle self-targeting: process locally without ZMQ
        if target_client_id == self._server_client_id:
            return self._process_local_command(msg)

        # Generate request ID
        req_id = str(uuid.uuid4())
        msg.req_id = req_id

        future = Future()
        # Map the ID to the Future
        self._pending_responses[req_id] = future

        # Queue the send operation
        self._command_queue.put((target_client_id, msg))

        try:
            return future.result(timeout=timeout)
        except Exception as e:
            # Cleanup on timeout
            self._pending_responses.pop(req_id, None)
            raise e

    def _process_local_command(self, msg: CNCMessage) -> CNCMessage:
        """
        Process a command locally (for self-targeting).
        This allows the server to send commands to itself.
        """
        handler = self.message_handlers.get(type(msg))
        if handler:
            # Use a placeholder identity for local processing
            result = handler.handle(self._server_client_id.encode("utf-8"), msg)
            if result:
                response_msg, _ = result
                return response_msg

        raise RuntimeError(f"No handler found for message type: {type(msg).__name__}")

    def run(self):
        last_heartbeat = 0
        while not self._stop_event.is_set():
            try:
                # 1. Process outgoing commands from REST API
                while not self._command_queue.empty():
                    target_id, msg = self._command_queue.get()
                    self._send_zmq_message(target_id.encode("utf-8"), msg)

                # 2. Poll for incoming ZMQ messages
                socks = dict(self.poller.poll(timeout=50))
                if self.socket in socks:
                    self.handle_incoming_message()

                # 3. Client Heartbeat / Server self-update
                if not self.is_server:
                    if time.time() - last_heartbeat >= CNC_HEARTBEAT_INTERVAL_SECONDS:
                        self._send_zmq_message(
                            None,
                            CNCMessageHeartbeat(supervisor_info=self.supervisor_info),
                        )
                        last_heartbeat = time.time()
                else:
                    # Update the server's own client entry periodically
                    if time.time() - last_heartbeat >= CNC_HEARTBEAT_INTERVAL_SECONDS:
                        self.clients[self._server_client_id] = CNCClientInfo(
                            identity=None,
                            last_seen=datetime.now().isoformat(),
                            info=self.supervisor_info,
                        )
                        last_heartbeat = time.time()

            except zmq.ZMQError:
                if self._stop_event.is_set():
                    break
            except Exception as e:
                self._logger.error(f"Error in CNC loop: {e}", exc_info=True)

    def handle_incoming_message(self):
        assert self.socket is not None
        try:
            frames = self.socket.recv_multipart()
        except zmq.ZMQError:
            return

        if not frames:
            return

        # Router: [SenderID, Body]
        # Dealer: [Body]
        if self.is_server:
            if len(frames) != 2:
                return
            sender_id_bytes, body = frames[0], frames[1]
            sender_id_str = sender_id_bytes.decode("utf-8", errors="ignore")
        else:
            if len(frames) != 1:
                return
            sender_id_bytes = None
            sender_id_str = "server"
            body = frames[0]

        try:
            msg = msgspec.msgpack.decode(body, type=CNCMessageType)

            if self.is_server and sender_id_str:
                # Update Registry
                existing_info = self.clients.get(sender_id_str)
                self.clients[sender_id_str] = CNCClientInfo(
                    identity=sender_id_bytes,
                    last_seen=datetime.now().isoformat(),
                    info=msg.supervisor_info
                    if hasattr(msg, "supervisor_info")
                    else existing_info.info
                    if existing_info
                    else None,
                )

                # Only if the message actually has a req_id and it's in our pending map
                if (
                    hasattr(msg, "req_id")
                    and msg.req_id
                    and msg.req_id in self._pending_responses
                ):
                    future = self._pending_responses.pop(msg.req_id)
                    if not future.done():
                        future.set_result(msg)

            self._process_payload(msg, sender_id_bytes)

        except Exception as e:
            self._logger.error(f"Error processing payload: {e}", exc_info=True)

    def _process_payload(self, msg: CNCMessage, sender_id_bytes: Optional[bytes]):
        handler = self.message_handlers.get(type(msg))
        if handler:
            identity_arg = sender_id_bytes if self.is_server else b"server"
            result = handler.handle(identity_arg, msg)
            if result:
                response_msg, _ = result

                # Copy the request id from Request to Response
                # This ensures the Server knows which request this response belongs to.
                if hasattr(msg, "req_id") and hasattr(response_msg, "req_id"):
                    response_msg.req_id = msg.req_id

                self._send_zmq_message(sender_id_bytes, response_msg)

    def _send_zmq_message(self, identity: Optional[bytes], msg: CNCMessage):
        assert self.socket is not None
        packed = msgspec.msgpack.encode(msg)
        if self.is_server:
            if identity:
                self.socket.send_multipart([identity, packed])
        else:
            self.socket.send(packed)

    def add_client_log(self, client_id: str, msg: CNCMessageLog):
        self.client_logs[client_id].append(msg)
        if len(self.client_logs[client_id]) > CNC_MAX_CLIENT_LOGS:
            self.client_logs[client_id].popleft()

Simplified Command and Control. Handles ZMQ communication and provides a direct interface for the REST API.

Methods

def add_client_log(self,
client_id: str,
msg: CNCMessageLog)
Expand source code
def add_client_log(self, client_id: str, msg: CNCMessageLog):
    self.client_logs[client_id].append(msg)
    if len(self.client_logs[client_id]) > CNC_MAX_CLIENT_LOGS:
        self.client_logs[client_id].popleft()
def handle_incoming_message(self)
Expand source code
def handle_incoming_message(self):
    assert self.socket is not None
    try:
        frames = self.socket.recv_multipart()
    except zmq.ZMQError:
        return

    if not frames:
        return

    # Router: [SenderID, Body]
    # Dealer: [Body]
    if self.is_server:
        if len(frames) != 2:
            return
        sender_id_bytes, body = frames[0], frames[1]
        sender_id_str = sender_id_bytes.decode("utf-8", errors="ignore")
    else:
        if len(frames) != 1:
            return
        sender_id_bytes = None
        sender_id_str = "server"
        body = frames[0]

    try:
        msg = msgspec.msgpack.decode(body, type=CNCMessageType)

        if self.is_server and sender_id_str:
            # Update Registry
            existing_info = self.clients.get(sender_id_str)
            self.clients[sender_id_str] = CNCClientInfo(
                identity=sender_id_bytes,
                last_seen=datetime.now().isoformat(),
                info=msg.supervisor_info
                if hasattr(msg, "supervisor_info")
                else existing_info.info
                if existing_info
                else None,
            )

            # Only if the message actually has a req_id and it's in our pending map
            if (
                hasattr(msg, "req_id")
                and msg.req_id
                and msg.req_id in self._pending_responses
            ):
                future = self._pending_responses.pop(msg.req_id)
                if not future.done():
                    future.set_result(msg)

        self._process_payload(msg, sender_id_bytes)

    except Exception as e:
        self._logger.error(f"Error processing payload: {e}", exc_info=True)
def run(self)
Expand source code
def run(self):
    last_heartbeat = 0
    while not self._stop_event.is_set():
        try:
            # 1. Process outgoing commands from REST API
            while not self._command_queue.empty():
                target_id, msg = self._command_queue.get()
                self._send_zmq_message(target_id.encode("utf-8"), msg)

            # 2. Poll for incoming ZMQ messages
            socks = dict(self.poller.poll(timeout=50))
            if self.socket in socks:
                self.handle_incoming_message()

            # 3. Client Heartbeat / Server self-update
            if not self.is_server:
                if time.time() - last_heartbeat >= CNC_HEARTBEAT_INTERVAL_SECONDS:
                    self._send_zmq_message(
                        None,
                        CNCMessageHeartbeat(supervisor_info=self.supervisor_info),
                    )
                    last_heartbeat = time.time()
            else:
                # Update the server's own client entry periodically
                if time.time() - last_heartbeat >= CNC_HEARTBEAT_INTERVAL_SECONDS:
                    self.clients[self._server_client_id] = CNCClientInfo(
                        identity=None,
                        last_seen=datetime.now().isoformat(),
                        info=self.supervisor_info,
                    )
                    last_heartbeat = time.time()

        except zmq.ZMQError:
            if self._stop_event.is_set():
                break
        except Exception as e:
            self._logger.error(f"Error in CNC loop: {e}", exc_info=True)
def send_command_sync(self,
target_client_id: str,
msg: CNCMessage,
timeout: int = 2) ‑> CNCMessage
Expand source code
def send_command_sync(
    self, target_client_id: str, msg: CNCMessage, timeout: int = 2
) -> CNCMessage:
    """
    Thread-safe method called by REST API to send a command and wait for a reply.
    """
    if not self.is_server:
        raise RuntimeError("Only server can send commands to clients.")

    # Handle self-targeting: process locally without ZMQ
    if target_client_id == self._server_client_id:
        return self._process_local_command(msg)

    # Generate request ID
    req_id = str(uuid.uuid4())
    msg.req_id = req_id

    future = Future()
    # Map the ID to the Future
    self._pending_responses[req_id] = future

    # Queue the send operation
    self._command_queue.put((target_client_id, msg))

    try:
        return future.result(timeout=timeout)
    except Exception as e:
        # Cleanup on timeout
        self._pending_responses.pop(req_id, None)
        raise e

Thread-safe method called by REST API to send a command and wait for a reply.

def start(self)
Expand source code
def start(self):
    self._thread.start()
def stop(self)
Expand source code
def stop(self):
    self._stop_event.set()
    if self.socket:
        self.socket.setsockopt(zmq.LINGER, 0)
        self.socket.close()
    self._thread.join(timeout=2)
    self.context.term()