data_juicer.ops.mapper.usage_counter_mapper 源代码

# Copyright 2025 The Data-Juicer Authors. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Extract token usage from agent response (choices, usage, response_metadata).
# Multi-format: OpenAI, Anthropic, generic usage objects.

from typing import Any, List, Set, Tuple

from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys

OP_NAME = "usage_counter_mapper"


def _usage_dedup_key(u: dict) -> Tuple[int, int, int]:
    """Stable key so duplicate copies of the same usage (e.g. top-level + choice) merge."""
    try:
        p = int(u.get("prompt_tokens") or 0)
    except (TypeError, ValueError):
        p = 0
    try:
        c = int(u.get("completion_tokens") or 0)
    except (TypeError, ValueError):
        c = 0
    raw_t = u.get("total_tokens")
    if raw_t is not None:
        try:
            t = int(raw_t)
        except (TypeError, ValueError):
            t = p + c
    else:
        t = p + c
    return (p, c, t)


def _dedupe_usages_preserve_order(usages: List[dict]) -> List[dict]:
    seen: Set[Tuple[int, int, int]] = set()
    out: List[dict] = []
    for u in usages:
        key = _usage_dedup_key(u)
        if key in seen:
            continue
        seen.add(key)
        out.append(u)
    return out


def _get_usage_from_obj(obj: Any) -> dict:
    """Extract prompt_tokens, completion_tokens, total_tokens from a dict.
    Supports: nested obj.usage / obj.usage_metadata, or obj as the usage dict.
    """
    if not isinstance(obj, dict):
        return {}
    usage = obj.get("usage") or obj.get("usage_metadata") or {}
    if not isinstance(usage, dict):
        usage = {}
    # Top-level usage (e.g. response_usage with prompt_tokens directly)
    if usage and (usage.get("prompt_tokens") is not None or usage.get("completion_tokens") is not None):
        pass
    elif obj.get("prompt_tokens") is not None or obj.get("completion_tokens") is not None:
        usage = obj
    p = usage.get("prompt_tokens") or usage.get("input_tokens", 0)
    c = usage.get("completion_tokens") or usage.get("output_tokens", 0)
    return {"prompt_tokens": p, "completion_tokens": c, "total_tokens": usage.get("total_tokens")}


def _aggregate_usage(usages: List[dict]) -> tuple:
    """Sum prompt/completion; total if present else prompt+completion."""
    p = sum(u.get("prompt_tokens") or 0 for u in usages)
    c = sum(u.get("completion_tokens") or 0 for u in usages)
    totals = [u.get("total_tokens") for u in usages if u.get("total_tokens") is not None]
    t = totals[0] if totals else None
    if t is None and (p or c):
        t = p + c
    return p, c, t


[文档] @TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class UsageCounterMapper(Mapper): """Write token usage to meta from choices/usage (OpenAI/Anthropic-style). Collects every non-empty usage dict found (top-level ``usage_key``, ``response_metadata``, each ``choices[]`` entry, nested message usage). By default, **deduplicates** identical usage snapshots before summing: same ``(prompt_tokens, completion_tokens, total_tokens or prompt+completion)`` only counts once (typical when ``response_usage`` mirrors ``choices[0].usage``). Set ``dedupe_identical_usage: false`` to restore legacy double-counting. """
[文档] def __init__( self, choices_key: str = "choices", usage_key: str = "usage", response_metadata_key: str = "response_metadata", dedupe_identical_usage: bool = True, **kwargs, ): super().__init__(**kwargs) self.choices_key = choices_key self.usage_key = usage_key self.response_metadata_key = response_metadata_key self.dedupe_identical_usage = bool(dedupe_identical_usage)
[文档] def process_single(self, sample): usages = [] # Top-level usage (e.g. usage / response_usage) if self.usage_key in sample: u = _get_usage_from_obj(sample.get(self.usage_key)) or _get_usage_from_obj(sample) if u: usages.append(u) # response_metadata.usage meta = sample.get(self.response_metadata_key) or {} if isinstance(meta, dict): u = _get_usage_from_obj(meta) if u: usages.append(u) # choices[].usage or choices[].message.usage choices = sample.get(self.choices_key) or [] if isinstance(choices, list): for c in choices: if not isinstance(c, dict): continue u = _get_usage_from_obj(c) if u: usages.append(u) msg = c.get("message") or c.get("delta") if isinstance(msg, dict): u = _get_usage_from_obj(msg) if u: usages.append(u) if self.dedupe_identical_usage: usages = _dedupe_usages_preserve_order(usages) if Fields.meta not in sample: sample[Fields.meta] = {} meta = sample[Fields.meta] p, c, t = _aggregate_usage(usages) meta[MetaKeys.prompt_tokens] = p meta[MetaKeys.completion_tokens] = c meta[MetaKeys.total_tokens] = t return sample