Source code for data_juicer_agents.tools.retrieve.retrieve_operators.logic

# -*- coding: utf-8 -*-
"""Structured operator retrieval service for DJX and session tools."""

from __future__ import annotations

import asyncio
import os
import re
import threading
from typing import Any, Dict, Iterable, List, Tuple

from .operator_registry import (
    get_available_operator_names,
    resolve_operator_name,
)


_WORD_RE = re.compile(r"[a-zA-Z0-9_]+")
_OP_TYPES = {
    "mapper",
    "filter",
    "deduplicator",
    "selector",
    "grouper",
    "aggregator",
    "pipeline",
    "formatter",
}


def _load_op_retrieval_funcs():
    try:
        from .backend import (
            get_dj_func_info,
            init_dj_func_info,
            retrieve_ops,
            retrieve_ops_with_meta,
        )

        return get_dj_func_info, init_dj_func_info, retrieve_ops, retrieve_ops_with_meta
    except Exception:
        return None


def _tokenize(text: str) -> List[str]:
    return [token.lower() for token in _WORD_RE.findall(str(text or ""))]


def _op_type(name: str) -> str:
    parts = str(name or "").split("_")
    if not parts:
        return "unknown"
    maybe = parts[-1].lower()
    if maybe in _OP_TYPES:
        return maybe
    if "dedup" in str(name or "").lower():
        return "deduplicator"
    return "unknown"


def _to_float_score(value: float) -> float:
    if value < 0:
        return 0.0
    if value > 100:
        return 100.0
    return round(value, 2)


def _keyword_score(intent: str, operator_name: str, description: str) -> float:
    intent_tokens = set(_tokenize(intent))
    if not intent_tokens:
        return 0.0

    name_tokens = set(_tokenize(operator_name))
    desc_tokens = set(_tokenize(description))

    name_overlap = len(intent_tokens.intersection(name_tokens))
    desc_overlap = len(intent_tokens.intersection(desc_tokens))
    contains_bonus = 1.0 if any(tok in operator_name.lower() for tok in intent_tokens) else 0.0

    # Weighted to prefer exact-ish operator name matches.
    raw = name_overlap * 16.0 + desc_overlap * 4.0 + contains_bonus * 8.0
    return _to_float_score(raw)


def _trace_entry(backend: str, status: str, error: str = "", reason: str = "") -> Dict[str, str]:
    payload = {
        "backend": str(backend or "").strip(),
        "status": str(status or "").strip(),
    }
    error_text = str(error or "").strip()
    reason_text = str(reason or "").strip()
    if error_text:
        payload["error"] = error_text
    if reason_text:
        payload["reason"] = reason_text
    return payload


def _safe_async_retrieve(intent: str, top_k: int, mode: str) -> Dict[str, Any]:
    api_key = os.environ.get("DASHSCOPE_API_KEY") or os.environ.get("MODELSCOPE_API_TOKEN")
    if not api_key:
        return {
            "names": [],
            "source": "lexical",
            "trace": [_trace_entry("lexical", "selected", reason="missing_api_key")],
        }

    funcs = _load_op_retrieval_funcs()
    if funcs is None:
        return {
            "names": [],
            "source": "lexical",
            "trace": [_trace_entry("lexical", "selected", reason="retrieval_backend_unavailable")],
        }
    _, _, retrieve_ops, retrieve_ops_with_meta = funcs

    def _normalize_names(names: Any) -> List[str]:
        if not isinstance(names, list):
            return []
        return [str(item) for item in names if str(item).strip()]

    def _normalize_meta(payload: Any) -> Dict[str, Any]:
        if isinstance(payload, dict):
            return {
                "names": _normalize_names(payload.get("names")),
                "source": str(payload.get("source", "")).strip(),
                "trace": list(payload.get("trace", [])) if isinstance(payload.get("trace"), list) else [],
                "items": list(payload.get("items", [])) if isinstance(payload.get("items"), list) else [],
            }
        return {
            "names": _normalize_names(payload),
            "source": "",
            "trace": [],
            "items": [],
        }

    def _run_in_new_thread() -> Dict[str, Any]:
        payload: Dict[str, Any] = {}

        def _worker() -> None:
            loop = asyncio.new_event_loop()
            try:
                payload["meta"] = loop.run_until_complete(
                    retrieve_ops_with_meta(intent, limit=top_k, mode=mode)
                )
            except Exception as exc:
                payload["error"] = exc
            finally:
                loop.close()

        thread = threading.Thread(target=_worker, daemon=True)
        thread.start()
        thread.join()
        if "error" in payload:
            raise payload["error"]
        return _normalize_meta(payload.get("meta"))

    try:
        asyncio.get_running_loop()
        meta = _run_in_new_thread()
        if meta.get("names"):
            return meta
        return meta
    except RuntimeError:
        meta = _normalize_meta(
            asyncio.run(retrieve_ops_with_meta(intent, limit=top_k, mode=mode))
        )
        if meta.get("names"):
            return meta
        return meta
    except Exception as exc:
        return {
            "names": [],
            "source": "",
            "trace": [_trace_entry(mode, "failed", str(exc))],
            "items": [],
        }


