Source code for data_juicer_agents.core.tool.contracts
# -*- coding: utf-8 -*-
"""Runtime-agnostic tool contracts."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Type
from pydantic import BaseModel
ToolEffect = Literal["read", "write", "execute", "external"]
ToolConfirmation = Literal["none", "recommended", "required"]
[docs]
@dataclass(frozen=True)
class ToolContext:
"""Execution context shared by all tool runtimes."""
working_dir: str = "./.djx"
env: Dict[str, str] = field(default_factory=dict)
artifacts_dir: Optional[str] = None
runtime_values: Dict[str, Any] = field(default_factory=dict)
[docs]
def resolve_artifacts_dir(self) -> Path:
raw = str(self.artifacts_dir or self.working_dir or "./.djx").strip() or "./.djx"
return Path(raw).expanduser()
[docs]
@dataclass(frozen=True)
class ToolArtifact:
"""Named artifact produced by a tool."""
path: str
description: str = ""
kind: str = "file"
label: str = ""
[docs]
def to_dict(self) -> Dict[str, Any]:
return {
"path": self.path,
"description": self.description,
"kind": self.kind,
"label": self.label,
}
[docs]
@dataclass
class ToolResult:
"""Normalized tool execution result."""
ok: bool
summary: str = ""
data: Dict[str, Any] = field(default_factory=dict)
artifacts: List[ToolArtifact] = field(default_factory=list)
error_type: str = ""
error_message: str = ""
next_actions: List[str] = field(default_factory=list)
[docs]
@classmethod
def success(
cls,
*,
summary: str = "",
data: Optional[Dict[str, Any]] = None,
artifacts: Optional[Iterable[ToolArtifact]] = None,
) -> "ToolResult":
return cls(
ok=True,
summary=summary,
data=dict(data or {}),
artifacts=list(artifacts or []),
)
[docs]
@classmethod
def failure(
cls,
*,
summary: str,
error_type: str,
error_message: str = "",
data: Optional[Dict[str, Any]] = None,
next_actions: Optional[Iterable[str]] = None,
) -> "ToolResult":
return cls(
ok=False,
summary=summary,
data=dict(data or {}),
error_type=str(error_type or "tool_failed"),
error_message=str(error_message or "").strip(),
next_actions=list(next_actions or []),
)
[docs]
def to_payload(self, *, action: str | None = None) -> Dict[str, Any]:
payload = dict(self.data)
payload.setdefault("ok", bool(self.ok))
if action and "action" not in payload:
payload["action"] = action
if self.summary and "message" not in payload:
payload["message"] = self.summary
if self.error_type and "error_type" not in payload:
payload["error_type"] = self.error_type
if self.error_message and "error_message" not in payload:
payload["error_message"] = self.error_message
if self.next_actions and "next_actions" not in payload:
payload["next_actions"] = list(self.next_actions)
if self.artifacts and "artifacts" not in payload:
payload["artifacts"] = [item.to_dict() for item in self.artifacts]
return payload
ToolExecutor = Callable[[ToolContext, BaseModel], ToolResult]
[docs]
@dataclass(frozen=True)
class ToolSpec:
"""Definition of one atomic tool."""
name: str
description: str
input_model: Type[BaseModel]
output_model: Type[BaseModel] | None
executor: ToolExecutor
tags: Tuple[str, ...] = ()
effects: ToolEffect = "read"
confirmation: ToolConfirmation = "none"
[docs]
def execute(self, ctx: ToolContext, raw_input: BaseModel | Dict[str, Any]) -> ToolResult:
if isinstance(raw_input, self.input_model):
parsed = raw_input
else:
parsed = self.input_model.model_validate(raw_input)
return self.executor(ctx, parsed)
__all__ = [
"ToolArtifact",
"ToolConfirmation",
"ToolContext",
"ToolEffect",
"ToolExecutor",
"ToolResult",
"ToolSpec",
]