Source code for data_juicer.ops.mapper.agent_dialog_normalize_mapper

# Copyright 2025 The Data-Juicer Authors. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Normalize agent interaction format (messages + choices) to DJ fields for
# dialog/text ops. Supports multi-platform, multi-agent tool formats.

import re
from typing import Any, List, Optional, Sequence, Tuple

from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys

OP_NAME = "agent_dialog_normalize_mapper"

# Default labels for flattened dialog (override for i18n, e.g. 用户 / 助手)
DEFAULT_USER_LABEL = "User"
DEFAULT_ASSISTANT_LABEL = "Assistant"


def _coerce_content_fragment(val: Any) -> str:
    """Turn a content block's text-ish field into a single flat string (no .strip on dict)."""
    if val is None:
        return ""
    if isinstance(val, str):
        return val.strip()
    if isinstance(val, dict):
        for k in ("value", "text", "content"):
            if k in val and val[k] not in (None, ""):
                return _coerce_content_fragment(val[k])
        return ""
    if isinstance(val, list):
        return "\n".join(_coerce_content_fragment(x) for x in val if x not in (None, "")).strip()
    return str(val).strip()


def _content_to_text(content: Any) -> str:
    """Extract plain text from message.content.
    Supports: str, list of {type, text} (OpenAI multimodal), list of str.
    ``text`` may be nested (dict/list); Qwen-style ``thinking`` blocks are included.
    """  # noqa: E501
    if content is None:
        return ""
    if isinstance(content, str):
        return content.strip()
    if isinstance(content, list):
        parts = []
        for block in content:
            if isinstance(block, dict):
                btype = block.get("type")
                if btype == "text":
                    parts.append(_coerce_content_fragment(block.get("text")))
                elif btype == "input_text" or (btype not in ("thinking", "reasoning") and "text" in block):
                    parts.append(_coerce_content_fragment(block.get("text")))
                elif btype in ("thinking", "reasoning") or "thinking" in block or "reasoning_content" in block:
                    # Qwen / DeepSeek / DashScope style reasoning (may omit type)
                    parts.append(
                        _coerce_content_fragment(
                            block.get("thinking") or block.get("reasoning") or block.get("reasoning_content")
                        )
                    )
            elif isinstance(block, str):
                parts.append(block.strip())
        return "\n".join(p for p in parts if p).strip()
    if isinstance(content, dict):
        return _coerce_content_fragment(content)
    return str(content).strip()


def _get_tool_name_from_call(tc: dict) -> Optional[str]:
    """Get tool/function name from one tool_call item.
    Supports OpenAI (function.name), Anthropic (name), generic.
    """
    if not isinstance(tc, dict):
        return None
    fn = tc.get("function") or tc.get("function_call")
    if isinstance(fn, dict) and fn.get("name"):
        return fn["name"]
    if tc.get("name"):
        return tc["name"]
    return None


def _tool_calls_summary(
    tool_calls: Any,
    max_names: int = 10,
) -> str:
    """Summarize tool_calls for history; multi-format."""
    if not tool_calls or not isinstance(tool_calls, list):
        return ""
    names = []
    for tc in tool_calls:
        n = _get_tool_name_from_call(tc)
        if n and n not in names:
            names.append(n)
    if not names:
        return ""
    display = names[:max_names]
    if len(names) > max_names:
        display.append(f"...+{len(names) - max_names}")
    return "[Tool calls: " + ", ".join(display) + "]"


def _list_str_for_hf_meta(values: List[str]) -> List[str]:
    """Return ``list[str]`` with Arrow-friendly element type for HuggingFace ``datasets``.

    An empty ``[]`` is often inferred as ``list<null>`` while non-empty lists use
    ``list<string>``, which breaks multi-process ``map`` when rows differ; storing a
    single empty string keeps dtype stable. Downstream should drop falsy entries.
    """
    out = [str(x).strip() for x in values if x is not None and str(x).strip()]
    return out if out else [""]


