Source code for data_juicer_sandbox.hooks

# yapf: disable
import asyncio
from copy import deepcopy

from data_juicer.config import get_init_configs, prepare_side_configs
from data_juicer.core.data.dj_dataset import nested_query
from data_juicer.utils.constant import JobRequiredKeys
from jsonargparse import dict_to_namespace
from loguru import logger

from data_juicer_sandbox.context_infos import ContextInfos, JobInfos
from data_juicer_sandbox.factories import (
    data_analyzer_factory,
    data_evaluator_factory,
    data_executor_factory,
    data_pool_manipulator_factory,
    general_data_executor_factory,
    general_probe_factory,
    model_evaluator_factory,
    model_infer_evaluator_factory,
    model_infer_executor_factory,
    model_train_executor_factory,
)
from data_juicer_sandbox.utils import add_iter_subdir_to_paths


[docs] class BaseHook:
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): self.job_cfg = job_cfg self.watcher = watcher self.hook_type = job_cfg.get(JobRequiredKeys.hook.value) # hook inputs mapping for config update self.input_mapping = job_cfg.get(JobRequiredKeys.input.value, {}) # names of hook output for record the results of hook running self.output_keys = job_cfg.get(JobRequiredKeys.output.value, None) if self.output_keys and not isinstance(self.output_keys, list): self.output_keys = [self.output_keys] # hook local settings for configs self.local_settings = job_cfg.get(JobRequiredKeys.local.value, {}) # unique meta name for this job self.meta_name = job_cfg.get(JobRequiredKeys.meta_name.value, 'anonymous_meta_name') # data-juicer config for some jobs based on Data-Juicer self.dj_cfg = job_cfg.get(JobRequiredKeys.dj_configs.value, None) self.inited_dj_cfg = None # extra config for some other specific jobs self.extra_cfg = job_cfg.get(JobRequiredKeys.extra_configs.value, None) # the current iteration number self.curr_iter = None
[docs] def run(self, context_infos: ContextInfos): # record the current iteration number self.curr_iter = context_infos.iter # update the args of this hook with input mappings and local settings self._input_updating_hook(context_infos) # updating export_path of dj config. # we assume that for non-dj hooks, the output paths are specified by "export_path" field always. if self.inited_dj_cfg and 'export_path' in self.inited_dj_cfg: self.inited_dj_cfg['export_path'] = add_iter_subdir_to_paths( [self.inited_dj_cfg['export_path']], self.curr_iter)[0] if self.extra_cfg and 'export_path' in self.extra_cfg: if isinstance(self.extra_cfg['export_path'], list): self.extra_cfg['export_path'] = add_iter_subdir_to_paths( self.extra_cfg['export_path'], self.curr_iter) else: self.extra_cfg['export_path'] = add_iter_subdir_to_paths( [self.extra_cfg['export_path']], self.curr_iter)[0] outputs = self.hook() if outputs: return self._output_recording_hook(outputs) else: return None
def _input_updating_hook(self, context_infos: ContextInfos): self.specify_dj_and_extra_configs(allow_fail=True) prev_dj_cfg = deepcopy(self.dj_cfg) if self.dj_cfg else None prev_extra_cfg = deepcopy(self.extra_cfg) if self.extra_cfg else None logger.debug(f"Previous DJ Cfg: {prev_dj_cfg}") logger.debug(f"Previous Extra Cfg: {prev_extra_cfg}") logger.debug(f"Local Settings: {self.local_settings}") logger.debug(f"Input Mappings: {self.input_mapping}") # update configs according to local settings for key, value in self.local_settings.items(): key_to_updated_parts = key.split('.') cfg_type = key_to_updated_parts[0] # check which config group to update if cfg_type == JobRequiredKeys.dj_configs.value: target_config = self.dj_cfg elif cfg_type == JobRequiredKeys.extra_configs.value: target_config = self.extra_cfg else: raise ValueError(f'The key {cfg_type} to update is not supported.') key_to_updated_parts = key_to_updated_parts[1:] # update the target key if len(key_to_updated_parts) > 0: if len(key_to_updated_parts) > 1: target_config = nested_query(target_config, '.'.join(key_to_updated_parts[:-1])) target_config[key_to_updated_parts[-1]] = value else: if cfg_type == JobRequiredKeys.dj_configs.value: self.dj_cfg = value elif cfg_type == JobRequiredKeys.extra_configs.value: self.extra_cfg = value # update configs according to input mapping and context_infos for key_to_updated, key_in_history in self.input_mapping.items(): key_to_updated_parts = key_to_updated.split('.') key_in_history_parts = key_in_history.split('.') # check if it's -1 history_job_infos = context_infos if key_in_history_parts[0] == '-1': # get the latest infos if len(key_in_history_parts) <= 1: raise ValueError(f'Need to specify the job result keys precisely for inputs to ' f'find the target values. Only got [{key_in_history}].') # find the last non-empty job_infos if len(context_infos) == 0: raise ValueError(f'Cannot find the previous non-empty job infos for [{key_in_history}].') history_job_infos = context_infos.get_the_last_job_infos() key_in_history_parts = key_in_history_parts[1:] # query target values target_value = history_job_infos['.'.join(key_in_history_parts)] # check which config group to update cfg_type = key_to_updated_parts[0] if cfg_type == JobRequiredKeys.dj_configs.value: target_config = self.dj_cfg elif cfg_type == JobRequiredKeys.extra_configs.value: target_config = self.extra_cfg else: raise ValueError(f'The key {key_to_updated_parts[0]} to update is not supported.') key_to_updated_parts = key_to_updated_parts[1:] # update the target key if len(key_to_updated_parts) > 0: if len(key_to_updated_parts) > 1: target_config = nested_query(target_config, '.'.join(key_to_updated_parts[:-1])) target_config[key_to_updated_parts[-1]] = target_value else: if cfg_type == JobRequiredKeys.dj_configs.value: self.dj_cfg = target_value elif cfg_type == JobRequiredKeys.extra_configs.value: self.extra_cfg = target_value if self.dj_cfg != prev_dj_cfg or self.extra_cfg != prev_extra_cfg: logger.info('Configs are updated according to input and local settings. Re-initializing them...') self.specify_dj_and_extra_configs() def _output_recording_hook(self, outputs): if isinstance(outputs, tuple): # multiple outputs outputs = list(outputs) else: # single output if isinstance(outputs, list) and len(outputs) == 1: # if there is only 1 ele in the returned list, unpack it outputs = outputs[0] outputs = [outputs] if self.output_keys is None: self.output_keys = list(range(len(outputs))) if len(outputs) != len(self.output_keys): raise ValueError( f'## HOOK [{self.hook_type}]: The number of outputs does not match the number of output keys. ' f'Expected {len(self.output_keys)} but got {len(outputs)}' ) curr_job_infos = JobInfos(self.meta_name, self.output_keys, outputs) return curr_job_infos
[docs] def hook(self, **kwargs): raise NotImplementedError
[docs] def specify_dj_and_extra_configs(self, allow_fail=False): # prepare data-juicer config if self.dj_cfg: prev_dj_cfg = self.dj_cfg prev_inited_dj_cfg = self.inited_dj_cfg try: logger.info('Parsing Data-Juicer configs in the job.') self.dj_cfg = prepare_side_configs(self.dj_cfg) # require Data-Juicer data process in some jobs # so we need to init the Data-Juicer data process configs self.inited_dj_cfg = get_init_configs(self.dj_cfg) self.dj_cfg = dict_to_namespace(self.dj_cfg) except Exception as e: if allow_fail: self.dj_cfg = prev_dj_cfg self.inited_dj_cfg = prev_inited_dj_cfg else: raise e # prepare extra configs if self.extra_cfg: prev_extra_cfg = self.extra_cfg try: logger.info('Parsing extra configs in the job.') self.extra_cfg = prepare_side_configs(self.extra_cfg) self.extra_cfg = dict_to_namespace(self.extra_cfg) except Exception as e: if allow_fail: self.extra_cfg = prev_extra_cfg else: raise e
[docs] class ProbeViaAnalyzerHook(BaseHook): """ The hook to probe dataset via Data-Juicer Analyzer. Input: - A data-juicer config. Output: - the path to export the analyzed dataset. """
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for probing the data via Analyzer :param job_cfg: the job configs :param watcher: for watching the result """ super(ProbeViaAnalyzerHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): analyzer = data_analyzer_factory(self.inited_dj_cfg) # probe the data via Analyzer logger.info('Begin to analyze data') analyzer.run() analyzer_res = analyzer.overall_result # drop string rows to avoid unaligned dtypes string_rows = ['unique', 'top', 'freq'] for row_name in string_rows: if row_name in analyzer_res.index: analyzer_res = analyzer_res.drop(row_name) self.watcher.watch(analyzer_res, self.meta_name) return analyzer.exporter.export_path
[docs] class ProbeViaModelInferHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for probing the data via Model Infer :param job_cfg: the job configs :param watcher: for watching the result """ super(ProbeViaModelInferHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): data_executor = data_executor_factory(self.inited_dj_cfg) model_infer_executor = model_infer_evaluator_factory(self.extra_cfg) # TODO # probe the model (calling inference sub-pipeline) based on # original data, such that we know what is the "hard" data for # the model and how to process the data accordingly sampled_data = data_executor.sample_data( sample_ratio=self.inited_dj_cfg.data_probe_ratio, sample_algo=self.inited_dj_cfg.data_probe_algo, ) res_type, infer_res = model_infer_executor.run( model_infer_executor.model_config['type'], sampled_data) self.watcher.watch(infer_res, self.meta_name) return infer_res
[docs] class GeneralProbeHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): super(GeneralProbeHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): data_probe = general_probe_factory(self.extra_cfg) logger.info('Begin to probe data.') ret = data_probe.run() return ret
[docs] class RefineRecipeViaKSigmaHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for refining the recipe via K Sigma :param job_cfg: the job configs :param watcher: for watching the result """ super(RefineRecipeViaKSigmaHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): path_k_sigma_recipe = self.extra_cfg.path_k_sigma_recipe # use k-sigma strategy to modify the data recipe # TODO: cannot import tools correctly if DJ is installed by pypi. Maybe we need # other importing methods. from data_juicer.tools.hpo.execute_hpo_3sigma import modify_recipe_k_sigma modify_recipe_k_sigma(self.dj_cfg, self.watcher.query(self.meta_name), path_k_sigma_recipe) return path_k_sigma_recipe
[docs] class RefineRecipeViaModelFeedbackHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for refining the recipe via Model Feedback :param job_cfg: the job configs :param watcher: for watching the result """ super(RefineRecipeViaModelFeedbackHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): # TODO # use model-feedback-based strategy to modify the data recipe, # e.g., more mapper on the "hard" or "sensitive" data, those were # ranked by user-interested measurement after model inference if self.sandbox_cfg.path_model_feedback_recipe is not None: # modify_recipe_model_feedback( # self.sandbox_cfg, # self.watcher.query("measure_on_infer_res"), # self.sandbox_cfg.path_model_feedback_recipe) raise NotImplementedError('Not implemented yet.') return None
[docs] class ProcessDataHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for processing the data via Data-Juicer :param job_cfg: the job configs :param watcher: for watching the result """ super(ProcessDataHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): data_executor = data_executor_factory(self.inited_dj_cfg) # basic routine to process data, users can customize this freely logger.info('Begin to process the data with given dj recipe') data_executor.run() return self.inited_dj_cfg.export_path
[docs] class DataPoolManipulationHook(BaseHook): """ Hook for data pool manipulation, including construction, combination, ranking, etc. """
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): super(DataPoolManipulationHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): data_pool_manipulator = data_pool_manipulator_factory(self.extra_cfg) logger.info('Begin to manipulate data pools.') ret = data_pool_manipulator.run() return ret
[docs] class GeneralDataExecutorHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): super(GeneralDataExecutorHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): data_executor = general_data_executor_factory(self.extra_cfg) logger.info('Begin to execute general data executor.') ret = data_executor.run() return ret
[docs] class TrainModelHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for model training :param job_cfg: the job configs :param watcher: for watching the result """ super(TrainModelHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): # try to update train dataset model_trainer = model_train_executor_factory(self.extra_cfg, watcher=self.watcher) # basic routine to train model via the processed data, # users can customize this freely logger.info('Begin to train the model with given model config') # update training dataset path ret = asyncio.run( model_trainer.run(model_trainer.model_config['type'], **kwargs)) return ret
[docs] class InferModelHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for model training :param job_cfg: the job configs :param watcher: for watching the result """ super(InferModelHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): model_infer = model_infer_executor_factory(self.extra_cfg, watcher=self.watcher) logger.info('Begin to infer the model with given model config') ret = asyncio.run(model_infer.run(model_infer.model_config['type'], **kwargs)) return ret
[docs] class EvaluateDataHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for data evaluation :param job_cfg: the job configs :param watcher: for watching the result """ super(EvaluateDataHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): data_evaluator = data_evaluator_factory(self.extra_cfg) # basic routine to evaluate the given data, # users can customize this freely logger.info('Begin to evaluate the data with given evaluator config') eval_res = data_evaluator.run(eval_type='data', **kwargs) self.watcher.watch(eval_res, self.meta_name) return eval_res
[docs] class EvaluateModelHook(BaseHook):
[docs] def __init__(self, job_cfg, watcher, *args, **kwargs): """ Initialize the hook for model evaluation :param job_cfg: the job configs :param watcher: for watching the result """ super(EvaluateModelHook, self).__init__(job_cfg, watcher, *args, **kwargs)
[docs] def hook(self, **kwargs): model_evaluator = model_evaluator_factory(self.extra_cfg) # basic routine to evaluate the given model, # users can customize this freely logger.info('Begin to evaluate the model with given evaluator config') eval_res = model_evaluator.run(kwargs) self.watcher.watch(eval_res, self.meta_name) return eval_res
HOOK_MAPPING = { # Data/Model Probe hooks 'ProbeViaAnalyzerHook': ProbeViaAnalyzerHook, 'ProbeViaModelInferHook': ProbeViaModelInferHook, 'GeneralProbeHook': GeneralProbeHook, # Data-Recipe Refinement hooks 'RefineRecipeViaKSigmaHook': RefineRecipeViaKSigmaHook, 'RefineRecipeViaModelFeedbackHook': RefineRecipeViaModelFeedbackHook, # Data/Model Execution hooks 'ProcessDataHook': ProcessDataHook, 'DataPoolManipulationHook': DataPoolManipulationHook, 'GeneralDataExecutorHook': GeneralDataExecutorHook, 'TrainModelHook': TrainModelHook, 'InferModelHook': InferModelHook, # Evaluation hooks 'EvaluateDataHook': EvaluateDataHook, 'EvaluateModelHook': EvaluateModelHook, }
[docs] def register_hook(job_cfg, watcher): if job_cfg.hook not in HOOK_MAPPING: raise ValueError('Undefined hook: [{job_cfg.hook}].') return HOOK_MAPPING[job_cfg.hook](job_cfg, watcher)