Source code for data_juicer.ops.mapper.agent_bad_case_signal_mapper

# Copyright 2025 The Data-Juicer Authors. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Fuse deterministic + optional LLM-eval signals into a conservative bad-case
# triage for human review (precision-oriented by default).

from __future__ import annotations

import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys, StatsKeys

OP_NAME = "agent_bad_case_signal_mapper"
logger = logging.getLogger(__name__)

# Lightweight 1–5 turn/trace judges (dialog_* / agent_* quality LLM mappers).
_DIALOG_QUALITY_SCORE_META_KEYS = (
    MetaKeys.dialog_memory_consistency,
    MetaKeys.dialog_coreference,
    MetaKeys.dialog_topic_shift,
    MetaKeys.dialog_error_recovery,
    MetaKeys.dialog_clarification_quality,
    MetaKeys.dialog_proactivity,
    MetaKeys.dialog_non_repetition,
    MetaKeys.agent_trace_coherence,
    MetaKeys.agent_tool_relevance,
)

_calibration_missing_path_warned: Optional[str] = None


def _load_calibration_json(path: str) -> Optional[Dict[str, Any]]:
    if not path or not str(path).strip():
        return None
    ap = os.path.abspath(os.path.expanduser(str(path).strip()))
    if not os.path.isfile(ap):
        global _calibration_missing_path_warned
        if _calibration_missing_path_warned != ap:
            logger.warning(
                "agent_bad_case_signal_mapper: calibration_json_path is not a file "
                "(%s); auto thresholds disabled for this run.",
                ap,
            )
            _calibration_missing_path_warned = ap
        return None
    try:
        with open(ap, encoding="utf-8") as f:
            data = json.load(f)
    except (OSError, json.JSONDecodeError) as e:
        logger.warning(
            "agent_bad_case_signal_mapper: failed to load calibration JSON %s: %s",
            ap,
            e,
        )
        return None
    if not isinstance(data, dict):
        return None
    return data


def _normalize_recommendation(record: Any) -> str:
    # record may be a JSON string (serialized by _normalize_record) or a dict
    if isinstance(record, str):
        if not record:
            return ""
        try:
            import json as _json

            record = _json.loads(record)
        except Exception:
            return ""
    if not isinstance(record, dict):
        return ""
    r = record.get("recommendation")
    if isinstance(r, (list, tuple, np.ndarray)) and len(r) > 0:
        r = r[0]
    if hasattr(r, "item"):
        try:
            r = r.item()
        except Exception:
            pass
    if r is None:
        return ""
    return str(r).strip().strip('"').lower()


