data_juicer_agents.capabilities.plan.service 源代码

# -*- coding: utf-8 -*-
"""Minimal hard orchestration for `djx plan`."""

from __future__ import annotations

import os
from typing import Any, Dict, Iterable

from data_juicer_agents.tools.context import inspect_dataset_schema
from data_juicer_agents.tools.retrieve import (
    extract_candidate_names,
    retrieve_operator_candidates,
)
from data_juicer_agents.tools.plan import (
    PlanModel,
    assemble_plan,
    build_dataset_spec,
    build_process_spec,
    build_system_spec,
    plan_validate,
)

from .generator import ProcessOperatorGenerator


PLANNER_MODEL_NAME = os.environ.get("DJA_PLANNER_MODEL", "qwen3-max-2026-01-23")


def _normalize_candidate_payload(raw_candidates: Any) -> Dict[str, Any] | None:
    if not isinstance(raw_candidates, dict):
        return None
    if not isinstance(raw_candidates.get("candidates", []), list):
        return None
    return raw_candidates


[文档] class PlanOrchestrator: """Fixed orchestration for CLI plan generation."""
[文档] def __init__( self, *, planner_model_name: str | None = None, llm_api_key: str | None = None, llm_base_url: str | None = None, llm_thinking: bool | None = None, ): self.generator = ProcessOperatorGenerator( model_name=str(planner_model_name or PLANNER_MODEL_NAME).strip() or PLANNER_MODEL_NAME, api_key=llm_api_key, base_url=llm_base_url, thinking=llm_thinking, )
def _resolve_retrieval( self, *, user_intent: str, dataset_path: str, top_k: int = 5, mode: str = "auto", retrieved_candidates: Dict[str, Any] | None = None, ) -> Dict[str, Any]: provided = _normalize_candidate_payload(retrieved_candidates) if provided is not None: return dict(provided) return retrieve_operator_candidates( intent=user_intent, top_k=top_k, mode=mode, dataset_path=dataset_path, )
[文档] def generate_plan( self, *, user_intent: str, dataset_path: str, export_path: str, custom_operator_paths: Iterable[Any] | None = None, retrieved_candidates: Dict[str, Any] | None = None, retrieval_top_k: int = 5, retrieval_mode: str = "auto", ) -> Dict[str, Any]: retrieval = self._resolve_retrieval( user_intent=user_intent, dataset_path=dataset_path, top_k=retrieval_top_k, mode=retrieval_mode, retrieved_candidates=retrieved_candidates, ) dataset_profile = inspect_dataset_schema(dataset_path, sample_size=20) dataset_result = build_dataset_spec( user_intent=user_intent, dataset_path=dataset_path, export_path=export_path, dataset_profile=dataset_profile, ) if not dataset_result.get("ok"): raise ValueError("dataset spec build failed: " + "; ".join(dataset_result.get("validation_errors", []) or [str(dataset_result.get("message", "unknown error"))])) operator_payload = self.generator.generate( user_intent=user_intent, retrieval_payload=retrieval, dataset_spec=dataset_result["dataset_spec"], dataset_profile=dataset_profile, ) process_result = build_process_spec( operators=operator_payload.get("operators", []), ) if not process_result.get("ok"): raise ValueError("process spec build failed: " + "; ".join(process_result.get("validation_errors", []) or [str(process_result.get("message", "unknown error"))])) system_result = build_system_spec( custom_operator_paths=custom_operator_paths, ) if not system_result.get("ok"): raise ValueError("system spec build failed: " + "; ".join(system_result.get("validation_errors", []) or [str(system_result.get("message", "unknown error"))])) assembled = assemble_plan( user_intent=user_intent, dataset_spec=dataset_result["dataset_spec"], process_spec=process_result["process_spec"], system_spec=system_result["system_spec"], approval_required=True, ) if not assembled.get("ok"): raise ValueError("assemble_plan failed: " + str(assembled.get("message", "unknown error"))) validation = plan_validate(plan_payload=assembled["plan"]) if not validation.get("ok"): raise ValueError("plan validation failed: " + "; ".join(validation.get("validation_errors", []) or [str(validation.get("message", "unknown error"))])) plan = PlanModel.from_dict(assembled["plan"]) return { "plan": plan, "dataset_spec": dataset_result["dataset_spec"], "process_spec": process_result["process_spec"], "system_spec": system_result["system_spec"], "retrieval": retrieval, "planning_meta": { "planner_model": self.generator.model_name, "retrieval_source": str(retrieval.get("retrieval_source", "")).strip() or "unknown", "retrieval_candidate_count": str(len(extract_candidate_names(retrieval))), }, "validation": validation, }
__all__ = ["PlanOrchestrator"]