def _lexical_fallback(intent: str, info_rows: List[Dict[str, Any]], top_k: int) -> List[str]:
    scored: List[Tuple[float, str]] = []
    for row in info_rows:
        name = str(row.get("class_name", "")).strip()
        if not name:
            continue
        score = _keyword_score(intent, name, str(row.get("class_desc", "")))
        scored.append((score, name))

    scored.sort(key=lambda item: (item[0], item[1]), reverse=True)
    selected = [name for score, name in scored if score > 0][:top_k]
    if selected:
        return selected
    # If no keyword overlap, still provide deterministic top-k list.
    return [name for _, name in scored[:top_k]]


def _build_candidate_row(
    rank: int,
    name: str,
    intent: str,
    info_map: Dict[str, Dict[str, Any]],
    llm_item: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
    row = info_map.get(name, {})
    desc = str(row.get("class_desc", "")).strip()
    args_text = str(row.get("arguments", "")).strip()
    args_lines = [line.strip() for line in args_text.splitlines() if line.strip()]
    llm_desc = str((llm_item or {}).get("description", "")).strip()
    llm_score = (llm_item or {}).get("relevance_score")
    key_match = (llm_item or {}).get("key_match")
    if not isinstance(key_match, list):
        key_match = []
    if isinstance(llm_score, (int, float)):
        relevance_score = _to_float_score(float(llm_score))
        score_source = "llm"
    else:
        relevance_score = _keyword_score(intent, name, desc)
        score_source = "keyword"
    return {
        "rank": rank,
        "operator_name": name,
        "operator_type": _op_type(name),
        "description": llm_desc or desc,
        "relevance_score": relevance_score,
        "score_source": score_source,
        "key_match": [str(item).strip() for item in key_match if str(item).strip()],
        "arguments_preview": args_lines[:4],
    }


[docs] def retrieve_operator_candidates( intent: str, top_k: int = 10, mode: str = "auto", dataset_path: str | None = None, ) -> Dict[str, Any]: """Retrieve operators and return a structured payload for CLI/agent usage.""" top_k = int(top_k) if isinstance(top_k, int) or str(top_k).isdigit() else 10 if top_k <= 0: top_k = 10 top_k = min(top_k, 200) info_rows: List[Dict[str, Any]] = [] funcs = _load_op_retrieval_funcs() if funcs is not None: get_dj_func_info, init_dj_func_info, _retrieve_ops, _retrieve_ops_with_meta = funcs try: init_dj_func_info() info_rows = [ item for item in get_dj_func_info() if isinstance(item, dict) and str(item.get("class_name", "")).strip() ] except Exception: info_rows = [] info_map = { str(item.get("class_name", "")).strip(): item for item in info_rows } retrieve_meta = _safe_async_retrieve(intent, top_k=top_k, mode=mode) retrieved_names = list(retrieve_meta.get("names", [])) retrieval_source = str(retrieve_meta.get("source", "")).strip() retrieval_trace = list(retrieve_meta.get("trace", [])) llm_item_map = {} if retrieval_source == "llm": for item in retrieve_meta.get("items", []): if not isinstance(item, dict): continue tool_name = str(item.get("tool_name", "")).strip() if tool_name: llm_item_map[tool_name] = item if not retrieved_names: retrieved_names = _lexical_fallback(intent, info_rows=info_rows, top_k=top_k) retrieval_source = "lexical" retrieval_trace.append(_trace_entry("lexical", "selected", reason="fallback_after_remote_empty_or_failed")) available_ops = get_available_operator_names() normalized_names: List[str] = [] seen = set() for raw_name in retrieved_names: name = resolve_operator_name(raw_name, available_ops=available_ops) if name and name not in seen: seen.add(name) normalized_names.append(name) if not normalized_names and info_rows: normalized_names = _lexical_fallback(intent, info_rows=info_rows, top_k=top_k) candidates = [ _build_candidate_row( idx, name, intent=intent, info_map=info_map, llm_item=llm_item_map.get(name), ) for idx, name in enumerate(normalized_names[:top_k], start=1) ] notes: List[str] = [] if not candidates: notes.append("No operator candidates were found from retrieval.") return { "ok": True, "intent": intent, "top_k": top_k, "mode": mode, "retrieval_source": retrieval_source, "retrieval_trace": retrieval_trace, "candidate_count": len(candidates), "gap_detected": len(candidates) == 0, "candidates": candidates, "notes": notes, }
[docs] def extract_candidate_names(payload: Dict[str, Any]) -> List[str]: names: List[str] = [] for item in payload.get("candidates", []) if isinstance(payload, dict) else []: if not isinstance(item, dict): continue name = str(item.get("operator_name", "")).strip() if name: names.append(name) return names