Source code for data_juicer_agents.tools.retrieve.retrieve_operators.backend

# -*- coding: utf-8 -*-
import json
import hashlib
import logging
import os
import os.path as osp
import time
from typing import Optional

from langchain_community.vectorstores import FAISS

VECTOR_INDEX_CACHE_PATH = osp.join(osp.dirname(__file__), "vector_index_cache")

# Global variable to cache the vector store
_cached_vector_store: Optional[FAISS] = None
_cached_tools_info: Optional[list] = None
_cached_content_hash: Optional[str] = None

# Global variable for agent lifecycle management
_global_dj_func_info: Optional[list] = None

RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsible for filtering the top {limit} most relevant tools from a large tool library based on user requirements. Execute the following steps:

# Requirement Analysis
    Carefully read the user's [requirement description], extract core keywords, functional objectives, usage scenarios, and technical requirements (such as real-time performance, data types, industry domains, etc.).

# Tool Matching
    Perform multi-dimensional matching based on the following tool attributes:
    - Tool name and functional description
    - Supported input/output formats
    - Applicable industry or scenario tags
    - Technical implementation principles (API, local deployment, AI model types)
    - Relevance ranking

# Use weighted scoring mechanism (example weights):
    - Functional match (40%)
    - Scenario compatibility (30%)
    - Technical compatibility (20%)
    - User rating/usage rate (10%)

# Deduplication and Optimization
    Exclude the following low-quality results:
    - Tools with duplicate functionality (keep only the best one)
    - Tools that cannot meet basic requirements
    - Tools missing critical parameter descriptions

# Constraints
    - Strictly control output to a maximum of {limit} tools
    - Refuse to speculate on unknown tool attributes
    - Maintain accuracy of domain expertise

# Output Format
    Return a JSON format TOP{limit} tool list containing:
    [
        {{
            "rank": 1,
            "tool_name": "Tool Name",
            "description": "Core functionality summary",
            "relevance_score": 98.7,
            "key_match": ["Matching keywords/features"]
        }}
    ]
    Output strictly in JSON array format, and only output the JSON array format tool list.