def _compress_head_tail(text: str, max_chars: int, head_ratio: float = 0.62) -> str:
    """If ``text`` exceeds ``max_chars``, keep head + tail with an explicit middle marker.

    Designed for **write-back** to ``dialog_history`` / ``text`` so downstream sees the
    same bounded view as prompt-side caps, while preserving error prefixes and trailing
    stack/summary when possible. Middle is dropped (lossy); marker states that clearly.
    """
    if max_chars <= 0 or not text or len(text) <= max_chars:
        return text
    head_ratio = min(0.85, max(0.35, float(head_ratio)))
    marker_reserve = 128
    budget = max_chars - marker_reserve
    if budget < 400:
        cut = max(0, max_chars - 48)
        return text[:cut] + "\n…[truncated — agent_dialog_normalize_mapper]…"

    head_n = max(200, int(budget * head_ratio))
    tail_n = budget - head_n
    if tail_n < 200:
        tail_n = 200
        head_n = max(200, budget - tail_n)
    omitted = len(text) - head_n - tail_n
    if omitted <= 0:
        return text
    marker = (
        f"\n\n[··· {omitted} chars omitted from middle; " "head+tail preserved — agent_dialog_normalize_mapper]\n\n"
    )
    if head_n + len(marker) + tail_n > max_chars:
        over = head_n + len(marker) + tail_n - max_chars
        head_n = max(120, head_n - over)
    return text[:head_n] + marker + text[-tail_n:]


def _apply_char_cap(
    text: str,
    max_chars: int,
    head_ratio: float,
    compressed_ref: Optional[List[bool]],
) -> str:
    """Apply :func:`_compress_head_tail` when ``max_chars`` > 0 and text is too long."""
    if max_chars <= 0 or not text or len(text) <= max_chars:
        return text
    if compressed_ref is not None:
        compressed_ref[0] = True
    return _compress_head_tail(text, max_chars, head_ratio)


def _extract_tool_types(messages: List[dict]) -> List[str]:
    """Collect unique tool/function names from messages (multi-format)."""
    out = []
    seen = set()
    for m in messages:
        for tc in m.get("tool_calls") or m.get("tool_use") or []:
            name = _get_tool_name_from_call(tc)
            if name and name not in seen:
                seen.add(name)
                out.append(name)
    return out


# Skill patterns: only structured identifiers to avoid fragment noise.
# - Path: xxx/SKILL.md or xxx\SKILL.md -> capture segment (ASCII identifier)
# - ## header: only English identifiers (## cron), not doc headings (## 记忆).
_SKILL_PATTERNS = [
    re.compile(r"[/\\]([a-zA-Z][a-zA-Z0-9_]*)[/\\]SKILL\.md", re.IGNORECASE),
    re.compile(r"##\s+([a-zA-Z][a-zA-Z0-9_]*(?:\.[a-zA-Z0-9_]+)?)\s*(?:\n|$)"),
    re.compile(
        r"Check\s+[\"'].*?([a-zA-Z][a-zA-Z0-9_]*)/SKILL\.md",
        re.IGNORECASE,
    ),
]


def _extract_skill_types(messages: List[dict]) -> List[str]:
    """Extract skill names (paths + ## English headers only)."""
    out = []
    seen = set()
    for m in messages:
        text = _content_to_text(m.get("content"))
        for pat in _SKILL_PATTERNS:
            for mo in pat.finditer(text):
                name = (mo.group(1) or "").strip()
                if name and name not in seen:
                    seen.add(name)
                    out.append(name)
    return out


