data_juicer_agents.tui.controller 源代码

# -*- coding: utf-8 -*-
"""Controller for running DJSessionAgent with asynchronous TUI updates."""

from __future__ import annotations

import queue
import threading
from typing import Any, Dict, List, Optional

from data_juicer_agents.capabilities.session.orchestrator import DJSessionAgent
from data_juicer_agents.capabilities.session.orchestrator import SessionReply
from data_juicer_agents.tui.noise_filter import suppress_tui_noise_stderr


[文档] class SessionController:
[文档] def __init__( self, *, dataset_path: Optional[str], export_path: Optional[str], verbose: bool, ) -> None: self._dataset_path = dataset_path self._export_path = export_path self._verbose = bool(verbose) self._event_queue: queue.Queue[Dict[str, Any]] = queue.Queue() self._agent: Optional[DJSessionAgent] = None self._lock = threading.RLock() self._turn_running = False self._turn_reply: Optional[SessionReply] = None self._turn_error: Optional[Exception] = None self._turn_thread: Optional[threading.Thread] = None
def _on_agent_event(self, event: Dict[str, Any]) -> None: self._event_queue.put(dict(event))
[文档] def start(self) -> None: with self._lock: if self._agent is not None: return self._agent = DJSessionAgent( use_llm_router=True, dataset_path=self._dataset_path, export_path=self._export_path, verbose=self._verbose, event_callback=self._on_agent_event, )
[文档] def submit_turn(self, message: str) -> None: with self._lock: if self._agent is None: raise RuntimeError("SessionController has not been started") if self._turn_running: raise RuntimeError("A turn is already running") self._turn_running = True self._turn_reply = None self._turn_error = None def _worker() -> None: try: with suppress_tui_noise_stderr(): reply = self._agent.handle_message(message) with self._lock: self._turn_reply = reply except Exception as exc: # pragma: no cover - defensive path with self._lock: self._turn_error = exc finally: with self._lock: self._turn_running = False thread = threading.Thread(target=_worker, daemon=True) self._turn_thread = thread thread.start()
[文档] def is_turn_running(self) -> bool: with self._lock: return self._turn_running
[文档] def request_interrupt(self) -> bool: with self._lock: agent = self._agent if agent is None: return False try: return bool(agent.request_interrupt()) except Exception: # pragma: no cover - defensive path return False
[文档] def drain_events(self) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] while True: try: rows.append(self._event_queue.get_nowait()) except queue.Empty: return rows
[文档] def consume_turn_result(self) -> SessionReply: with self._lock: running = self._turn_running reply = self._turn_reply error = self._turn_error thread = self._turn_thread if running: raise RuntimeError("Turn is still running") if thread is not None: thread.join(timeout=0.1) if reply is not None: with self._lock: self._turn_reply = None self._turn_error = None self._turn_thread = None return reply if error is not None: with self._lock: self._turn_reply = None self._turn_error = None self._turn_thread = None return SessionReply( text=( "Unhandled session error, exiting session.\n" f"error: {error}" ), stop=True, ) return SessionReply(text="No response generated.")