Source code for data_juicer.core.tracer.tracer

import os
from collections import defaultdict
from multiprocessing import Lock

import pandas as pd
from datasets import Dataset
from loguru import logger

from data_juicer.ops import OPERATORS
from data_juicer.utils.common_utils import deprecated
from data_juicer.utils.constant import Fields


[docs] class Tracer: """ The tracer to trace the sample changes before and after an operator process. The comparison results will be stored in the work directory. Now supports sample-level tracing for better efficiency and accuracy. """
[docs] def __init__(self, work_dir, op_list_to_trace=None, show_num=10, trace_keys=None, lock=None): """ Initialization method. :param work_dir: the work directory to store the comparison results :param op_list_to_trace: the OP list to be traced. :param show_num: the maximum number of samples to show in the comparison result files. :param trace_keys: list of field names to include in trace output. If set, the specified fields' values will be included in each trace entry. """ self.work_dir = os.path.join(work_dir, "trace") if not os.path.exists(self.work_dir): os.makedirs(self.work_dir) # clear existing trace files in the work_dir for f in os.listdir(self.work_dir): os.remove(os.path.join(self.work_dir, f)) self.op_list_to_trace = op_list_to_trace if not op_list_to_trace: logger.info("Trace for all ops.") self.op_list_to_trace = set(OPERATORS.modules.keys()) else: self.op_list_to_trace = set(op_list_to_trace) self.show_num = show_num self.trace_keys = trace_keys or [] # Sample-level tracing storage: op_name -> list of trace entries self._sample_traces = defaultdict(list) # Thread lock for thread-safe sample collection (only used in non-Ray mode) # In Ray mode, each worker will have its own tracer instance self._lock = lock or Lock() # Counter for each op to track how many samples have been collected self._collected_counts = defaultdict(int)
[docs] def should_trace_op(self, op_name: str) -> bool: """ Check if an operator should be traced. :param op_name: the operator name :return: True if the operator should be traced """ return op_name in self.op_list_to_trace
[docs] def is_collection_complete(self, op_name: str) -> bool: """ Check if enough samples have been collected for an operator. :param op_name: the operator name :return: True if enough samples have been collected """ with self._lock: return self._collected_counts[op_name] >= self.show_num
[docs] def collect_mapper_sample(self, op_name: str, original_sample: dict, processed_sample: dict, text_key: str): """ Collect a sample-level change for a Mapper operator. This method is thread-safe and will only collect up to show_num samples. :param op_name: the operator name :param original_sample: the original sample before processing :param processed_sample: the processed sample after processing :param text_key: the key name of the text field to compare :return: True if the sample was collected, False if collection is complete """ if not self.should_trace_op(op_name): return False # Check if sample has changed original_text = original_sample.get(text_key, "") processed_text = processed_sample.get(text_key, "") if original_text == processed_text: return False with self._lock: # Double-check after acquiring lock if self._collected_counts[op_name] >= self.show_num: return False entry = {} # Add specified fields first (appears at start of output) for key in self.trace_keys: entry[key] = original_sample.get(key) # Add trace data entry["original_text"] = original_text entry["processed_text"] = processed_text logger.debug(f"Trace the entry in mapper [{op_name}]: {entry}") self._collected_counts[op_name] += 1 with open(self.get_trace_file_path(op_name), "a") as f: entry_str = pd.DataFrame([entry]).to_json(orient="records", lines=True, force_ascii=False) f.write(entry_str) f.flush() return True
[docs] def collect_filter_sample(self, op_name: str, sample: dict, should_keep: bool): """ Collect a sample-level change for a Filter operator. This method is thread-safe and will only collect up to show_num samples. Only collects samples that are filtered out (should_keep=False). :param op_name: the operator name :param sample: the sample being filtered :param should_keep: True if the sample should be kept, False if filtered :return: True if the sample was collected, False if collection is complete """ if not self.should_trace_op(op_name): return False # Only collect filtered samples (should_keep=False) if should_keep: return False with self._lock: # Double-check after acquiring lock if self._collected_counts[op_name] >= self.show_num: return False logger.debug(f"Trace the sample in filter [{op_name}]: {sample}") self._collected_counts[op_name] += 1 with open(self.get_trace_file_path(op_name), "a") as f: entry_str = pd.DataFrame([sample]).to_json(orient="records", lines=True, force_ascii=False) f.write(entry_str) f.flush() return True
[docs] def get_trace_file_path(self, op_name: str) -> str: """ Get the file path for a trace file. :param op_name: the operator name :return: the file path """ return os.path.join(self.work_dir, f"sample_trace-{op_name}.jsonl")
[docs] @deprecated("This method will be deprecated in the future. Please apply the sample-level tracing method instead.") def trace_mapper(self, op_name: str, previous_ds: Dataset, processed_ds: Dataset, text_key: str): """ Compare datasets before and after a Mapper. This will mainly show the different sample pairs due to the modification by the Mapper :param op_name: the op name of mapper :param previous_ds: dataset before the mapper process :param processed_ds: dataset processed by the mapper :param text_key: which text_key to trace :return: """ if op_name not in self.op_list_to_trace: return assert len(previous_ds) == len(processed_ds) dif_dict = [] num = 0 # Find different samples orderly between previous and processed # datasets until the total number of found sample pairs is enough. for i in range(len(previous_ds)): previous_sample = previous_ds[i][text_key] processed_sample = processed_ds[i][text_key] if previous_sample != processed_sample: entry = {} # Add specified fields first (appears at start of output) for key in self.trace_keys: entry[key] = previous_ds[i].get(key) # Add trace data (these take precedence over trace_keys) entry["original_text"] = previous_sample entry["processed_text"] = processed_sample dif_dict.append(entry) num += 1 if num >= self.show_num: break if len(dif_dict) == 0: logger.warning( f"Datasets before and after op [{op_name}] are all " f"the same. Thus no comparison results would be " f"generated." ) return elif len(dif_dict) < self.show_num: logger.warning( f"There are {len(dif_dict)} different samples " f"before and after op [{op_name}] -- less than " f"expected {self.show_num} samples." ) # export the tracer results. res_name = f"mapper-{op_name}.jsonl" dif_df = pd.DataFrame(dif_dict) dif_df.to_json(os.path.join(self.work_dir, res_name), orient="records", lines=True, force_ascii=False)
[docs] @deprecated("This method will be deprecated in the future. Please apply the sample-level tracing method instead.") def trace_batch_mapper(self, op_name: str, previous_ds: Dataset, processed_ds: Dataset, text_key: str): """ Compare datasets before and after a BatchMapper. This will mainly show the new samples augmented by the BatchMapper :param op_name: the op name of mapper :param previous_ds: dataset before the mapper process :param processed_ds: dataset processed by the mapper :param text_key: which text_key to trace :return: """ if op_name not in self.op_list_to_trace: return assert previous_ds[0][text_key] == processed_ds[0][text_key] aug_dict = [] # Get the first samples for i in range(len(processed_ds)): processed_sample = processed_ds[i] aug_dict.append(processed_sample) if i + 1 >= self.show_num: break if len(aug_dict) < self.show_num: logger.warning(f"There are only {len(aug_dict)} samples -- less " f"than expected {self.show_num} samples.") # export the tracer results. res_name = f"mapper-{op_name}.jsonl" dif_df = pd.DataFrame(aug_dict) dif_df.to_json(os.path.join(self.work_dir, res_name), orient="records", lines=True, force_ascii=False)
[docs] @deprecated("This method will be deprecated in the future. Please apply the sample-level tracing method instead.") def trace_filter(self, op_name: str, previous_ds: Dataset, processed_ds: Dataset): """ Compare datasets before and after a Filter. This will mainly show the filtered samples by the Filter :param op_name: the op name of filter :param previous_ds: dataset before the filter process :param processed_ds: dataset processed by the filter :return: """ if op_name not in self.op_list_to_trace: return if len(previous_ds) == len(processed_ds): logger.warning( f"Datasets before and after op [{op_name}] are all " f"the same. Thus no comparison results would be " f"generated." ) return # get the number of filtered samples. total_dif_num = len(previous_ds) - len(processed_ds) # index of the current sample in the previous dataset i = 0 filter_dict = [] # number of found filtered samples. It's the offset between two # datasets as well. num = 0 previous_ds_no_stats = ( previous_ds.remove_columns(Fields.stats) if Fields.stats in previous_ds.column_names else previous_ds ) processed_ds_no_stats = ( processed_ds.remove_columns(Fields.stats) if Fields.stats in processed_ds.column_names else processed_ds ) while i < len(previous_ds): if i - num >= len(processed_ds) or previous_ds_no_stats[i] != processed_ds_no_stats[i - num]: # 1. If all samples in processed dataset are checked but there # still some samples left in the previous dataset, all of these # left samples are filtered. # 2. If the corresponding samples in previous and processed # datasets are different, samples in the previous dataset are # filtered. num += 1 filter_dict.append(previous_ds[i]) if num >= self.show_num or num >= total_dif_num: # If the total number of found filtered samples is enough or we # have found all filtered samples, just stop. break i += 1 if len(filter_dict) < self.show_num: logger.warning( f"There are {len(filter_dict)} filtered samples " f"before and after op [{op_name}] -- less than " f"expected {self.show_num} samples." ) # export the tracer results. res_name = f"filter-{op_name}.jsonl" filter_df = pd.DataFrame(filter_dict) filter_df.to_json(os.path.join(self.work_dir, res_name), orient="records", lines=True, force_ascii=False)
[docs] def trace_deduplicator(self, op_name: str, dup_pairs: dict): """ Compare datasets before and after a Deduplicator. This will mainly show the near-duplicate sample pairs extracted by the Deduplicator. Different from the other two trace methods, the trace process for deduplicator is embedded into the process method of deduplicator, but the other two trace methods are independent of the process method of mapper and filter operators :param op_name: the op name of deduplicator :param dup_pairs: duplicate sample pairs obtained from deduplicator :return: """ if op_name not in self.op_list_to_trace: return if dup_pairs is None: logger.warning( f"Op [{op_name}] does not generate dup_pairs " f"correctly, thus no comparison results can be " f"obtained from this op." ) return if len(dup_pairs) == 0: logger.warning( f"Datasets before and after op [{op_name}] are all " f"the same. Thus no comparison results would be " f"generated." ) return elif len(dup_pairs) < self.show_num: logger.warning( f"There are {len(dup_pairs)} filtered samples " f"before and after op [{op_name}] -- less than " f"expected {self.show_num} samples." ) # reorganize the duplicate pairs dup_dict = [] for key in dup_pairs: dup_dict.append( { "dup1": dup_pairs[key][0], "dup2": dup_pairs[key][1], } ) # export the tracer result. res_name = f"duplicate-{op_name}.jsonl" dup_df = pd.DataFrame(dup_dict) dup_df.to_json(os.path.join(self.work_dir, res_name), orient="records", lines=True, force_ascii=False)