def _messages_to_history(
    messages: List[dict],
    include_system_in_first_user: bool = False,
    *,
    history_tool_result_max_chars: int = 10_000,
    history_max_assistant_trace_chars: int = 0,
    history_max_user_chars: int = 0,
    history_compress_head_ratio: float = 0.62,
    compressed_ref: Optional[List[bool]] = None,
) -> List[Tuple[str, str]]:
    """Convert messages to [(query, response), ...]. User/assistant only.

    Agent/tool note: within one user turn, the model may emit **multiple**
    ``assistant`` messages (tool calls → ``tool`` → assistant again). Earlier
    implementations **replaced** the assistant side each time, dropping
    intermediate reasoning and tool traces. We **accumulate** consecutive
    assistant segments (and still append each ``tool`` result onto the same
    pair) so ``query`` / ``response`` / ``text`` match multi-step agent runs.

    Optional caps: ``history_tool_result_max_chars`` per tool payload (default
    ``10000``, same order as the old hard-coded slice; use ``0`` for unlimited);
    ``history_max_assistant_trace_chars`` on the **whole** assistant side after
    each update (``0`` = off); ``history_max_user_chars`` on user text.
    When a cap applies, middle is omitted with an explicit marker
    (head+tail). Set ``compressed_ref`` to a one-element list ``[False]`` to
    record whether any compression ran.
    """
    history = []
    pending_system = []
    ratio = history_compress_head_ratio

    for m in messages:
        role = (m.get("role") or "").lower()
        content = _content_to_text(m.get("content"))
        tool_calls = m.get("tool_calls") or m.get("tool_use") or []

        if role == "system":
            if include_system_in_first_user and content:
                pending_system.append(content)
            continue
        if role == "user":
            if pending_system:
                content = "\n\n".join(pending_system + [content]).strip()
                pending_system = []
            content = _apply_char_cap(content, history_max_user_chars, ratio, compressed_ref)
            history.append((content, ""))
            continue
        if role == "assistant":
            if not content and tool_calls:
                content = _tool_calls_summary(tool_calls)
            piece = content or ""
            if history:
                prev_q, prev_r = history[-1]
                if prev_r and piece:
                    new_r = prev_r + "\n\n" + piece
                elif piece:
                    new_r = piece
                else:
                    new_r = prev_r
                new_r = _apply_char_cap(
                    new_r,
                    history_max_assistant_trace_chars,
                    ratio,
                    compressed_ref,
                )
                history[-1] = (prev_q, new_r)
            else:
                piece_capped = _apply_char_cap(
                    piece,
                    history_max_assistant_trace_chars,
                    ratio,
                    compressed_ref,
                )
                history.append(("", piece_capped))
            continue
        if role == "tool":
            body = _apply_char_cap(
                content,
                history_tool_result_max_chars,
                ratio,
                compressed_ref,
            )
            # Append tool result to last assistant response for context
            if history and history[-1][1]:
                new_r = history[-1][1] + "\n[Tool result]\n" + body
                new_r = _apply_char_cap(
                    new_r,
                    history_max_assistant_trace_chars,
                    ratio,
                    compressed_ref,
                )
                history[-1] = (history[-1][0], new_r)
            elif history:
                lone = "[Tool result]\n" + body
                lone = _apply_char_cap(
                    lone,
                    history_max_assistant_trace_chars,
                    ratio,
                    compressed_ref,
                )
                history[-1] = (history[-1][0], lone)
            continue
    return history


def _last_user_assistant_msg_indices(
    messages: List[dict],
) -> Tuple[Optional[int], Optional[int]]:
    """0-based indices in ``messages`` of the last user / assistant turns."""
    last_u: Optional[int] = None
    last_a: Optional[int] = None
    for i, m in enumerate(messages):
        if not isinstance(m, dict):
            continue
        role = (m.get("role") or "").lower()
        if role == "user":
            last_u = i
        elif role == "assistant":
            last_a = i
    return last_u, last_a


def _first_non_empty_str(sample: dict, keys: Sequence[str]) -> Optional[str]:
    for k in keys:
        if k not in sample:
            continue
        v = sample.get(k)
        if v is None:
            continue
        s = str(v).strip()
        if s:
            return s
    return None


def _choices_to_text(choices: Any) -> str:
    """Extract reply text from choices (OpenAI / Anthropic / generic)."""
    if not choices or not isinstance(choices, list):
        return ""
    for c in choices:
        if not isinstance(c, dict):
            continue
        msg = c.get("message") or c.get("delta") or c
        text = msg.get("content")
        if text is None:
            continue
        if isinstance(text, str) and text.strip():
            return text.strip()
        if isinstance(text, list):
            t = _content_to_text(text)
            if t:
                return t
    return ""


def _flatten_history_to_text(
    history: List[Tuple[str, str]],
    user_label: str = DEFAULT_USER_LABEL,
    assistant_label: str = DEFAULT_ASSISTANT_LABEL,
) -> str:
    """Flatten history to one text for text-based ops."""
    lines = []
    for q, r in history:
        if q:
            lines.append(f"{user_label}: {q}")
        if r:
            lines.append(f"{assistant_label}: {r}")
    return "\n\n".join(lines)


