Source code for data_juicer_agents.tools.plan._shared.dataset_spec

# -*- coding: utf-8 -*-
"""Shared dataset-spec helpers for plan tools."""

from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, List, Tuple

from .normalize import normalize_optional_text, normalize_params, normalize_string_list
from .schema import DatasetBindingSpec, DatasetSpec, _ALLOWED_MODALITIES


def normalize_dataset_spec(
    dataset_spec: DatasetSpec | Dict[str, Any],
) -> DatasetSpec:
    """Normalize dataset spec: strip strings, deduplicate lists.

    Performs type coercion on extra dataset fields (export_type,
    export_shard_size, etc.) via ``coerce_fields`` so that values
    serialise correctly in recipe YAML, consistent with how
    ``normalize_system_spec`` handles system fields.
    """
    if isinstance(dataset_spec, DatasetSpec):
        source = dataset_spec
    elif isinstance(dataset_spec, dict):
        source = DatasetSpec.from_dict(dataset_spec)
    else:
        raise ValueError("dataset_spec must be a dict object")

    # Coerce extra dataset fields to correct types for YAML serialization
    coerced_extra = dict(source.io._extra_fields)
    coerce_warnings: list[str] = []
    try:
        from data_juicer_agents.utils.dj_config_bridge import coerce_fields

        coerced_extra, coerce_errors = coerce_fields(source.io._extra_fields)
        if coerce_errors:
            coerce_warnings = [f"[type coercion] {err}" for err in coerce_errors]
    except Exception:
        pass  # bridge unavailable — skip coercion

    existing_warnings = normalize_string_list(source.warnings)
    existing_warnings.extend(w for w in coerce_warnings if w not in existing_warnings)

    return DatasetSpec.from_dict(
        {
            "io": {
                "dataset_path": str(source.io.dataset_path or "").strip(),
                "dataset": source.io.dataset.to_dict() if source.io.dataset is not None else None,
                "generated_dataset_config": (
                    source.io.generated_dataset_config.to_dict()
                    if source.io.generated_dataset_config is not None
                    else None
                ),
                "export_path": str(source.io.export_path or "").strip(),
                # Preserve extra dataset fields (export_type, export_shard_size, suffixes, etc.)
                **coerced_extra,
            },
            "binding": {
                "modality": str(source.binding.modality or "unknown").strip()
                or "unknown",
                "text_keys": normalize_string_list(source.binding.text_keys),
                "image_key": normalize_optional_text(source.binding.image_key),
                "audio_key": normalize_optional_text(source.binding.audio_key),
                "video_key": normalize_optional_text(source.binding.video_key),
                "image_bytes_key": normalize_optional_text(
                    source.binding.image_bytes_key
                ),
            },
            "warnings": existing_warnings,
        }
    )


def infer_modality(binding: DatasetBindingSpec) -> str:
    candidate = str(binding.modality or "unknown").strip().lower()
    has_text = bool(binding.text_keys)
    has_image = bool(binding.image_key)
    has_audio = bool(binding.audio_key)
    has_video = bool(binding.video_key)

    if candidate in {"text", "image", "audio", "video", "multimodal"}:
        if candidate == "text" and has_text:
            return candidate
        if candidate == "image" and has_image:
            return candidate
        if candidate == "audio" and has_audio:
            return candidate
        if candidate == "video" and has_video:
            return candidate
        if candidate == "multimodal" and sum([has_text, has_image, has_audio, has_video]) >= 2:
            return candidate

    active_modalities = sum([has_text, has_image, has_audio, has_video])
    if active_modalities >= 2:
        return "multimodal"
    if has_video:
        return "video"
    if has_audio:
        return "audio"
    if has_image:
        return "image"
    if has_text:
        return "text"
    return "unknown"


def _dataset_source_priority_warning(source_count: int) -> str | None:
    if source_count <= 1:
        return None
    return (
        "multiple dataset sources are present; "
        "effective priority: generated_dataset_config > dataset (multi-source config) > dataset_path"
    )