[docs] @TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class AgentBadCaseSignalMapper(Mapper): """Attach structured bad-case *signals* and a conservative *tier* to each sample. Design goal: **precision over recall** for the ``high_precision`` tier. **Upstream coverage** (when present in the pipeline): - ``meta``: ``tool_*``, ``usage`` tokens, ``primary_tool_type``, ``dominant_tool_types``, ``dialog_intent_labels``, ``dialog_topic_labels``, ``dialog_sentiment_labels``, ``agent_turn_count``, lineage keys. - ``stats``: ``llm_analysis_*``, ``llm_quality_*``, ``llm_difficulty_*``, ``text_len``, ``num_words``, ``perplexity``, ``lang_score``. - ``meta``: optional ``dialog_*`` / ``agent_trace_coherence`` / ``agent_tool_relevance`` records (1–5 scores from lightweight LLM mappers). Each signal group can be toggled via constructor flags. ``high`` weight feeds ``high_precision`` tier (with config); ``medium`` feeds ``watchlist`` only. **Tool-heavy agent runs:** use ``min_tool_fail_count_for_signal`` to avoid treating a single exploratory tool error (common before recovery) as strong bad-case evidence. **P-percentile calibration** (optional): set ``auto_calibrate_thresholds`` and ``calibration_json_path`` to a JSON file produced by ``demos/agent/scripts/compute_percentile_thresholds.py --write-calibration``. Per-sample thresholds merge ``default`` with ``by_request_model`` using ``meta.agent_request_model``. When ``calibration_manual_overrides_auto`` is true (default), explicit ``max_total_tokens`` / ``max_latency_ms`` / perplexity settings in YAML override the file; set it false to prefer calibration. """
[docs] def __init__( self, query_key: str = "query", response_key: str = "response", # --- tool path --- signal_on_tool_fail: bool = True, # Agent runs often include exploratory tool errors (e.g. wrong path) # before recovery; require this many pattern-matched **error** tool # messages before emitting ``tool_message_error_pattern``. min_tool_fail_count_for_signal: int = 1, signal_on_low_tool_success_ratio: bool = True, tool_success_ratio_max_for_signal: float = 0.499, min_tool_rounds_for_ratio_signal: int = 2, # --- empty response heuristic --- signal_on_suspect_empty_response: bool = True, min_query_len_for_empty_check: int = 80, max_response_len_for_empty_check: int = 20, # --- cost / latency (optional absolute thresholds) --- max_total_tokens: Optional[int] = None, max_latency_ms: Optional[int] = None, # --- optional P-percentile calibration (see demos/agent/scripts/compute_percentile_thresholds.py) --- calibration_json_path: Optional[str] = None, auto_calibrate_thresholds: bool = False, # When True (default): explicit max_* / perplexity_high_threshold in YAML win over JSON. calibration_manual_overrides_auto: bool = True, # If JSON row has perplexity_high_threshold, enable that signal even when # signal_on_high_perplexity is False (unless manual override supplies threshold). auto_enable_perplexity_from_calibration: bool = True, # --- llm_analysis_filter (agent scenario eval) --- signal_on_llm_analysis_low: bool = True, llm_analysis_score_max_for_bad: float = 0.28, llm_analysis_discard_must_be_strict: bool = True, high_precision_llm_analysis_discard_threshold: float = 0.24, # --- llm_quality_score_filter (reply quality dims) --- signal_on_llm_text_quality_low: bool = True, llm_text_quality_score_max_for_bad: float = 0.28, llm_text_quality_discard_must_be_strict: bool = True, high_precision_llm_text_quality_discard_threshold: float = 0.24, # --- dialog tags (weak signals → medium only) --- signal_on_negative_sentiment_hint: bool = False, negative_sentiment_substrings: Optional[List[str]] = None, # --- text stats from filters --- signal_on_high_perplexity: bool = False, perplexity_high_threshold: float = 800.0, # --- difficulty × quality conjunction (off by default) --- signal_hard_query_poor_reply: bool = False, hard_query_difficulty_min: float = 0.72, poor_reply_quality_max: float = 0.36, # --- tier composition --- high_precision_on_tool_fail_alone: bool = True, min_medium_signals_for_watchlist: int = 2, # --- dialog / trace 1–5 meta (no extra LLM; reads upstream mapper output) --- signal_on_low_dialog_quality_meta: bool = True, dialog_quality_low_score_threshold: float = 2.0, min_dialog_quality_low_axes_for_signal: int = 1, **kwargs, ): super().__init__(**kwargs) self.query_key = query_key self.response_key = response_key self.signal_on_tool_fail = signal_on_tool_fail self.min_tool_fail_count_for_signal = max(1, int(min_tool_fail_count_for_signal)) self.signal_on_low_tool_success_ratio = signal_on_low_tool_success_ratio self.tool_success_ratio_max_for_signal = tool_success_ratio_max_for_signal self.min_tool_rounds_for_ratio_signal = min_tool_rounds_for_ratio_signal self.signal_on_suspect_empty_response = signal_on_suspect_empty_response self.min_query_len_for_empty_check = min_query_len_for_empty_check self.max_response_len_for_empty_check = max_response_len_for_empty_check self.max_total_tokens = max_total_tokens self.max_latency_ms = max_latency_ms self.calibration_json_path = calibration_json_path self.auto_calibrate_thresholds = bool(auto_calibrate_thresholds) self.calibration_manual_overrides_auto = bool(calibration_manual_overrides_auto) self.auto_enable_perplexity_from_calibration = bool(auto_enable_perplexity_from_calibration) self._calibration: Optional[Dict[str, Any]] = None if self.auto_calibrate_thresholds and self.calibration_json_path: self._calibration = _load_calibration_json(self.calibration_json_path) if self._calibration is not None: logger.info( "agent_bad_case_signal_mapper: loaded calibration (percentile=%s) from %s", self._calibration.get("percentile", "n/a"), os.path.abspath(os.path.expanduser(str(self.calibration_json_path))), ) self.signal_on_llm_analysis_low = signal_on_llm_analysis_low self.llm_analysis_score_max_for_bad = llm_analysis_score_max_for_bad self.llm_analysis_discard_must_be_strict = llm_analysis_discard_must_be_strict self.high_precision_llm_analysis_discard_threshold = high_precision_llm_analysis_discard_threshold self.signal_on_llm_text_quality_low = signal_on_llm_text_quality_low self.llm_text_quality_score_max_for_bad = llm_text_quality_score_max_for_bad self.llm_text_quality_discard_must_be_strict = llm_text_quality_discard_must_be_strict self.high_precision_llm_text_quality_discard_threshold = high_precision_llm_text_quality_discard_threshold self.signal_on_negative_sentiment_hint = signal_on_negative_sentiment_hint self.negative_sentiment_substrings = negative_sentiment_substrings or [ "负面", "negative", "angry", "沮丧", "不满", "frustrated", "disappointed", "unhappy", ] self.signal_on_high_perplexity = signal_on_high_perplexity self.perplexity_high_threshold = perplexity_high_threshold self.signal_hard_query_poor_reply = signal_hard_query_poor_reply self.hard_query_difficulty_min = hard_query_difficulty_min self.poor_reply_quality_max = poor_reply_quality_max self.high_precision_on_tool_fail_alone = high_precision_on_tool_fail_alone self.min_medium_signals_for_watchlist = min_medium_signals_for_watchlist self.signal_on_low_dialog_quality_meta = bool(signal_on_low_dialog_quality_meta) self.dialog_quality_low_score_threshold = float(dialog_quality_low_score_threshold) self.min_dialog_quality_low_axes_for_signal = max(1, int(min_dialog_quality_low_axes_for_signal))
def _resolve_calibration_row(self, meta: dict) -> Dict[str, Any]: if not self.auto_calibrate_thresholds or not self._calibration: return {} default = self._calibration.get("default") if not isinstance(default, dict): default = {} bym = self._calibration.get("by_request_model") if not isinstance(bym, dict): bym = {} model = str(meta.get(MetaKeys.agent_request_model) or "").strip() row = dict(default) if model and model in bym and isinstance(bym[model], dict): row.update(bym[model]) return row def _effective_max_total_tokens(self, meta: dict) -> Optional[int]: cal_v: Optional[int] = None if self.auto_calibrate_thresholds and self._calibration: raw = self._resolve_calibration_row(meta).get("max_total_tokens") if raw is not None: try: cal_v = int(raw) except (TypeError, ValueError): pass if self.calibration_manual_overrides_auto: if self.max_total_tokens is not None: return self.max_total_tokens return cal_v if cal_v is not None: return cal_v return self.max_total_tokens def _effective_max_latency_ms(self, meta: dict) -> Optional[int]: cal_v: Optional[int] = None if self.auto_calibrate_thresholds and self._calibration: raw = self._resolve_calibration_row(meta).get("max_latency_ms") if raw is not None: try: cal_v = int(raw) except (TypeError, ValueError): pass if self.calibration_manual_overrides_auto: if self.max_latency_ms is not None: return self.max_latency_ms return cal_v if cal_v is not None: return cal_v return self.max_latency_ms def _effective_perplexity(self, meta: dict) -> Tuple[bool, float]: """Return (signal_on, threshold).""" row = self._resolve_calibration_row(meta) if self.auto_calibrate_thresholds and self._calibration else {} cal_th: Optional[float] = None raw = row.get("perplexity_high_threshold") if raw is not None: try: cal_th = float(raw) except (TypeError, ValueError): pass want_signal = self.signal_on_high_perplexity or ( cal_th is not None and self.auto_enable_perplexity_from_calibration and self.auto_calibrate_thresholds and self._calibration is not None ) if not want_signal: return False, float(self.perplexity_high_threshold) if self.calibration_manual_overrides_auto and self.signal_on_high_perplexity: return True, float(self.perplexity_high_threshold) if not self.calibration_manual_overrides_auto and cal_th is not None: return True, cal_th if cal_th is not None and (self.signal_on_high_perplexity or self.auto_enable_perplexity_from_calibration): return True, cal_th if self.signal_on_high_perplexity: return True, float(self.perplexity_high_threshold) return False, float(self.perplexity_high_threshold) def _append( self, signals: List[dict], code: str, detail: str, weight: str, ) -> None: signals.append({"code": code, "detail": detail, "weight": weight}) def _llm_eval_signal( self, stats: dict, signals: List[dict], score_key: str, record_key: str, score_max: float, discard_strict: bool, high_thresh: float, code: str, ) -> None: score = stats.get(score_key) record = stats.get(record_key) rec_norm = _normalize_recommendation(record) if score is None: return if float(score) > score_max: return strict = (not discard_strict) or (rec_norm == "discard") if not strict: return w = "high" if float(score) <= high_thresh and rec_norm == "discard" else "medium" self._append( signals, code, f"score={score}, recommendation={rec_norm or 'n/a'}", w, )
[docs] def process_single(self, sample: dict) -> dict: meta = sample.setdefault(Fields.meta, {}) stats = sample.get(Fields.stats) or {} signals: List[dict] = [] fail_count = int(meta.get(MetaKeys.tool_fail_count) or 0) if self.signal_on_tool_fail and fail_count >= self.min_tool_fail_count_for_signal: self._append( signals, "tool_message_error_pattern", f"tool_fail_count={fail_count}", "high", ) succ = int(meta.get(MetaKeys.tool_success_count) or 0) rounds = succ + fail_count ratio = meta.get(MetaKeys.tool_success_ratio) if ( self.signal_on_low_tool_success_ratio and rounds >= self.min_tool_rounds_for_ratio_signal and ratio is not None and float(ratio) <= self.tool_success_ratio_max_for_signal ): self._append( signals, "low_tool_success_ratio", f"ratio={ratio}, success={succ}, fail={fail_count}", "medium", ) q = (sample.get(self.query_key) or "").strip() r = (sample.get(self.response_key) or "").strip() if ( self.signal_on_suspect_empty_response and len(q) >= self.min_query_len_for_empty_check and len(r) <= self.max_response_len_for_empty_check ): self._append( signals, "suspect_empty_or_trivial_final_response", f"query_len={len(q)}, response_len={len(r)}", "medium", ) eff_max_tok = self._effective_max_total_tokens(meta) total_tokens = meta.get(MetaKeys.total_tokens) if eff_max_tok is not None and total_tokens is not None and int(total_tokens) > eff_max_tok: self._append( signals, "high_token_usage", f"total_tokens={total_tokens}", "medium", ) eff_max_lat = self._effective_max_latency_ms(meta) latency = meta.get(MetaKeys.agent_total_cost_time_ms) if eff_max_lat is not None and latency is not None and int(latency) > eff_max_lat: self._append( signals, "high_latency_ms", f"total_cost_time_ms={latency}", "medium", ) if self.signal_on_llm_analysis_low: self._llm_eval_signal( stats, signals, StatsKeys.llm_analysis_score, StatsKeys.llm_analysis_record, self.llm_analysis_score_max_for_bad, self.llm_analysis_discard_must_be_strict, self.high_precision_llm_analysis_discard_threshold, "llm_agent_analysis_eval_low", ) if self.signal_on_llm_text_quality_low: self._llm_eval_signal( stats, signals, StatsKeys.llm_quality_score, StatsKeys.llm_quality_record, self.llm_text_quality_score_max_for_bad, self.llm_text_quality_discard_must_be_strict, self.high_precision_llm_text_quality_discard_threshold, "llm_reply_quality_eval_low", ) if self.signal_on_negative_sentiment_hint: labels = meta.get(MetaKeys.dialog_sentiment_labels) if isinstance(labels, list) and labels: blob = " ".join(str(x).lower() for x in labels) if any(s.lower() in blob for s in self.negative_sentiment_substrings): self._append( signals, "negative_sentiment_label_hint", f"labels={labels[:6]}", "medium", ) ppl_on, ppl_th = self._effective_perplexity(meta) if ppl_on: ppl = stats.get(StatsKeys.perplexity) if ppl is not None and float(ppl) >= ppl_th: self._append( signals, "high_perplexity", f"perplexity={ppl}", "medium", ) if self.signal_hard_query_poor_reply: d = stats.get(StatsKeys.llm_difficulty_score) qs = stats.get(StatsKeys.llm_quality_score) if ( d is not None and qs is not None and float(d) >= self.hard_query_difficulty_min and float(qs) <= self.poor_reply_quality_max ): self._append( signals, "hard_query_low_reply_quality_conjunction", f"difficulty={d}, llm_quality_score={qs}", "medium", ) if self.signal_on_low_dialog_quality_meta: lows: List[str] = [] th = float(self.dialog_quality_low_score_threshold) for k in _DIALOG_QUALITY_SCORE_META_KEYS: rec = meta.get(k) if not isinstance(rec, dict): continue if rec.get("skipped") or rec.get("error"): continue sc = rec.get("score") if sc is None: continue try: fv = float(sc) except (TypeError, ValueError): continue if fv <= th: lows.append(k) if len(lows) >= self.min_dialog_quality_low_axes_for_signal: self._append( signals, "dialog_turn_quality_meta_low", f"axes={lows}, threshold={th}", "medium", ) meta[MetaKeys.agent_bad_case_signals] = signals mediums = [s for s in signals if s.get("weight") == "medium"] tool_fail_high = any( s.get("code") == "tool_message_error_pattern" and s.get("weight") == "high" for s in signals ) llm_high = any( s.get("code") in ( "llm_agent_analysis_eval_low", "llm_reply_quality_eval_low", ) and s.get("weight") == "high" for s in signals ) tier = "none" if tool_fail_high: tier = "high_precision" if self.high_precision_on_tool_fail_alone else "watchlist" elif llm_high: tier = "high_precision" elif len(mediums) >= self.min_medium_signals_for_watchlist: tier = "watchlist" elif len(signals) == 1 and signals[0].get("weight") == "medium": tier = "watchlist" meta[MetaKeys.agent_bad_case_tier] = tier return sample