[docs] @TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class AgentDialogNormalizeMapper(Mapper): """Normalize agent format (messages + choices) to DJ fields. Outputs: text, dialog_history, query, response; optionally meta tags agent_tool_types, agent_skill_types, agent_turn_count. When ``copy_lineage_fields`` is True, also copies request_model, pt, total_cost_time, and (when ``copy_request_id``) the first non-empty id among ``request_id_keys`` from the sample root into meta for cohort analysis and stable drill-down links. Always records last user/assistant message indices (in the raw ``messages`` list) when present. Supports multi-format tool_calls (e.g. tool_calls[].function.name as in OpenAI / demos/local/demo-agent-data-content.json) and configurable user/assistant labels. Optional ``history_*_max_chars`` caps keep head+tail with an explicit middle-omitted marker so ``dialog_history``, flattened ``text``, and last ``query`` / ``response`` stay aligned; ``meta.agent_dialog_history_compressed`` is set when any cap fires. """
[docs] def __init__( self, messages_key: str = "messages", choices_key: str = "choices", text_key: str = "text", history_key: str = "dialog_history", query_key: str = "query", response_key: str = "response", extract_tool_skill_tags: bool = True, include_system_in_first_user: bool = False, user_label: str = DEFAULT_USER_LABEL, assistant_label: str = DEFAULT_ASSISTANT_LABEL, copy_lineage_fields: bool = True, copy_request_id: bool = True, request_id_keys: List[str] = [ "request_id", "trace_id", "id", ], history_tool_result_max_chars: int = 10_000, history_max_assistant_trace_chars: int = 0, history_max_user_chars: int = 0, history_compress_head_ratio: float = 0.62, **kwargs, ): super().__init__(text_key=text_key, **kwargs) self.messages_key = messages_key self.choices_key = choices_key self.history_key = history_key self.query_key = query_key self.response_key = response_key self.extract_tool_skill_tags = extract_tool_skill_tags self.include_system_in_first_user = include_system_in_first_user self.user_label = user_label self.assistant_label = assistant_label self.copy_lineage_fields = copy_lineage_fields self.copy_request_id = copy_request_id self.request_id_keys = request_id_keys self.history_tool_result_max_chars = history_tool_result_max_chars self.history_max_assistant_trace_chars = history_max_assistant_trace_chars self.history_max_user_chars = history_max_user_chars self.history_compress_head_ratio = history_compress_head_ratio
[docs] def process_single(self, sample): messages = sample.get(self.messages_key) or [] choices = sample.get(self.choices_key) or [] if not isinstance(messages, list): messages = [] compressed_ref = [False] history = _messages_to_history( messages, include_system_in_first_user=self.include_system_in_first_user, history_tool_result_max_chars=self.history_tool_result_max_chars, history_max_assistant_trace_chars=self.history_max_assistant_trace_chars, history_max_user_chars=self.history_max_user_chars, history_compress_head_ratio=self.history_compress_head_ratio, compressed_ref=compressed_ref, ) flat_text = _flatten_history_to_text( history, user_label=self.user_label, assistant_label=self.assistant_label, ) last_query = history[-1][0] if history else "" last_response = history[-1][1] if history else "" if not last_response and choices: last_response = _choices_to_text(choices) sample[self.text_key] = flat_text sample[self.history_key] = history sample[self.query_key] = last_query sample[self.response_key] = last_response if Fields.meta not in sample: sample[Fields.meta] = {} meta = sample[Fields.meta] # Always set bool so HF Arrow struct schema matches across rows (num_proc>1) meta[MetaKeys.agent_dialog_history_compressed] = bool(compressed_ref[0]) meta[MetaKeys.agent_turn_count] = len(history) if self.extract_tool_skill_tags: meta[MetaKeys.agent_tool_types] = _list_str_for_hf_meta(_extract_tool_types(messages)) meta[MetaKeys.agent_skill_types] = _list_str_for_hf_meta(_extract_skill_types(messages)) last_u_idx, last_a_idx = _last_user_assistant_msg_indices(messages) if last_u_idx is not None: meta[MetaKeys.agent_last_user_msg_idx] = last_u_idx if last_a_idx is not None: meta[MetaKeys.agent_last_assistant_msg_idx] = last_a_idx if self.copy_request_id: rid = _first_non_empty_str(sample, self.request_id_keys) if rid is not None: meta[MetaKeys.agent_request_id] = rid # Cohort fields for bad-case / A-B analysis (request_model, date bucket, latency) if self.copy_lineage_fields: if sample.get("request_model") is not None: meta[MetaKeys.agent_request_model] = sample["request_model"] if sample.get("pt") is not None: meta[MetaKeys.agent_pt] = sample["pt"] if sample.get("total_cost_time") is not None: meta[MetaKeys.agent_total_cost_time_ms] = sample["total_cost_time"] return sample