# -*- coding: utf-8 -*-
"""Lightweight dataset probing utilities for planning-time schema inference."""
from __future__ import annotations
import csv
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple
_IMAGE_SUFFIXES = (
".jpg",
".jpeg",
".png",
".webp",
".gif",
".bmp",
".tiff",
".svg",
)
def _looks_like_image_value(value: str) -> bool:
lower = value.strip().lower()
if lower.startswith(("http://", "https://")):
return any(lower.split("?")[0].endswith(suf) for suf in _IMAGE_SUFFIXES)
if "/" in lower or "\\" in lower:
return any(lower.endswith(suf) for suf in _IMAGE_SUFFIXES)
return any(lower.endswith(suf) for suf in _IMAGE_SUFFIXES)
def _value_kind(value: Any) -> str:
if value is None:
return "null"
if isinstance(value, bool):
return "bool"
if isinstance(value, (int, float)):
return "number"
if isinstance(value, str):
if _looks_like_image_value(value):
return "image_ref"
return "text"
if isinstance(value, list):
return "array"
if isinstance(value, dict):
return "object"
return "other"
def _load_jsonl_records(path: Path, sample_size: int) -> Tuple[List[Dict[str, Any]], int]:
rows: List[Dict[str, Any]] = []
scanned = 0
with open(path, "r", encoding="utf-8") as f:
for line in f:
if len(rows) >= sample_size:
break
line = line.strip()
if not line:
continue
scanned += 1
try:
obj = json.loads(line)
except Exception:
continue
if isinstance(obj, dict):
rows.append(obj)
return rows, scanned
def _load_json_records(path: Path, sample_size: int) -> Tuple[List[Dict[str, Any]], int]:
with open(path, "r", encoding="utf-8") as f:
content = json.load(f)
if isinstance(content, list):
dict_rows = [item for item in content if isinstance(item, dict)]
return dict_rows[:sample_size], min(len(dict_rows), sample_size)
if isinstance(content, dict):
return [content], 1
return [], 0
def _load_csv_records(
path: Path, sample_size: int, delimiter: str = ","
) -> Tuple[List[Dict[str, Any]], int]:
rows: List[Dict[str, Any]] = []
try:
with open(path, "r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f, delimiter=delimiter)
# Check if headers are present
if reader.fieldnames is None:
return [], 0
for row in reader:
if len(rows) >= sample_size:
break
# Filter out None values that occur when rows have fewer fields than headers
cleaned_row = {k: v for k, v in row.items() if v is not None}
rows.append(cleaned_row)
except (UnicodeDecodeError, csv.Error, IOError) as e:
return [], 0
return rows, len(rows)
def _load_parquet_records(path: Path, sample_size: int) -> Tuple[List[Dict[str, Any]], int]:
try:
import pyarrow.parquet as pq
except ImportError:
return [], 0
parquet_file = pq.ParquetFile(str(path))
# Read only the first batch of up to sample_size rows instead of
# materializing the entire file before slicing.
batch_iter = parquet_file.iter_batches(batch_size=sample_size)
try:
first_batch = next(batch_iter)
except StopIteration:
return [], 0
rows = first_batch.to_pydict()
num_records = len(next(iter(rows.values()))) if rows else 0
records: List[Dict[str, Any]] = [
{col: values[i] for col, values in rows.items()}
for i in range(num_records)
]
return records, num_records
_UNSUPPORTED_PREFIXES = (
"hf://",
"huggingface://",
"s3://",
"gs://",
"az://",
"hdfs://",
"http://",
"https://",
)
def _looks_like_unsupported_source(dataset_path: str) -> bool:
lower = dataset_path.strip().lower()
return any(lower.startswith(prefix) for prefix in _UNSUPPORTED_PREFIXES)
def _dataset_path_to_dataset(dataset_path: str) -> Dict[str, Any]:
"""Convert a plain dataset_path string to the standard dataset config format."""
return {
"configs": [
{"type": "local", "path": dataset_path},
],
}
def _resolve_dataset_config(
dataset_path: str,
dataset: Dict[str, Any] | None,
) -> Dict[str, Any]:
"""Convert a single dataset source into the standard dataset config format.
Exactly one of *dataset_path* or *dataset* must be provided; passing both
is an error because the caller should have already enforced the single-source
constraint before reaching this point.
Raises:
ValueError: When both dataset_path and dataset are provided simultaneously.
"""
if dataset and isinstance(dataset, dict) and dataset_path:
raise ValueError(
"Only one dataset source can be specified at a time: "
"pass either dataset_path or dataset, not both."
)
if dataset and isinstance(dataset, dict):
return dict(dataset)
if dataset_path:
return _dataset_path_to_dataset(dataset_path)
return {"configs": []}
def _pick_inspectable_path(dataset_config: Dict[str, Any]) -> str | None:
"""Return the first local file path from a dataset config that can be inspected."""
configs = dataset_config.get("configs", [])
if not isinstance(configs, list):
return None
for cfg in configs:
if not isinstance(cfg, dict):
continue
source_type = str(cfg.get("type", "local")).strip().lower()
path_value = str(cfg.get("path", "")).strip()
if not path_value:
continue
if source_type in {"local", ""}:
if not _looks_like_unsupported_source(path_value):
return path_value
return None
[docs]
def inspect_dataset_schema(
dataset_source=None,
sample_size: int = 20,
) -> Dict[str, Any]:
"""Inspect a small sample of a dataset and infer keys/modality for planning.
Accepts a DatasetSource object that encapsulates the dataset path and config.
When dataset_source is None, returns a friendly error dict instead of raising.
"""
if dataset_source is None:
return {
"ok": False,
"error_type": "missing_dataset_source",
"error": "No dataset source provided. Pass a DatasetSource with path or config.",
"message": "No dataset source provided.",
}
legacy = dataset_source.to_legacy_args()
dataset_path = legacy["dataset_path"]
dataset = legacy["dataset"]
generated_dataset_config = legacy["generated_dataset_config"]
if generated_dataset_config:
return {
"ok": False,
"error_type": "unsupported_input_source",
"error": (
"inspect_dataset does not support generated dataset sources. "
"Generated datasets are created dynamically at runtime and "
"cannot be inspected ahead of time."
),
"message": (
"Cannot inspect a generated dataset source. "
"Schema probing is only supported for path or config sources."
),
}
resolved_config = _resolve_dataset_config(dataset_path, dataset)
inspectable_path = _pick_inspectable_path(resolved_config)
if inspectable_path is None:
# No local path available to inspect
configs = resolved_config.get("configs", [])
if configs:
return {
"ok": False,
"error_type": "unsupported_input_source",
"error": (
"inspect_dataset could not find a local file path to inspect "
"in the provided dataset config. Only local file-based sources "
"are supported for schema inspection."
),
"message": (
"No inspectable local path found in dataset config. "
"Ensure at least one source has type='local' with a valid file path."
),
"dataset": resolved_config,
}
return {
"ok": False,
"error_type": "missing_dataset_source",
"error": "No dataset source provided. Pass dataset_path or dataset config.",
"message": "No dataset source provided.",
}
path = Path(inspectable_path)
if not path.exists():
return {
"ok": False,
"error_type": "dataset_path_not_found",
"error": f"dataset_path does not exist: {inspectable_path}",
"message": f"dataset_path does not exist: {inspectable_path}",
"dataset": resolved_config,
}
if path.is_dir():
# Directory paths cannot be sampled directly; skip schema probing.
# Data-Juicer will auto-detect the format at runtime.
# TODO: support directory sampling by picking a representative file from the directory
# (e.g. first file matching a known suffix) so that modality and key hints
# can still be inferred when dataset_path points to a directory.
return {
"ok": True,
"message": "dataset is a directory; schema probing skipped",
"dataset": resolved_config,
"inspected_path": inspectable_path,
"sampled_records": 0,
"scanned_lines": 0,
"modality": "unknown",
"keys": [],
"candidate_text_keys": [],
"candidate_image_keys": [],
"key_stats": {},
"sample_preview": [],
}
if sample_size <= 0:
sample_size = 20
rows: List[Dict[str, Any]]
scanned: int
suffix = path.suffix.lower()
if suffix == ".json":
rows, scanned = _load_json_records(path, sample_size=sample_size)
elif suffix == ".csv":
rows, scanned = _load_csv_records(path, sample_size=sample_size)
elif suffix == ".tsv":
rows, scanned = _load_csv_records(path, sample_size=sample_size, delimiter="\t")
elif suffix == ".parquet":
rows, scanned = _load_parquet_records(path, sample_size=sample_size)
else:
rows, scanned = _load_jsonl_records(path, sample_size=sample_size)
if not rows:
return {
"ok": False,
"error_type": "inspect_failed",
"error": (
f"Failed to load any valid records from the dataset. "
f"This could be due to: (1) file format/encoding issues, "
f"(2) malformed data structure, or (3) empty dataset. "
f"Scanned {scanned} lines but found no valid dict records."
),
"message": (
f"Could not extract valid records from '{inspectable_path}'. "
f"Try inspecting the file manually with shell commands like: "
f"'head -n 3 {inspectable_path}' or 'cat {inspectable_path} | head -n 3' "
f"to verify the file format and content structure."
),
"dataset": resolved_config,
"sampled_records": 0,
"scanned_lines": scanned,
}
key_stats: Dict[str, Dict[str, Any]] = {}
for row in rows:
for key, value in row.items():
stat = key_stats.setdefault(
key,
{
"count": 0,
"kinds": {},
"avg_text_len": 0.0,
},
)
stat["count"] += 1
kind = _value_kind(value)
stat["kinds"][kind] = int(stat["kinds"].get(kind, 0)) + 1
if kind == "text":
prev_avg = float(stat["avg_text_len"])
text_count = int(stat["kinds"]["text"])
new_len = len(str(value))
stat["avg_text_len"] = prev_avg + (new_len - prev_avg) / text_count
def text_score(item: Tuple[str, Dict[str, Any]]) -> float:
key, stat = item
kinds = stat["kinds"]
text_cnt = int(kinds.get("text", 0))
if text_cnt <= 0:
return -1.0
key_bonus = 0.0
if any(h in key.lower() for h in ["text", "content", "doc", "sentence", "chunk"]):
key_bonus += 1.0
return text_cnt + min(float(stat.get("avg_text_len", 0.0)) / 80.0, 2.0) + key_bonus
def image_score(item: Tuple[str, Dict[str, Any]]) -> float:
key, stat = item
kinds = stat["kinds"]
image_cnt = int(kinds.get("image_ref", 0))
if image_cnt <= 0:
return -1.0
key_bonus = 0.0
if any(h in key.lower() for h in ["image", "img", "picture", "photo", "vision"]):
key_bonus += 1.0
return image_cnt + key_bonus
ranked_text = sorted(key_stats.items(), key=text_score, reverse=True)
ranked_image = sorted(key_stats.items(), key=image_score, reverse=True)
candidate_text_keys = [k for k, v in ranked_text if text_score((k, v)) > 0][:3]
candidate_image_keys = [k for k, v in ranked_image if image_score((k, v)) > 0][:3]
if candidate_text_keys and candidate_image_keys:
modality = "multimodal"
elif candidate_image_keys:
modality = "image"
elif candidate_text_keys:
modality = "text"
else:
modality = "unknown"
# Keep sample preview short and safe.
preview: List[Dict[str, Any]] = []
for row in rows[:2]:
one: Dict[str, Any] = {}
for k, v in row.items():
if isinstance(v, str) and len(v) > 120:
one[k] = v[:117] + "..."
else:
one[k] = v
preview.append(one)
return {
"ok": True,
"message": "dataset inspected",
"dataset": resolved_config,
"inspected_path": inspectable_path,
"sampled_records": len(rows),
"scanned_lines": scanned,
"modality": modality,
"keys": sorted(key_stats.keys()),
"candidate_text_keys": candidate_text_keys,
"candidate_image_keys": candidate_image_keys,
"key_stats": key_stats,
"sample_preview": preview,
}