data_juicer.tools.quality_classifier.predict 源代码
# This tool is used for predicting a document score for text samples using
# quality classifier models we provided, including:
# - gpt3: A GPT3 quality classifier reproduced from scratch by us based on
# PySpark. It's trained over CC as negative samples and Wikipedia-en,
# Books, OpenWebText as positive samples.
# - chinese: A quality classifier for Chinese. It's trained over Chinese
# texts sampled from CC as negative samples and Wudao, Wikipedia-zh as
# positive samples.
# - code: A quality classifier for codes. It's trained over code samples that
# have stars >= 1372 as positive samples and random samples from left
# data as negative samples. Stars count 1372 splits a nearly 700w subset
# with most stars.
# All these 3 classifiers are trained using the same training pipeline as GPT3
# based on PySpark but with different tokenizers and keeping methods:
# - gpt3: standard Tokenizer from spark & GPT3 keeping method based on pareto
# - chinese: sentencepiece tokenizer for Chinese & label
# - code: sentencepiece tokenizer for code & label
#
# This tool needs several arguments:
# - dataset_path: the path to the dataset you want to predict doc_scores for.
# - result_path: the path to store the predicted result dataset.
# - model: quality classifier name to apply. It's "gpt3" in default. You can
# use one of ["gpt3", "chinese", "code"] we provided, or you can set it
# to the path to your own model trained using the train.py tool.
# - tokenizer: what tokenizer to use to tokenize texts. It's None in default,
# which means using the standard Tokenizer of PySpark. You can use one of
# ["zh.sp.model", "code.sp.model"] we provided, or you can set it to the
# path to your own sentencepiece model.
# - keep_method: the method to label should_keep field for each sample. It's
# "gpt3" in default. Should be one of ["gpt3", "label"].
# - text_key: the field key name to hold texts to be classified. It's "text"
# in default.
# - overall_stats: whether to output an overall stats report on predicted
# document scores. It's False in default.
#
# Recommended arguments for provided trained models:
# - gpt3:
# - model: gpt3
# - tokenizer: None
# - keep_method: gpt3
# - chinese:
# - model: chinese
# - tokenizer: zh.sp.model
# - keep_method: label
# - code:
# - model: code
# - tokenizer: code.sp.model
# - keep_method: label
#
# Notice:
# 1. The configs of SparkSession in function init_spark can be modified to be
# more suitable for your own machine. See function init_spark in
# qc_utils.py.
# 2. Random factors are involved in "gpt3" model. So you might get different
# should_keep label in different running processes. But you should get
# same doc_score predictions in different running processes.
import os
import fire
from loguru import logger
from data_juicer.tools.quality_classifier.qc_utils import (
export_result,
init_spark,
load_dataset,
predict,
prepare_model,
)
[文档]
@logger.catch(reraise=True)
def predict_score(
dataset_path, result_path, model="gpt3", tokenizer=None, keep_method="gpt3", text_key="text", overall_stats=False
):
"""
Use specific quality classifier to predict document scores on your dataset
:param dataset_path: the path to the dataset you want to predict for
:param result_path: the path to store the predicted result dataset
:param model: quality classifier name to apply. It's "gpt3" in default. You
can use one of ["gpt3", "chinese", "code"] we provided, or you can set
it to the path to your own model trained using the train.py tool
:param tokenizer: what tokenizer to use to tokenize texts. It's None in
default, which means using the standard Tokenizer of PySpark. You can
use one of ["zh.sp.model", "code.sp.model"] we provided, or you can set
it to the path to your own sentencepiece model
:param keep_method: the method to label should_keep field for each sample.
It's "gpt3" in default. Should be one of ["gpt3", "label"]
:param text_key: the field key name to hold texts to be classified. It's
"text" in default
:param overall_stats: whether to output an overall stats report on
predicted document scores. It's False in default
:return:
None if overall_stats is False
average quality score of the document if overall_stats is True
"""
# set default tokenizers for default models
if model == "chinese":
tokenizer = "zh.sp.model"
keep_method = "label"
if model == "code":
tokenizer = "code.sp.model"
keep_method = "label"
if model == "gpt3":
tokenizer = None
keep_method = "gpt3"
# initialize a spark session
if "_JAVA_OPTIONS" in os.environ and "-Djava.net.preferIPv6Addresses=true" in os.environ["_JAVA_OPTIONS"]:
os.environ["_JAVA_OPTIONS"] = os.environ["_JAVA_OPTIONS"].replace(
"-Djava.net.preferIPv6Addresses=true", "-Djava.net.preferIPv6Addresses=false"
)
spark = init_spark()
# load the quality classifier model
model = prepare_model(model_name=model)
# load dataset
ds = load_dataset(spark, dataset_path, text_key=text_key)
# start to predict
pred = predict(model, ds, tokenizer=tokenizer, keep_method=keep_method)
# export prediction result to specific path
export_result(pred, result_path)
if overall_stats:
# generate overall statistics on doc scores
overall = pred.select("doc_score").toPandas().describe(include="all")
# export to result report file
overall.to_csv(os.path.join(result_path, "overall.csv"))
overall.to_markdown(os.path.join(result_path, "overall.md"))
return overall
if __name__ == "__main__":
fire.Fire(predict_score)