Source code for data_juicer_agents.tools.plan._shared.dataset_spec
# -*- coding: utf-8 -*-
"""Shared dataset-spec helpers for plan tools."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple
from .._shared.schema import DatasetBindingSpec, DatasetSpec, _ALLOWED_MODALITIES
def _normalize_string_list(values: Iterable[Any] | None) -> List[str]:
normalized: List[str] = []
seen = set()
for item in values or []:
text = str(item or "").strip()
if not text or text in seen:
continue
normalized.append(text)
seen.add(text)
return normalized
def infer_modality(binding: DatasetBindingSpec) -> str:
candidate = str(binding.modality or "unknown").strip().lower()
has_text = bool(binding.text_keys)
has_image = bool(binding.image_key)
has_audio = bool(binding.audio_key)
has_video = bool(binding.video_key)
if candidate in {"text", "image", "audio", "video", "multimodal"}:
if candidate == "text" and has_text:
return candidate
if candidate == "image" and has_image:
return candidate
if candidate == "audio" and has_audio:
return candidate
if candidate == "video" and has_video:
return candidate
if candidate == "multimodal" and sum([has_text, has_image, has_audio, has_video]) >= 2:
return candidate
active_modalities = sum([has_text, has_image, has_audio, has_video])
if active_modalities >= 2:
return "multimodal"
if has_video:
return "video"
if has_audio:
return "audio"
if has_image:
return "image"
if has_text:
return "text"
return "unknown"
def _dataset_source_priority_warning(source_count: int) -> str | None:
if source_count <= 1:
return None
return (
"multiple dataset sources are present; current implementation is local-path-first and will "
"follow Data-Juicer source priority generated_dataset_config > dataset_path > dataset"
)
[docs]
def validate_dataset_spec_payload(
dataset_spec: DatasetSpec | Dict[str, Any],
*,
dataset_profile: Dict[str, Any] | None = None,
) -> Tuple[List[str], List[str]]:
if isinstance(dataset_spec, dict):
dataset_spec = DatasetSpec.from_dict(dataset_spec)
errors: List[str] = []
warnings: List[str] = list(dataset_spec.warnings)
io = dataset_spec.io
binding = dataset_spec.binding
source_count = int(bool(io.generated_dataset_config)) + int(bool(io.dataset_path)) + int(bool(io.dataset))
source_warning = _dataset_source_priority_warning(source_count)
if source_warning and source_warning not in warnings:
warnings.append(source_warning)
if io.dataset and not io.dataset_path:
errors.append("dataset source objects are reserved for a later iteration; use dataset_path for now")
if io.generated_dataset_config and not io.dataset_path:
errors.append(
"generated_dataset_config is reserved for a later iteration; use dataset_path for now"
)
if not io.dataset_path:
errors.append("dataset_path is required in this iteration")
if not io.export_path:
errors.append("export_path is required")
if io.dataset_path:
dataset_path = Path(io.dataset_path).expanduser()
if not dataset_path.exists():
errors.append(f"dataset_path does not exist: {io.dataset_path}")
if io.export_path:
export_parent = Path(io.export_path).expanduser().resolve().parent
if not export_parent.exists():
errors.append(f"export parent directory does not exist: {export_parent}")
if binding.modality not in _ALLOWED_MODALITIES:
errors.append("modality must be one of text/image/audio/video/multimodal/unknown")
if binding.modality == "text" and not binding.text_keys:
errors.append("text modality requires text_keys")
if binding.modality == "image" and not binding.image_key:
errors.append("image modality requires image_key")
if binding.modality == "audio" and not binding.audio_key:
errors.append("audio modality requires audio_key")
if binding.modality == "video" and not binding.video_key:
errors.append("video modality requires video_key")
if binding.modality == "multimodal":
active = sum([bool(binding.text_keys), bool(binding.image_key), bool(binding.audio_key), bool(binding.video_key)])
if active < 2:
errors.append("multimodal modality requires at least two bound modalities")
if isinstance(dataset_profile, dict) and dataset_profile.get("ok"):
known_keys = set(dataset_profile.get("keys", [])) if isinstance(dataset_profile.get("keys", []), list) else set()
for key in binding.text_keys:
if known_keys and key not in known_keys:
errors.append(f"text key not found in inspected dataset profile: {key}")
for field_name, value in {
"image_key": binding.image_key,
"audio_key": binding.audio_key,
"video_key": binding.video_key,
"image_bytes_key": binding.image_bytes_key,
}.items():
if value and known_keys and value not in known_keys:
errors.append(f"{field_name} not found in inspected dataset profile: {value}")
dataset_cfg = io.dataset or {}
if isinstance(dataset_cfg, dict) and isinstance(dataset_cfg.get("configs"), list):
types = [item.get("type") for item in dataset_cfg.get("configs", []) if isinstance(item, dict)]
normalized_types = {str(item).strip() for item in types if str(item).strip()}
if len(normalized_types) > 1:
errors.append("mixture of different dataset source types is not supported")
if normalized_types == {"remote"} and len(dataset_cfg.get("configs", [])) > 1:
errors.append("multiple remote datasets are not supported")
return errors, warnings
__all__ = ["infer_modality", "validate_dataset_spec_payload"]