Source code for data_juicer_agents.tools.plan._shared.system_spec

# -*- coding: utf-8 -*-
"""Shared system-spec helpers for plan tools."""
from __future__ import annotations

import os
from typing import Any, Dict, Iterable, List, Tuple

from .normalize import normalize_string_list
from .schema import SystemSpec


def normalize_system_spec(
    system_spec: SystemSpec | Dict[str, Any] | None,
    *,
    custom_operator_paths: Iterable[Any] | None = None,
) -> SystemSpec:
    """Normalize system spec, preserving all dynamic fields from Data-Juicer.

    Performs type coercion on all fields (core + extra) via
    ``coerce_fields`` so that values serialise correctly in recipe YAML.
    """
    if isinstance(system_spec, SystemSpec):
        spec = system_spec
    elif isinstance(system_spec, dict):
        spec = SystemSpec.from_dict(system_spec)
    elif system_spec is None:
        spec = SystemSpec()
    else:
        raise ValueError("system_spec must be a dict object")

    # Override custom_operator_paths if provided externally
    if custom_operator_paths is not None:
        spec.custom_operator_paths = normalize_string_list(custom_operator_paths)

    # Coerce all fields to correct types for YAML serialization
    try:
        from data_juicer_agents.utils.dj_config_bridge import coerce_fields

        # Coerce extra fields
        coerced_extra, extra_errors = coerce_fields(spec._extra_fields)
        spec._extra_fields = coerced_extra

        # Coerce core fields (np might be string from LLM)
        core_dict = {"np": spec.np, "executor_type": spec.executor_type}
        coerced_core, core_errors = coerce_fields(core_dict)
        spec.np = coerced_core.get("np", spec.np)
        spec.executor_type = coerced_core.get("executor_type", spec.executor_type)

        coerce_errors = extra_errors + core_errors
        if coerce_errors:
            spec.warnings.extend(f"[type coercion] {err}" for err in coerce_errors)
    except Exception:
        pass  # bridge unavailable — skip coercion

    # --- Auto-corrections (mirrors DJ init_setup_from_cfg) ----------------

    # np cap: ensure np does not exceed available CPU cores
    cpu_count = os.cpu_count() or 1
    if spec.np > cpu_count:
        spec.warnings.append(
            f"[auto-corrected] np={spec.np} exceeds CPU count "
            f"({cpu_count}); capped to {cpu_count}"
        )
        spec.np = cpu_count

    # cache / checkpoint mutual exclusion:
    # disabling cache or enabling checkpoint makes cache_compress meaningless
    use_cache = spec.get("use_cache", True)
    use_checkpoint = spec.get("use_checkpoint", False)
    cache_compress = spec.get("cache_compress", None)
    if (not use_cache or use_checkpoint) and cache_compress:
        spec.warnings.append(
            "[auto-corrected] cache_compress disabled because "
            "cache is off or checkpoint is on"
        )
        spec.set("cache_compress", None)

    # op_fusion / checkpoint mutual exclusion:
    # op fusion is not compatible with checkpoint mode
    op_fusion = spec.get("op_fusion", False)
    if op_fusion and spec.get("use_checkpoint", False):
        spec.warnings.append(
            "[auto-corrected] use_checkpoint disabled because " "op_fusion is enabled"
        )
        spec.set("use_checkpoint", False)

    return spec


[docs] def validate_system_spec_payload( system_spec: SystemSpec | Dict[str, Any], ) -> Tuple[List[str], List[str]]: """Validate system spec using Data-Juicer's native validation when possible.""" if isinstance(system_spec, dict): system_spec = SystemSpec.from_dict(system_spec) errors: List[str] = [] warnings: List[str] = [] # Basic validation for core fields if not system_spec.executor_type: errors.append("executor_type is required") if int(system_spec.np or 0) <= 0: errors.append("np must be >= 1") # DJ parser validation try: from data_juicer_agents.utils.dj_config_bridge import get_dj_config_bridge bridge = get_dj_config_bridge() system_dict = system_spec.to_dict() # Remove non-DJ fields before validation dj_dict = {k: v for k, v in system_dict.items() if k != "warnings"} is_valid, dj_errors = bridge.validate(dj_dict) if not is_valid: errors.extend(dj_errors) except Exception: pass # --- Semantic validation (mirrors DJ init_setup_from_cfg) ------------- # fusion_strategy must be in FUSION_STRATEGIES when op_fusion is on op_fusion = system_spec.get("op_fusion", False) if op_fusion: fusion_strategy = ( str(system_spec.get("fusion_strategy", "") or "").strip().lower() ) if fusion_strategy: try: from data_juicer.ops.op_fusion import FUSION_STRATEGIES if fusion_strategy not in FUSION_STRATEGIES: errors.append( f"fusion_strategy '{fusion_strategy}' is not supported; " f"must be one of {sorted(FUSION_STRATEGIES)}" ) except Exception: pass # DJ unavailable — skip check # work_dir: {job_id} placeholder must be the last path component work_dir = str(system_spec.get("work_dir", "") or "").strip() if work_dir and "{job_id}" in work_dir: if not work_dir.rstrip("/").endswith("{job_id}"): errors.append( "work_dir: '{job_id}' placeholder must be the last " "component of the path" ) warnings.extend( [item for item in system_spec.warnings if item and item not in warnings] ) return errors, warnings
__all__ = [ "normalize_system_spec", "validate_system_spec_payload", ]