Source code for data_juicer.core.data.ray_dataset

from __future__ import annotations

import os
from functools import partial
from typing import Any, Dict, List, Literal, Optional, Union

import pyarrow
import ray
from jsonargparse import Namespace
from loguru import logger
from ray.data._internal.util import get_compute_strategy

from data_juicer.core.data import DJDataset
from data_juicer.core.data.schema import Schema
from data_juicer.core.tracer import should_trace_op
from data_juicer.ops import Deduplicator, Filter, Mapper, Pipeline
from data_juicer.ops.base_op import DEFAULT_BATCH_SIZE, TAGGING_OPS
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import is_remote_path
from data_juicer.utils.resource_utils import cuda_device_count
from data_juicer.utils.webdataset_utils import _custom_default_decoder


[docs] def get_abs_path(path, dataset_dir): if is_remote_path(path): return path full_path = os.path.abspath(os.path.join(dataset_dir, path)) if os.path.exists(full_path): return full_path else: return path
[docs] def convert_to_absolute_paths(samples: pyarrow.Table, dataset_dir, path_keys): for key in path_keys: col_idx = samples.schema.get_field_index(key) cols = samples.column(col_idx) def _process_paths(): for col in cols: path = col.as_py() if isinstance(path, str): yield get_abs_path(path, dataset_dir) elif isinstance(path, list): yield [get_abs_path(p, dataset_dir) for p in path] else: yield path samples = samples.set_column(col_idx, key, pyarrow.array(_process_paths())) return samples
# TODO: check path for nestdataset
[docs] def set_dataset_to_absolute_path(dataset, dataset_path, cfg): """ Set all the path in input data to absolute path. Checks dataset_dir and project_dir for valid paths. """ path_keys = [] columns = dataset.columns() for key in [ cfg.get("video_key", "videos"), cfg.get("image_key", "images"), cfg.get("audio_key", "audios"), ]: if key in columns: path_keys.append(key) if len(path_keys) > 0: dataset_dir = os.path.dirname(dataset_path) logger.info(f"dataset_dir: {dataset_dir}") dataset = dataset.map_batches( partial(convert_to_absolute_paths, dataset_dir=dataset_dir, path_keys=path_keys), batch_format="pyarrow", zero_copy_batch=True, batch_size=DEFAULT_BATCH_SIZE, ) return dataset
[docs] def preprocess_dataset(dataset: ray.data.Dataset, dataset_path, cfg) -> ray.data.Dataset: if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) return dataset
[docs] def get_num_gpus(op, op_proc): if not op.use_cuda(): return 0 proc_per_gpu = op_proc / cuda_device_count() return 1.0 / proc_per_gpu
[docs] def filter_batch(batch, filter_func): mask = pyarrow.array(filter_func(batch.to_pydict())) return batch.filter(mask)
[docs] class RayDataset(DJDataset):
[docs] def __init__( self, dataset: ray.data.Dataset, dataset_path: str = None, cfg: Optional[Namespace] = None, auto_op_parallelism=True, ) -> None: self.data = preprocess_dataset(dataset, dataset_path, cfg) # if auto_op_parallelism is set in both args and cfg, cfg takes precedence if cfg and cfg.get("auto_op_parallelism") is not None: self._auto_proc = cfg.get("auto_op_parallelism") else: self._auto_proc = auto_op_parallelism
[docs] def schema(self) -> Schema: """Get dataset schema. Returns: Schema: Dataset schema containing column names and types """ if self.data is None or self.data.columns() is None: raise ValueError("Dataset is empty or not initialized") return Schema.from_ray_schema(self.data.schema())
[docs] def get(self, k: int) -> List[Dict[str, Any]]: """Get k rows from the dataset.""" if k < 0: raise ValueError(f"k must be non-negative, got {k}") if k == 0: return [] k = min(k, self.data.count()) return list(self.data.limit(k).take())
[docs] def get_column(self, column: str, k: Optional[int] = None) -> List[Any]: """Get column values from Ray dataset. Args: column: Name of the column to retrieve k: Optional number of rows to return. If None, returns all rows Returns: List of values from the specified column Raises: KeyError: If column doesn't exist ValueError: If k is negative """ if self.data is None or self.data.columns() is None or column not in self.data.columns(): raise KeyError(f"Column '{column}' not found in dataset") if k is not None: if k < 0: raise ValueError(f"k must be non-negative, got {k}") if k == 0: return [] k = min(k, self.data.count()) return [row[column] for row in self.data.limit(k).take()] return [row[column] for row in self.data.take()]
[docs] def process(self, operators, *, exporter=None, checkpointer=None, tracer=None) -> DJDataset: if operators is None: return self if not isinstance(operators, list): operators = [operators] from data_juicer.utils.process_utils import calculate_ray_np if self._auto_proc: calculate_ray_np(operators) # Check if dataset is empty - Ray returns None for columns() on empty datasets # with unknown schema. If empty, skip processing as there's nothing to process. try: row_count = self.data.count() except Exception: row_count = 0 if row_count == 0: from loguru import logger logger.warning("Dataset is empty (0 rows), skipping operator processing") return self # Cache columns once at start to avoid breaking pipeline with repeated columns() calls # Ray's columns() internally does limit(1) which forces execution and breaks streaming columns_result = self.data.columns() # Handle empty dataset case where columns() returns None if columns_result is None: from loguru import logger logger.warning("Dataset has unknown schema (likely empty), skipping operator processing") return self cached_columns = set(columns_result) for op in operators: cached_columns = self._run_single_op(op, cached_columns, tracer=tracer) return self
def _run_single_op(self, op, cached_columns=None, tracer=None): # Use cached columns to avoid calling self.data.columns() which breaks pipeline if cached_columns is None: cached_columns = set(self.data.columns()) if op._name in TAGGING_OPS.modules and Fields.meta not in cached_columns: def process_batch_arrow(table: pyarrow.Table): new_column_data = [{} for _ in range(len(table))] new_table = table.append_column(Fields.meta, [new_column_data]) return new_table self.data = self.data.map_batches( process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE ) cached_columns.add(Fields.meta) try: batch_size = getattr(op, "batch_size", 1) if op.is_batched_op() else 1 if isinstance(op, Mapper): # Wrap process method with tracer for sample-level collection original_process = None if tracer and should_trace_op(tracer, op._name): from data_juicer.ops.base_op import wrap_mapper_with_tracer original_process = op.process op.process = wrap_mapper_with_tracer(original_process, op._name, op.text_key, tracer, True) try: if op.use_ray_actor(): compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) self.data = self.data.map_batches( op.__class__, fn_args=None, fn_kwargs=None, fn_constructor_args=op._init_args, fn_constructor_kwargs=op._init_kwargs, batch_size=batch_size, num_cpus=op.num_cpus, num_gpus=op.num_gpus, compute=compute, batch_format="pyarrow", runtime_env=op.runtime_env, ) else: compute = get_compute_strategy(op.process, concurrency=op.num_proc) self.data = self.data.map_batches( op.process, batch_size=batch_size, batch_format="pyarrow", num_cpus=op.num_cpus, num_gpus=op.num_gpus, compute=compute, runtime_env=op.runtime_env, ) finally: # Restore original process method if tracer and should_trace_op(tracer, op._name) and original_process: op.process = original_process elif isinstance(op, Filter): # Use cached_columns instead of self.data.columns() to avoid breaking pipeline if Fields.stats not in cached_columns: def process_batch_arrow(table: pyarrow.Table): new_column_data = [{} for _ in range(len(table))] new_talbe = table.append_column(Fields.stats, [new_column_data]) return new_talbe self.data = self.data.map_batches( process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE ) cached_columns.add(Fields.stats) if op.use_ray_actor(): compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) self.data = self.data.map_batches( op.__class__, fn_args=None, fn_kwargs=None, fn_constructor_args=op._init_args, fn_constructor_kwargs=op._init_kwargs, batch_size=batch_size, num_cpus=op.num_cpus, num_gpus=op.num_gpus, compute=compute, batch_format="pyarrow", runtime_env=op.runtime_env, ) else: compute = get_compute_strategy(op.compute_stats, concurrency=op.num_proc) self.data = self.data.map_batches( op.compute_stats, batch_size=batch_size, batch_format="pyarrow", num_cpus=op.num_cpus, num_gpus=op.num_gpus, compute=compute, runtime_env=op.runtime_env, ) if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) # Wrap process method with tracer for sample-level collection original_process = None if tracer and should_trace_op(tracer, op._name): from data_juicer.ops.base_op import wrap_filter_with_tracer original_process = op.process op.process = wrap_filter_with_tracer(original_process, op._name, tracer, op.is_batched_op()) try: if op.is_batched_op(): # The core computation have been done in compute_stats, # and the filter process only performs simple filtering. # cpu and parallelism are not set here self.data = self.data.map_batches( partial(filter_batch, filter_func=op.process), batch_format="pyarrow", zero_copy_batch=True, batch_size=DEFAULT_BATCH_SIZE, runtime_env=op.runtime_env, ) else: self.data = self.data.filter( op.process, runtime_env=op.runtime_env, ) finally: # Restore original process method if tracer and should_trace_op(tracer, op._name) and original_process: op.process = original_process elif isinstance(op, (Deduplicator, Pipeline)): self.data = op.run(self.data) else: logger.error("Ray executor only support Filter, Mapper, Deduplicator and Pipeline OPs for now") raise NotImplementedError except: # noqa: E722 logger.error(f"An error occurred during Op [{op._name}].") import traceback traceback.print_exc() exit(1) return cached_columns
[docs] def count(self) -> int: return self.data.count()
[docs] @classmethod def read(cls, data_format: str, paths: Union[str, List[str]]) -> RayDataset: if data_format in {"json", "jsonl"}: return RayDataset.read_json(paths) elif data_format == "webdataset": return RayDataset.read_webdataset(paths) elif data_format in { "parquet", "images", "parquet_bulk", "csv", "text", "avro", "numpy", "tfrecords", "binary_files", "lance", }: return getattr(ray.data, f"read_{data_format}")(paths)
[docs] @classmethod def read_json(cls, paths: Union[str, List[str]]) -> RayDataset: # Note: a temp solution for reading json stream # TODO: replace with ray.data.read_json_stream once it is available import pyarrow.json as js try: js.open_json return read_json_stream(paths) except AttributeError: return ray.data.read_json(paths)
[docs] @classmethod def read_webdataset(cls, paths: Union[str, List[str]]) -> RayDataset: return ray.data.read_webdataset(paths, decoder=partial(_custom_default_decoder, format="PIL"))
[docs] def to_list(self) -> list: return self.data.to_pandas().to_dict(orient="records")
[docs] class JSONStreamDatasource(ray.data.read_api.ArrowJSONDatasource): """ A temp Datasource for reading json stream. Note: Depends on a customized `pyarrow` with `open_json` method. """ def _read_stream(self, f: "pyarrow.NativeFile", path: str): # Check if open_json is available (PyArrow 20.0.0+) try: from pyarrow.json import open_json except ImportError: # Fall back to read_json for older PyArrow versions # This will read the entire file into memory, but works with older PyArrow import pyarrow.json as js try: # Read the entire file as a table table = js.read_json(f, **self.arrow_json_args) if table.num_rows > 0: yield table except Exception as e: raise ValueError(f"Failed to read JSON file: {path}. Error: {e}") from e return try: reader = open_json( f, read_options=self.read_options, **self.arrow_json_args, ) schema = None while True: try: batch = reader.read_next_batch() table = pyarrow.Table.from_batches([batch], schema=schema) if schema is None: schema = table.schema yield table except StopIteration: return except pyarrow.lib.ArrowInvalid as e: raise ValueError(f"Failed to read JSON file: {path}.") from e
[docs] def read_json_stream( paths: Union[str, List[str]], *, filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = -1, ray_remote_args: Dict[str, Any] = None, arrow_open_stream_args: Optional[Dict[str, Any]] = None, meta_provider=None, partition_filter=None, partitioning=ray.data.read_api.Partitioning("hive"), include_paths: bool = False, ignore_missing_paths: bool = False, shuffle: Union[Literal["files"], None] = None, file_extensions: Optional[List[str]] = ["json", "jsonl"], concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, **arrow_json_args, ) -> ray.data.Dataset: # Check if open_json is available (PyArrow 20.0.0+) # If not, fall back to ray.data.read_json which works with older PyArrow try: import pyarrow.json as js js.open_json # Check if attribute exists except (ImportError, AttributeError): # Fall back to standard ray.data.read_json for older PyArrow versions # This works with filesystem parameter for S3 return ray.data.read_json(paths, filesystem=filesystem) if meta_provider is None: meta_provider = ray.data.read_api.DefaultFileMetadataProvider() datasource = JSONStreamDatasource( paths, arrow_json_args=arrow_json_args, filesystem=filesystem, open_stream_args=arrow_open_stream_args, meta_provider=meta_provider, partition_filter=partition_filter, partitioning=partitioning, ignore_missing_paths=ignore_missing_paths, shuffle=shuffle, include_paths=include_paths, file_extensions=file_extensions, ) return ray.data.read_datasource( datasource, parallelism=parallelism, ray_remote_args=ray_remote_args, concurrency=concurrency, override_num_blocks=override_num_blocks, )