[docs] def validate_dataset_spec_payload( dataset_spec: DatasetSpec | Dict[str, Any], *, dataset_profile: Dict[str, Any] | None = None, ) -> Tuple[List[str], List[str]]: """Validate dataset spec with our business rules + DJ parser.""" if isinstance(dataset_spec, dict): dataset_spec = DatasetSpec.from_dict(dataset_spec) errors: List[str] = [] warnings: List[str] = list(dataset_spec.warnings) io = dataset_spec.io binding = dataset_spec.binding source_count = int(bool(io.generated_dataset_config)) + int(bool(io.dataset_path)) + int(bool(io.dataset)) source_warning = _dataset_source_priority_warning(source_count) if source_warning and source_warning not in warnings: warnings.append(source_warning) if source_count == 0: errors.append( "at least one dataset source is required: dataset_path, dataset, or generated_dataset_config" ) if not io.export_path: errors.append("export_path is required") if io.dataset_path: dataset_path = Path(io.dataset_path).expanduser() if not dataset_path.exists(): errors.append(f"dataset_path does not exist: {io.dataset_path}") if io.export_path: export_parent = Path(io.export_path).expanduser().resolve().parent if not export_parent.exists(): errors.append(f"export parent directory does not exist: {export_parent}") if binding.modality not in _ALLOWED_MODALITIES: errors.append("modality must be one of text/image/audio/video/multimodal/unknown") if binding.modality == "text" and not binding.text_keys: errors.append("text modality requires text_keys") if binding.modality == "image" and not binding.image_key: errors.append("image modality requires image_key") if binding.modality == "audio" and not binding.audio_key: errors.append("audio modality requires audio_key") if binding.modality == "video" and not binding.video_key: errors.append("video modality requires video_key") if binding.modality == "multimodal": active = sum([bool(binding.text_keys), bool(binding.image_key), bool(binding.audio_key), bool(binding.video_key)]) if active < 2: errors.append("multimodal modality requires at least two bound modalities") if isinstance(dataset_profile, dict) and dataset_profile.get("ok"): known_keys = set(dataset_profile.get("keys", [])) if isinstance(dataset_profile.get("keys", []), list) else set() for key in binding.text_keys: if known_keys and key not in known_keys: errors.append(f"text key not found in inspected dataset profile: {key}") for field_name, value in { "image_key": binding.image_key, "audio_key": binding.audio_key, "video_key": binding.video_key, "image_bytes_key": binding.image_bytes_key, }.items(): if value and known_keys and value not in known_keys: errors.append(f"{field_name} not found in inspected dataset profile: {value}") if io.dataset: if not io.dataset.configs: errors.append("dataset.configs must be a non-empty list") else: types = [c.type for c in io.dataset.configs] normalized_types = {str(t).strip() for t in types if str(t).strip()} if len(normalized_types) > 1: errors.append("mixture of different dataset source types is not supported") if normalized_types == {"remote"} and len(io.dataset.configs) > 1: errors.append("multiple remote datasets are not supported") # Validate against truly implemented strategies try: from data_juicer_agents.utils.dj_config_bridge import get_dj_config_bridge as _get_bridge _bridge = _get_bridge() _implemented = _bridge.get_implemented_load_strategies() if not _implemented: # Strategy discovery returned nothing — Data-Juicer may be # unavailable or registry introspection failed. Report a # single clear error instead of misleading per-entry # "not implemented" messages. errors.append( "Cannot validate dataset load strategies: Data-Juicer strategy " "discovery returned no results. Ensure Data-Juicer is installed " "and its load-strategy registry is accessible." ) else: _valid_combos = {(s["type"], s["source"]) for s in _implemented if s.get("source")} _valid_types = {s["type"] for s in _implemented} for cfg in io.dataset.configs: if cfg.source: combo = (cfg.type, cfg.source) if combo not in _valid_combos: _available = [{"type": s["type"], "source": s["source"]} for s in _implemented if s.get("source")] errors.append( f"Dataset source type='{cfg.type}', source='{cfg.source}' is not implemented " f"in the current Data-Juicer installation. " f"Call list_dataset_load_strategies to see available options: {_available}" ) elif cfg.type not in _valid_types: errors.append( f"Dataset type='{cfg.type}' is not implemented in the current Data-Juicer installation. " f"Available types: {sorted(_valid_types)}" ) except Exception: pass if io.generated_dataset_config: if not io.generated_dataset_config.type: errors.append('generated_dataset_config must have a non-empty "type" field') # DJ parser validation for dataset fields try: from data_juicer_agents.utils.dj_config_bridge import get_dj_config_bridge bridge = get_dj_config_bridge() dataset_dict: Dict[str, Any] = {} if io.dataset_path: dataset_dict["dataset_path"] = io.dataset_path if io.export_path: dataset_dict["export_path"] = io.export_path if binding.text_keys: dataset_dict["text_keys"] = list(binding.text_keys) if binding.image_key: dataset_dict["image_key"] = binding.image_key if binding.audio_key: dataset_dict["audio_key"] = binding.audio_key if binding.video_key: dataset_dict["video_key"] = binding.video_key if binding.image_bytes_key: dataset_dict["image_bytes_key"] = binding.image_bytes_key if io.dataset: dataset_dict["dataset"] = io.dataset.to_dict() if io.generated_dataset_config: dataset_dict["generated_dataset_config"] = io.generated_dataset_config.to_dict() # Merge extra dataset fields (export_type, export_shard_size, suffixes, etc.) # so that the DJ parser can validate them as well. if io._extra_fields: dataset_dict.update(io._extra_fields) if dataset_dict: is_valid, dj_errors = bridge.validate(dataset_dict) if not is_valid: errors.extend(dj_errors) except Exception: pass return errors, warnings
__all__ = [ "infer_modality", "normalize_dataset_spec", "validate_dataset_spec_payload", ]