data_juicer_agents.capabilities.session.runtime 源代码
# -*- coding: utf-8 -*-
"""Runtime primitives shared by session tools."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4
import yaml
from data_juicer_agents.tools.plan import PlanModel
from data_juicer_agents.utils.runtime_helpers import (
normalize_line_idx,
parse_line_ranges,
run_interruptible_subprocess,
short_log,
to_bool,
to_event_result_preview,
to_int,
to_string_list,
to_text_response,
truncate_text,
)
[文档]
@dataclass
class SessionState:
dataset_path: Optional[str] = None
export_path: Optional[str] = None
working_dir: str = "./.djx"
plan_path: Optional[str] = None
plan_intent: Optional[str] = None
custom_operator_paths: List[str] = field(default_factory=list)
dataset_spec: Optional[Dict[str, Any]] = None
process_spec: Optional[Dict[str, Any]] = None
system_spec: Optional[Dict[str, Any]] = None
draft_plan: Optional[Dict[str, Any]] = None
draft_plan_path_hint: Optional[str] = None
last_retrieval: Dict[str, Any] = field(default_factory=dict)
last_inspected_dataset: Optional[str] = None
last_dataset_profile: Dict[str, Any] = field(default_factory=dict)
history: List[Dict[str, str]] = field(default_factory=list)
[文档]
class SessionToolRuntime:
"""Mutable runtime shared by all session tools."""
[文档]
def __init__(
self,
*,
state: SessionState,
verbose: bool = False,
event_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> None:
self.state = state
self.verbose = bool(verbose)
self._event_callback = event_callback
[文档]
def debug(self, message: str) -> None:
if not self.verbose:
return
print(f"[dj-agents][debug] {message}")
[文档]
def emit_event(self, event_type: str, **payload: Any) -> None:
if self._event_callback is None:
return
event: Dict[str, Any] = {
"type": event_type,
"timestamp": datetime.utcnow().isoformat(timespec="milliseconds") + "Z",
}
event.update(payload)
try:
self._event_callback(event)
except Exception:
return
[文档]
def invoke_tool(
self,
tool_name: str,
args: Dict[str, Any],
fn: Callable[[], Dict[str, Any]],
) -> Dict[str, Any]:
call_id = f"tool_{uuid4().hex[:10]}"
self.emit_event(
"tool_start",
tool=tool_name,
call_id=call_id,
args=args,
)
try:
payload = fn()
except Exception as exc:
self.emit_event(
"tool_end",
tool=tool_name,
call_id=call_id,
ok=False,
error_type="exception",
summary=str(exc),
)
raise
ok = True
error_type = None
summary = ""
result_preview = to_event_result_preview(payload)
failure_preview = ""
if isinstance(payload, dict):
ok = bool(payload.get("ok", True))
error_type = str(payload.get("error_type", "")).strip() or None
summary = str(payload.get("message", "")).strip()
failure_preview = self._build_failure_preview(payload, max_chars=320) if not ok else ""
if not summary and not ok:
summary = str(payload.get("stderr", "")).strip() or str(payload.get("stdout", "")).strip()
summary = summary[:240]
self.emit_event(
"tool_end",
tool=tool_name,
call_id=call_id,
ok=ok,
error_type=error_type,
summary=summary,
failure_preview=failure_preview,
result_preview=result_preview,
)
return payload
@staticmethod
def _build_failure_preview(payload: Dict[str, Any], *, max_chars: int = 320) -> str:
direct = str(payload.get("failure_preview", "")).strip()
if direct:
return truncate_text(direct, limit=max_chars)
validation_errors = payload.get("validation_errors")
if isinstance(validation_errors, list):
details = [str(item).strip() for item in validation_errors if str(item).strip()]
if details:
return truncate_text(
"validation_errors: " + "; ".join(details[:3]),
limit=max_chars,
)
error_message = str(payload.get("error_message", "")).strip()
if error_message:
return truncate_text(error_message, limit=max_chars)
stderr = str(payload.get("stderr", "")).strip()
if stderr:
return truncate_text(f"stderr: {stderr}", limit=max_chars)
stdout = str(payload.get("stdout", "")).strip()
if stdout:
return truncate_text(f"stdout: {stdout}", limit=max_chars)
message = str(payload.get("message", "")).strip()
if message:
return truncate_text(message, limit=max_chars)
error_type = str(payload.get("error_type", "")).strip()
if error_type:
return truncate_text(error_type, limit=max_chars)
return ""
[文档]
def invoke_text_tool(
self,
tool_name: str,
args: Dict[str, Any],
fn: Callable[[], Dict[str, Any]],
):
return to_text_response(self.invoke_tool(tool_name, args, fn))
[文档]
def context_payload(self) -> Dict[str, Any]:
draft = self.state.draft_plan if isinstance(self.state.draft_plan, dict) else None
retrieval = self.state.last_retrieval if isinstance(self.state.last_retrieval, dict) else {}
retrieval_candidates = retrieval.get("candidate_names", [])
if not isinstance(retrieval_candidates, list):
retrieval_candidates = []
dataset_spec = self.state.dataset_spec if isinstance(self.state.dataset_spec, dict) else {}
process_spec = self.state.process_spec if isinstance(self.state.process_spec, dict) else {}
system_spec = self.state.system_spec if isinstance(self.state.system_spec, dict) else {}
return {
"dataset_path": self.state.dataset_path,
"export_path": self.state.export_path,
"plan_path": self.state.plan_path,
"plan_intent": self.state.plan_intent,
"custom_operator_paths": list(self.state.custom_operator_paths),
"has_dataset_spec": bool(dataset_spec),
"dataset_spec_modality": str(((dataset_spec.get("binding") or {}).get("modality", ""))).strip() or None,
"has_process_spec": bool(process_spec),
"process_operator_count": len(process_spec.get("operators", [])) if isinstance(process_spec.get("operators", []), list) else 0,
"has_system_spec": bool(system_spec),
"draft_plan_id": str((draft or {}).get("plan_id", "")).strip() or None,
"draft_modality": str((draft or {}).get("modality", "")).strip() or None,
"draft_operator_count": len((draft or {}).get("operators", [])) if isinstance((draft or {}).get("operators"), list) else 0,
"draft_plan_path_hint": self.state.draft_plan_path_hint,
"last_retrieval_intent": str(retrieval.get("intent", "")).strip() or None,
"last_retrieval_candidate_count": len(retrieval_candidates),
"last_inspected_dataset": self.state.last_inspected_dataset,
"has_dataset_profile": bool(self.state.last_dataset_profile),
}
[文档]
def storage_root(self) -> Path:
root = str(self.state.working_dir or "./.djx").strip() or "./.djx"
return Path(root).expanduser()
[文档]
def next_session_plan_path(self) -> str:
session_dir = self.storage_root() / "session_plans"
session_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d%H%M%S")
return str(session_dir / f"session_plan_{ts}.yaml")
[文档]
def load_plan_dict(self, plan_path: str) -> Optional[Dict[str, Any]]:
try:
data = yaml.safe_load(Path(plan_path).expanduser().read_text(encoding="utf-8"))
except Exception:
return None
return data if isinstance(data, dict) else None
[文档]
def load_plan_model(self, plan_path: str) -> Optional[PlanModel]:
data = self.load_plan_dict(plan_path)
if not isinstance(data, dict):
return None
try:
return PlanModel.from_dict(data)
except Exception:
return None
[文档]
@staticmethod
def looks_like_plan_id(value: str) -> bool:
token = str(value or "").strip()
if not token:
return False
if "/" in token or "\\" in token:
return False
return token.startswith("plan_")
[文档]
def find_saved_plan_path_by_plan_id(self, plan_id: str) -> Optional[str]:
token = str(plan_id or "").strip()
if not token:
return None
root = self.storage_root()
candidates: List[Path] = []
if self.state.plan_path:
candidates.append(Path(self.state.plan_path).expanduser())
for base_dir in (root / "session_plans", root / "recipes"):
if base_dir.exists():
candidates.extend(sorted(base_dir.glob("*.yaml")))
seen: set[str] = set()
for path in candidates:
path_str = str(path)
if path_str in seen:
continue
seen.add(path_str)
model = self.load_plan_model(path_str)
if model is None:
continue
if str(model.plan_id).strip() == token:
return path_str
return None
[文档]
def current_draft_plan_model(self) -> Optional[PlanModel]:
payload = self.state.draft_plan
if not isinstance(payload, dict):
return None
try:
return PlanModel.from_dict(payload)
except Exception:
return None
__all__ = [
"SessionState",
"SessionToolRuntime",
"normalize_line_idx",
"parse_line_ranges",
"run_interruptible_subprocess",
"short_log",
"to_bool",
"to_int",
"to_string_list",
"to_text_response",
"truncate_text",
]