import json
import os
from copy import deepcopy
from typing import List
import wandb
import yaml
from data_juicer.config import merge_config, prepare_side_configs
from data_juicer.utils.constant import JobRequiredKeys
from jsonargparse import Namespace as JsonNamespace
from jsonargparse import dict_to_namespace, namespace_to_dict
from loguru import logger
from data_juicer_sandbox.context_infos import (
ContextInfos,
GlobalContextInfos,
PipelineInfos,
)
from data_juicer_sandbox.hooks import register_hook
from data_juicer_sandbox.utils import validate_hook_output
[docs]
class SandBoxWatcher:
"""
Basic Watcher class to manage interested results, and manage the experiment
within the sandbox based on WandB UI and it's utilities.
"""
[docs]
def __init__(self, sandbox_cfg):
"""
Initialize the watcher with a reference to an executor instance.
"""
# the web-ui and experiment versioning is based on WandB
project_name = sandbox_cfg.project_name
experiment_name = sandbox_cfg.experiment_name
hpo_config = sandbox_cfg.hpo_config
self.sandbox_cfg = sandbox_cfg
if not os.path.exists(self.sandbox_cfg.work_dir):
os.makedirs(self.sandbox_cfg.work_dir, exist_ok=True)
self.wandb_run = wandb.init(project=project_name, name=experiment_name)
if hpo_config is not None and "metric" in hpo_config and "name" in hpo_config["metric"]:
self.object_name_in_hpo = hpo_config["metric"]["name"]
else:
self.object_name_in_hpo = None
self.logged_res = {}
[docs]
def query(self, meta_name: str):
"""
Query the result from the logged_res.
"""
return self.logged_res.get(meta_name)
[docs]
def watch(self, res, meta_name: str = ""):
"""
Flatten the result in dot structure and log it into WandB.
"""
if isinstance(res, dict):
for key, value in res.items():
# getting the left nodes of the given res dictionary.
if isinstance(value, dict):
self.watch(value, f"{meta_name}.{key}")
else:
self.logged_res[f"{meta_name}.{key}"] = value
if self.object_name_in_hpo == f"{meta_name}.{key}":
# Ensuring float results for HPO experiments
value = float(value)
self.wandb_run.log({f"{meta_name}.{key}": value})
else:
self.logged_res[meta_name] = res
if meta_name == self.object_name_in_hpo:
res = float(res)
self.wandb_run.log({meta_name: res})
[docs]
def setup_sweep(self, hpo_config: dict = None, project_name: str = None):
"""
Setup and start a new WandB sweep.
"""
if hpo_config is None:
hpo_config = self.sandbox_cfg.hpo_config
if project_name is None:
project_name = self.sandbox_cfg.project_name
sweep_id = wandb.sweep(sweep=hpo_config, project=project_name)
if hpo_config is not None and "metric" in hpo_config and "name" in hpo_config["metric"]:
self.object_name_in_hpo = hpo_config["metric"]["name"]
return sweep_id
[docs]
def watch_cfgs(self, cfgs: List[tuple] = None):
"""
Watch the configuration of the experiment.
"""
merged_cfgs = {}
if cfgs is not None:
for cfg, cfg_prefix in cfgs:
# skip empty configs
if cfg is None:
continue
if isinstance(cfg, JsonNamespace):
converged_cfg = namespace_to_dict(cfg)
elif isinstance(cfg, dict):
converged_cfg = cfg
else:
raise ValueError(f"Expected dict or JsonNamespace, got {type(cfg)}")
for key, val in converged_cfg.items():
merged_cfgs[f"{cfg_prefix}.{key}"] = val
else:
merged_cfgs = namespace_to_dict(self.sandbox_cfg)
wandb.config.update(merged_cfgs)
[docs]
class Target:
SUPPORT_OPS = ["==", ">=", "<=", ">", "<"]
key: str
op: str
tgt_val: float
[docs]
def __init__(self, iter_target_str: str = None, key: str = None, op: str = None, tgt_val: float = None):
if iter_target_str is not None:
self.parse_iter_targets(iter_target_str)
else:
self.key = key
self.op = op
self.tgt_val = tgt_val
[docs]
def parse_iter_targets(self, iter_target_str):
for op in self.SUPPORT_OPS:
if op in iter_target_str:
target_key, target_value = [s.strip() for s in iter_target_str.split(op)]
# check if the target value is a number
try:
target_value = float(target_value)
except: # noqa: E722
logger.error(
f"Invalid iter_targets [{iter_target_str}]: The target value [{target_value}] "
"is not a valid number."
)
exit(1)
self.key = target_key
self.op = op
self.tgt_val = target_value
break
else:
logger.error(
f"Invalid iter_targets [{iter_target_str}]: No valid comparators are found."
f"Only support {self.SUPPORT_OPS}"
)
exit(1)
[docs]
def check_target(self, context_infos: ContextInfos):
curr_val = context_infos[self.key]
try:
logger.debug(f"Checking {curr_val} {self.op} {self.tgt_val}")
ret = eval(f"{curr_val} {self.op} {self.tgt_val}")
except: # noqa: E722
logger.error(f"Invalid iter_targets [{str(self)}]: The target value [{curr_val}] " "is not a valid number.")
return False
return ret
def __str__(self):
return f"{self.key} {self.op} {self.tgt_val}"
[docs]
class SandboxPipeline:
[docs]
def __init__(self, pipeline_name="anonymous", pipeline_cfg=None, watcher=None):
"""
Initialization method.
"""
self.name = pipeline_name
self.cfg = pipeline_cfg
self.watcher = watcher
# jobs to probe, refine_recipe, execution and evaluation for
# interested data and model within the sandbox
self.probe_jobs = []
self.refine_recipe_jobs = []
self.execution_jobs = []
self.evaluation_jobs = []
self.register_jobs()
[docs]
def register_jobs(self):
# register probe_jobs
for job_cfg in self.cfg.get("probe_job_configs", []):
self.probe_jobs.append(register_hook(job_cfg, self.watcher))
# register refine_recipe_jobs
for job_cfg in self.cfg.get("refine_recipe_job_configs", []):
self.refine_recipe_jobs.append(register_hook(job_cfg, self.watcher))
# register execution_jobs
for job_cfg in self.cfg.get("execution_job_configs", []):
self.execution_jobs.append(register_hook(job_cfg, self.watcher))
# register evaluation_jobs
for job_cfg in self.cfg.get("evaluation_job_configs", []):
self.evaluation_jobs.append(register_hook(job_cfg, self.watcher))
[docs]
def run(self, context_infos: ContextInfos):
"""
Running the sandbox pipeline at once or in HPO style.
"""
if self.cfg.hpo_config is not None:
# execute_hpo_wandb contains running one_trail with HPO scheduler
return self.execute_hpo_wandb(context_infos)
else:
return self.one_trial(context_infos)
[docs]
def one_trial(self, context_infos: ContextInfos):
"""
Running the sandbox pipeline at once.
Users can flexibly conduct some steps of the whole sandbox pipeline
according to their own need and configuration. The watcher will
automatically track the results in terms of data, model and specified
evaluation metrics to the watcher.
"""
# TODO: how the hpo work
if self.watcher.object_name_in_hpo is not None:
# merge the new hyper-parameters produced by HPO scheduler
self.cfg = merge_config(self.cfg, wandb.config)
self.watcher.watch_cfgs([self.cfg, "after_hpo"])
if self.name in context_infos.pipeline_names:
raise ValueError(f"There are different pipelines with the same pipeline name {self.name}.")
pipeline_infos = PipelineInfos(self.name)
context_infos.record_pipeline_infos(pipeline_infos)
# ====== Data & model probe ======
for probe_hook in self.probe_jobs:
logger.info(
f"======= Iter [{context_infos.iter}] - Pipeline [{self.name}]: Start Probe Hook [{probe_hook.meta_name}] ======="
)
new_job_infos = probe_hook.run(context_infos)
context_infos[self.name].record_job_infos(new_job_infos)
logger.debug(f"Context Infos: {context_infos.to_dict()}")
# ====== Data-model recipes iteration based on probe results ======
for refine_hook in self.refine_recipe_jobs:
logger.info(
f"======= Iter [{context_infos.iter}] - Pipeline [{self.name}]: Start Refine Hook [{refine_hook.meta_name}] ======="
)
new_job_infos = refine_hook.run(context_infos)
context_infos[self.name].record_job_infos(new_job_infos)
logger.debug(f"Context Infos: {context_infos.to_dict()}")
# ====== Data processing & model training ======
for exec_hook in self.execution_jobs:
logger.info(
f"======= Iter [{context_infos.iter}] - Pipeline [{self.name}]: Start Execution Hook [{exec_hook.meta_name}] ======="
)
new_job_infos = exec_hook.run(context_infos)
context_infos[self.name].record_job_infos(new_job_infos)
logger.debug(f"Context Infos: {context_infos.to_dict()}")
# ====== Evaluation on processed data or trained model ======
for eval_hook in self.evaluation_jobs:
logger.info(
f"======= Iter [{context_infos.iter}] - Pipeline [{self.name}]: Start Evaluation Hook [{eval_hook.meta_name}] ======="
)
new_job_infos = eval_hook.run(context_infos)
context_infos[self.name].record_job_infos(new_job_infos)
logger.debug(f"Context Infos: {context_infos.to_dict()}")
return context_infos
[docs]
def execute_hpo_wandb(self, context_infos):
"""
Running the sandbox pipeline in HPO style.
Users can flexibly conduct some steps of the whole sandbox pipeline
according to their own need and configuration. The watcher will
automatically track the results in terms of data, model and specified
evaluation metrics to the watcher.
"""
with open(self.cfg.hpo_config) as file:
hpo_configuration = yaml.safe_load(file)
sweep_id = self.watcher.setup_sweep(hpo_configuration)
wandb.agent(
sweep_id,
function=self.one_trial,
count=hpo_configuration["sweep_max_count"] if "sweep_max_count" in hpo_configuration else None,
)
return None
[docs]
class SandBoxExecutor:
"""
This SandBoxExecutor class is used to provide a sandbox environment for
exploring data-model co-designs in a one-stop manner with fast feedback
and tiny model size, small data size, and high efficiency.
It plays as a middleware maintains the data-juicer's data executor,
a model processor (training and inference), and an auto-evaluator,
where the latter two ones are usually from third-party libraries.
"""
[docs]
def __init__(
self,
cfg=None,
):
"""
Initialization method.
:param cfg: configuration of sandbox.
"""
self.cfg = cfg
self.watcher = SandBoxWatcher(self.cfg)
self.watcher.watch_cfgs([(cfg, "sandbox")])
self.pipelines = self.parse_pipelines(self.cfg)
self.resume = self.cfg.get("resume", False)
# iterative related
self.max_iter_num = self.cfg.get("max_iter_num", 1)
init_targets = self.cfg.get("iter_targets", [])
# if both of them are not set
if self.max_iter_num < 0:
logger.error(f"Argument 'max_iter_num' must be 0 or a positive number. Got [{self.max_iter_num}].")
exit(1)
if not isinstance(init_targets, list):
init_targets = [init_targets]
if self.max_iter_num == 0 and len(init_targets) == 0:
logger.error(
"Either 'max_iter_num' must be > 0 or 'iter_targets' must be set. "
"If you want to run the pipeline without iterative, please leave both arguments at their default values"
" or set 'max_iter_num' to 1."
)
exit(1)
init_targets = [Target(iter_target_str=iter_target_str) for iter_target_str in init_targets]
self.iter_targets = []
for target in init_targets:
if not validate_hook_output(self.pipelines, target.key):
logger.error(
f"Invalid iter_targets [{str(target)}]: "
f"The target metric key [{target.key}] can not found in the pipelines."
)
self.iter_targets.append(target)
self.iter_targets_mode = self.cfg.get("iter_targets_mode", "all")
# iterative updater for config arguments
self.iter_updater = self.cfg.get("iter_updater", {})
[docs]
def parse_pipelines(self, cfg):
"""
Parse the pipeline configs.
:param cfg: the original config
:return: a list of SandBoxPipeline objects.
"""
pipelines = []
pipeline_keys = [
"pipelines",
"probe_job_configs",
"refine_recipe_job_configs",
"execution_job_configs",
"evaluation_job_configs",
]
global_cfgs = deepcopy(cfg)
for pipeline_key in pipeline_keys:
if pipeline_key in global_cfgs:
global_cfgs.pop(pipeline_key)
if cfg.pipelines:
# specify the pipelines
for pipeline in cfg.pipelines:
pipeline_name, pipeline_cfg = list(pipeline.items())[0]
pipeline_cfg.update(global_cfgs)
pipelines.append(SandboxPipeline(pipeline_name, self.specify_jobs_configs(pipeline_cfg), self.watcher))
else:
pipeline = SandboxPipeline(pipeline_cfg=self.specify_jobs_configs(cfg), watcher=self.watcher)
pipelines.append(pipeline)
return pipelines
[docs]
def iterative_update_pipelines(self, current_pipelines: List[SandboxPipeline], last_context_infos: ContextInfos):
if current_pipelines is None:
return None
if last_context_infos is None or len(last_context_infos) == 0:
return current_pipelines
# get the pipeline configs
for from_key, target_key in self.iter_updater.items():
from_value = last_context_infos[from_key]
if from_value is not None:
cfg_levels = target_key.split(".")
if len(cfg_levels) < 4:
raise ValueError(
f"The target key [{target_key}] must be in the format of "
f"<pipeline_name>.<hook_meta_name>.[extra_configs|dj_configs].<hook_cfg_key1>[.<hook_cfg_keyn>]."
)
tgt_pipeline_name = cfg_levels[0]
tgt_hook_meta_name = cfg_levels[1]
tgt_local_key = ".".join(cfg_levels[2:])
for i in range(len(current_pipelines)):
current_pipeline = current_pipelines[i]
if current_pipeline.name == tgt_pipeline_name:
all_hooks = (
current_pipeline.probe_jobs
+ current_pipeline.refine_recipe_jobs
+ current_pipeline.execution_jobs
+ current_pipeline.evaluation_jobs
)
for hook in all_hooks:
if hook.meta_name == tgt_hook_meta_name:
# put the updated configs key/values into the local settings
hook.local_settings[tgt_local_key] = from_value
current_pipelines[i] = current_pipeline
else:
logger.warning(f"The iter_updater [{from_key}] is not found in the last context infos.")
return current_pipelines
[docs]
def specify_job_configs(self, ori_config):
config = prepare_side_configs(ori_config)
for key in JobRequiredKeys:
if key.value not in config:
logger.debug(f'The key "{key.value}" is not specified in {ori_config}')
return dict_to_namespace(config)
[docs]
def specify_jobs_configs(self, cfg):
"""
Specify job configs by their dict objects or config file path strings.
:param cfg: the original config
:return: a dict of different configs.
"""
def configs_to_job_list(cfgs):
job_cfgs = []
if cfgs:
job_cfgs = [self.specify_job_configs(job_cfg) for job_cfg in cfgs]
return job_cfgs
if isinstance(cfg, dict):
cfg = dict_to_namespace(cfg)
if "probe_job_configs" in cfg:
cfg.probe_job_configs = configs_to_job_list(cfg.probe_job_configs)
if "refine_recipe_job_configs" in cfg:
cfg.refine_recipe_job_configs = configs_to_job_list(cfg.refine_recipe_job_configs)
if "execution_job_configs" in cfg:
cfg.execution_job_configs = configs_to_job_list(cfg.execution_job_configs)
if "evaluation_job_configs" in cfg:
cfg.evaluation_job_configs = configs_to_job_list(cfg.evaluation_job_configs)
return cfg
[docs]
def run(self):
context_infos_path = os.path.join(self.cfg.work_dir, "context_infos.json")
num_pipeline_skip = 0
last_context_infos = ContextInfos(iter=0)
if self.resume and os.path.exists(context_infos_path):
# load context infos from the existing one
context_infos_list = json.load(open(context_infos_path, "r"))
context_infos_list = GlobalContextInfos.from_list(context_infos_list)
current_iter = len(context_infos_list)
if current_iter == 0:
logger.info("The context infos file is empty. Start from the first iter.")
else:
logger.info(f"Continue from the iter {current_iter}.")
current_iter -= 1
last_context_infos = context_infos_list[-1]
context_infos_list = context_infos_list[:-1]
# find those finished pipelines
finished_pipelines = set(last_context_infos.pipeline_names)
for pipeline in self.pipelines:
# check if the pipeline is already existing in the context infos
if pipeline.name in finished_pipelines:
# check if the number of job infos is the same as the number of all kinds of jobs,
# which means all jobs are finished
num_job_infos = len(last_context_infos[pipeline.name])
num_jobs = (
len(pipeline.probe_jobs)
+ len(pipeline.refine_recipe_jobs)
+ len(pipeline.execution_jobs)
+ len(pipeline.evaluation_jobs)
)
if num_job_infos == num_jobs:
logger.info(
f"Pipeline {pipeline.name} is finished and loaded from the existing context infos. Skip it!"
)
num_pipeline_skip += 1
continue
else:
context_infos_list = GlobalContextInfos()
current_iter = 0
try:
current_pipelines = deepcopy(self.pipelines)
while True:
current_iter += 1
logger.info(f"============== Starting the iter {current_iter} ==============")
if num_pipeline_skip > 0:
context_infos = last_context_infos
else:
context_infos = ContextInfos(iter=current_iter)
for pipeline in current_pipelines:
if num_pipeline_skip > 0:
num_pipeline_skip -= 1
continue
context_infos = pipeline.run(context_infos)
context_infos_list.record_context_infos(context_infos)
# check if the pipelines reach the max number of iterations
if 0 < self.max_iter_num <= current_iter:
break
# check if the running meet the targets
if len(self.iter_targets) > 0:
curr_target_results = [iter_target.check_target(context_infos) for iter_target in self.iter_targets]
if self.iter_targets_mode == "all":
if all(curr_target_results):
logger.info("All targets are satisfied.")
break
elif self.iter_targets_mode == "any":
if any(curr_target_results):
satisfied_idxes = [
idx for idx, curr_target_result in enumerate(curr_target_results) if curr_target_result
]
satisfied_targets = [str(self.iter_targets[idx]) for idx in satisfied_idxes]
logger.info(f"Targets {satisfied_targets} are satisfied.")
break
# check if there are any arguments to be updated from the last iteration
if len(self.iter_updater) > 0:
logger.info("Updating arguments across iterations...")
current_pipelines = deepcopy(self.pipelines)
current_pipelines = self.iterative_update_pipelines(current_pipelines, context_infos)
finally:
# export context infos
with open(context_infos_path, "w") as fout:
json.dump(context_infos_list.to_list(), fout, indent=4)