"""

def _get_content_hash(dj_func_info: list) -> str:
    """Get content hash of dj_func_info using SHA256"""
    try:
        # Convert to JSON string with sorted keys for consistent hashing
        content_str = json.dumps(dj_func_info, sort_keys=True, ensure_ascii=False)
        return hashlib.sha256(content_str.encode("utf-8")).hexdigest()
    except Exception as e:
        logging.warning(f"Failed to compute content hash: {e}")
        return ""

def _load_cached_index() -> bool:
    """Load cached vector index from disk"""
    global _cached_vector_store, _cached_tools_info, _cached_content_hash

    try:
        # Get current dj_func_info
        dj_func_info = get_dj_func_info()
        current_hash = _get_content_hash(dj_func_info)
        
        if not current_hash:
            return False

        # Ensure cache directory exists
        os.makedirs(VECTOR_INDEX_CACHE_PATH, exist_ok=True)

        index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
        metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")

        if not all(os.path.exists(p) for p in [index_path, metadata_path]):
            return False

        # Check if cached index matches current tools info content
        with open(metadata_path, "r") as f:
            metadata = json.load(f)

        cached_hash = metadata.get("content_hash", "")
        
        if current_hash != cached_hash:
            logging.info("Content hash mismatch, need to rebuild index")
            return False

        # Load cached data
        from langchain_community.embeddings import DashScopeEmbeddings

        embeddings = DashScopeEmbeddings(
            dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"),
            model="text-embedding-v3",
        )

        _cached_vector_store = FAISS.load_local(
            index_path,
            embeddings,
            allow_dangerous_deserialization=True,
        )
        
        _cached_tools_info = dj_func_info
        _cached_content_hash = cached_hash

        logging.info("Successfully loaded cached vector index")
        return True

    except Exception as e:
        logging.warning(f"Failed to load cached index: {e}")
        return False

def _save_cached_index():
    """Save vector index to disk cache"""
    global _cached_vector_store, _cached_content_hash

    try:
        # Ensure cache directory exists
        os.makedirs(VECTOR_INDEX_CACHE_PATH, exist_ok=True)

        index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
        metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")

        # Save vector store
        if _cached_vector_store:
            _cached_vector_store.save_local(index_path)

        # Save metadata
        metadata = {
            "content_hash": _cached_content_hash,
            "created_at": time.time(),
        }
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)

        logging.info("Successfully saved vector index to cache")

    except Exception as e:
        logging.error(f"Failed to save cached index: {e}")

[docs] def init_dj_func_info(): """Initialize dj_func_info at agent startup""" global _global_dj_func_info try: logging.info("Initializing dj_func_info for agent lifecycle...") from .catalog import dj_func_info _global_dj_func_info = dj_func_info logging.info(f"Successfully initialized dj_func_info with {len(_global_dj_func_info)} operators") return True except Exception as e: logging.error(f"Failed to initialize dj_func_info: {e}") return False
[docs] def refresh_dj_func_info(): """Refresh dj_func_info during agent runtime (for manual updates)""" global _global_dj_func_info, _cached_vector_store, _cached_tools_info, _cached_content_hash try: logging.info("Refreshing dj_func_info...") # Clear existing cache to force rebuild _cached_vector_store = None _cached_tools_info = None _cached_content_hash = None # Reload dj_func_info import importlib from . import catalog as create_dj_func_info from data_juicer import ops importlib.reload(ops) importlib.reload(create_dj_func_info) dj_func_info = get_dj_func_info() _global_dj_func_info = dj_func_info logging.info(f"Successfully refreshed dj_func_info with {len(_global_dj_func_info)} operators") return True except Exception as e: import traceback traceback.print_exc() logging.error(f"Failed to refresh dj_func_info: {e}") return False
[docs] def get_dj_func_info(): """Get current dj_func_info (lifecycle-aware)""" global _global_dj_func_info if _global_dj_func_info is None: logging.warning("dj_func_info not initialized, initializing now...") if not init_dj_func_info(): # Fallback to direct import if initialization fails logging.warning("Falling back to direct import of dj_func_info") from .catalog import dj_func_info return dj_func_info return _global_dj_func_info
[docs] async def retrieve_ops_lm(user_query, limit=20): """Tool retrieval using language model - returns list of tool names""" items = await retrieve_ops_lm_items(user_query, limit=limit) return [ str(item.get("tool_name", "")).strip() for item in items if str(item.get("tool_name", "")).strip() ]
[docs] async def retrieve_ops_lm_items(user_query, limit=20): """Tool retrieval using language model - returns validated tool metadata.""" dj_func_info = get_dj_func_info() tool_descriptions = [ f"{t['class_name']}: {t['class_desc']}" for t in dj_func_info ] tools_string = "\n".join(tool_descriptions) from agentscope.model import DashScopeChatModel from agentscope.message import Msg from agentscope.formatter import DashScopeChatFormatter model = DashScopeChatModel( model_name="qwen-turbo", api_key=os.environ.get("DASHSCOPE_API_KEY"), stream=False, ) formatter = DashScopeChatFormatter() # Update retrieval prompt to use the specified limit retrieval_prompt_with_limit = RETRIEVAL_PROMPT.format(limit=limit) user_prompt = ( retrieval_prompt_with_limit + """ User requirement description: {user_query} Available tools: {tools_string} """.format( user_query=user_query, tools_string=tools_string, ) ) msgs = [ Msg(name="user", role="user", content=user_prompt), ] formatted_msgs = await formatter.format(msgs) response = await model(formatted_msgs) msg = Msg(name="assistant", role="assistant", content=response.content) retrieved_tools_text = msg.get_text_content() retrieved_tools = json.loads(retrieved_tools_text) # Extract tool names and validate they exist valid_tools = [] for tool_info in retrieved_tools: if not isinstance(tool_info, dict) or "tool_name" not in tool_info: logging.warning(f"Invalid tool info format: {tool_info}") continue tool_name = str(tool_info["tool_name"]).strip() if not tool_name: continue # Verify tool exists in dj_func_info tool_exists = any(t["class_name"] == tool_name for t in dj_func_info) if not tool_exists: logging.error(f"Tool not found: `{tool_name}`, skipping!") continue valid_tools.append( { "tool_name": tool_name, "description": str(tool_info.get("description", "")).strip(), "relevance_score": tool_info.get("relevance_score"), "key_match": ( [ str(item).strip() for item in tool_info.get("key_match", []) if str(item).strip() ] if isinstance(tool_info.get("key_match"), list) else [] ), } ) return valid_tools
def _build_vector_index(): """Build vector index using fresh dj_func_info""" global _cached_vector_store, _cached_tools_info, _cached_content_hash dj_func_info = get_dj_func_info() tool_descriptions = [ f"{t['class_name']}: {t['class_desc']}" for t in dj_func_info ] from langchain_community.embeddings import DashScopeEmbeddings embeddings = DashScopeEmbeddings( dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"), model="text-embedding-v3", ) metadatas = [{"index": i} for i in range(len(tool_descriptions))] vector_store = FAISS.from_texts( tool_descriptions, embeddings, metadatas=metadatas, ) # Cache the results _cached_vector_store = vector_store _cached_tools_info = dj_func_info _cached_content_hash = _get_content_hash(dj_func_info) # Save to disk cache _save_cached_index() logging.info("Successfully built and cached vector index")
[docs] def retrieve_ops_vector(user_query, limit=20): """Tool retrieval using vector search with smart caching - returns list of tool names""" global _cached_vector_store, _cached_tools_info # Try to load from cache first, only rebuild if content changed if not _load_cached_index(): logging.info("Building new vector index...") _build_vector_index() # Perform similarity search retrieved_tools = _cached_vector_store.similarity_search( user_query, k=limit, ) retrieved_indices = [doc.metadata["index"] for doc in retrieved_tools] # Extract tool names from retrieved indices using cached tools info tool_names = [] for raw_idx in retrieved_indices: tool_info = _cached_tools_info[raw_idx] tool_names.append(tool_info["class_name"]) return tool_names
def _trace_step(backend: str, status: str, error: str = "") -> dict: payload = { "backend": backend, "status": status, } error_text = str(error or "").strip() if error_text: payload["error"] = error_text return payload
[docs] async def retrieve_ops_with_meta( user_query: str, limit: int = 20, mode: str = "auto", ) -> dict: """Tool retrieval with source/trace metadata.""" if mode == "llm": try: items = await retrieve_ops_lm_items(user_query, limit=limit) names = [ str(item.get("tool_name", "")).strip() for item in items if str(item.get("tool_name", "")).strip() ] return { "names": names, "source": "llm" if names else "", "trace": [_trace_step("llm", "success" if names else "empty")], "items": items, } except Exception as exc: logging.error(f"LLM retrieval failed: {str(exc)}") return { "names": [], "source": "", "trace": [_trace_step("llm", "failed", str(exc))], "items": [], } if mode == "vector": try: names = retrieve_ops_vector(user_query, limit=limit) return { "names": names, "source": "vector" if names else "", "trace": [_trace_step("vector", "success" if names else "empty")], "items": [], } except Exception as exc: logging.error(f"Vector retrieval failed: {str(exc)}") return { "names": [], "source": "", "trace": [_trace_step("vector", "failed", str(exc))], "items": [], } if mode == "auto": trace = [] try: items = await retrieve_ops_lm_items(user_query, limit=limit) names = [ str(item.get("tool_name", "")).strip() for item in items if str(item.get("tool_name", "")).strip() ] trace.append(_trace_step("llm", "success" if names else "empty")) if names: return { "names": names, "source": "llm", "trace": trace, "items": items, } except Exception as exc: logging.warning( "LLM retrieval failed in auto mode (%s), falling back to vector retrieval.", str(exc), ) trace.append(_trace_step("llm", "failed", str(exc))) try: names = retrieve_ops_vector(user_query, limit=limit) trace.append(_trace_step("vector", "success" if names else "empty")) if names: return { "names": names, "source": "vector", "trace": trace, "items": [], } except Exception as fallback_exc: logging.error( "Tool retrieval failed in auto mode, vector fallback also failed: %s", str(fallback_exc), ) trace.append(_trace_step("vector", "failed", str(fallback_exc))) return { "names": [], "source": "", "trace": trace, "items": [], } raise ValueError( f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'", )
[docs] async def retrieve_ops( user_query: str, limit: int = 20, mode: str = "auto", ) -> list: """ Tool retrieval with configurable mode Args: user_query: User query string limit: Maximum number of tools to retrieve mode: Retrieval mode - "llm", "vector", or "auto" (default: "auto") - "llm": Use language model only - "vector": Use vector search only - "auto": Try LLM first, fallback to vector search on failure Returns: List of tool names """ meta = await retrieve_ops_with_meta( user_query=user_query, limit=limit, mode=mode, ) return list(meta.get("names", []))
if __name__ == "__main__": import asyncio user_query = ( "Clean special characters from text and filter samples with excessive length. Mask sensitive information and filter unsafe content including adult/terror-related terms." + "Additionally, filter out small images, perform image tagging, and remove duplicate images." ) # Test different modes print("=== Testing LLM mode ===") tool_names_llm = asyncio.run( retrieve_ops(user_query, limit=10, mode="llm"), ) print("Retrieved tool names (LLM):") print(tool_names_llm) print("\n=== Testing Vector mode ===") tool_names_vector = asyncio.run( retrieve_ops(user_query, limit=10, mode="vector"), ) print("Retrieved tool names (Vector):") print(tool_names_vector) print("\n=== Testing Auto mode (default) ===") tool_names_auto = asyncio.run( retrieve_ops(user_query, limit=10, mode="auto"), ) print("Retrieved tool names (Auto):") print(tool_names_auto)