Source code for data_juicer_agents.core.tool.registry
# -*- coding: utf-8 -*-
"""Registry for runtime-agnostic tool definitions."""
from __future__ import annotations
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Dict, List, Sequence, Tuple
from .contracts import ToolSpec
from .profiles import groups_for_tool_profile, tool_is_excluded_from_profile
[docs]
@dataclass
class ToolRegistry:
"""Container of tool definitions."""
_tools: Dict[str, ToolSpec] = field(default_factory=dict)
[docs]
def register(self, spec: ToolSpec) -> None:
if spec.name in self._tools:
raise ValueError(f"tool already registered: {spec.name}")
self._tools[spec.name] = spec
[docs]
def get(self, name: str) -> ToolSpec:
spec = self._tools.get(str(name).strip())
if spec is None:
raise KeyError(f"tool not found: {name}")
return spec
[docs]
def list(self, *, tags: Sequence[str] | None = None) -> List[ToolSpec]:
specs = list(self._tools.values())
if not tags:
return specs
expected = {str(tag).strip() for tag in tags if str(tag).strip()}
if not expected:
return specs
return [spec for spec in specs if expected.intersection(spec.tags)]
[docs]
def list_tools(self, *, tags: Sequence[str] | None = None) -> List[ToolSpec]:
return self.list(tags=tags)
def _registry_cache_key(
*,
profile: str | None = None,
groups: Sequence[str] | None = None,
) -> Tuple[str, ...] | None:
if groups is not None:
normalized = tuple(str(item or "").strip() for item in groups if str(item or "").strip())
return normalized
return groups_for_tool_profile(profile)
@lru_cache(maxsize=None)
def _build_registry_cached(group_names: Tuple[str, ...] | None) -> ToolRegistry:
from data_juicer_agents.core.tool.catalog import load_tool_specs
specs = load_tool_specs(group_names)
registry = ToolRegistry()
for spec in specs:
registry.register(spec)
return registry
[docs]
def build_default_tool_registry(
*,
profile: str | None = None,
groups: Sequence[str] | None = None,
) -> ToolRegistry:
return _build_registry_cached(_registry_cache_key(profile=profile, groups=groups))
[docs]
def get_tool_spec(name: str, *, profile: str | None = None) -> ToolSpec:
spec = build_default_tool_registry(profile=profile).get(name)
if tool_is_excluded_from_profile(spec.name, profile):
raise KeyError(f"tool not found: {name}")
return spec
[docs]
def list_tool_specs(
*,
tags: Sequence[str] | None = None,
profile: str | None = None,
) -> List[ToolSpec]:
specs = build_default_tool_registry(profile=profile).list(tags=tags)
return [spec for spec in specs if not tool_is_excluded_from_profile(spec.name, profile)]
__all__ = [
"ToolRegistry",
"build_default_tool_registry",
"get_tool_spec",
"list_tool_specs",
]