import json
import random
import re
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import (
get_model,
prepare_model,
update_sampling_params,
)
from data_juicer.utils.ray_utils import is_ray_mode
from ..base_op import OPERATORS, Mapper
torch = LazyLoader("torch")
vllm = LazyLoader("vllm")
rouge = LazyLoader("rouge")
OP_NAME = "generate_qa_from_examples_mapper"
# TODO: Extend LLM-based OPs into API-based implementation.
[docs]
@OPERATORS.register_module(OP_NAME)
class GenerateQAFromExamplesMapper(Mapper):
"""Generates question and answer pairs from examples using a Hugging Face model.
This operator generates QA pairs based on provided seed examples. The number of
generated samples is determined by the length of the empty dataset configured in the
YAML file. The operator uses a Hugging Face model to generate new QA pairs, which are
then filtered based on their similarity to the seed examples. Samples with a similarity
score below the specified threshold are kept. The similarity is computed using the
ROUGE-L metric. The operator requires a seed file in chatml format, which provides the
initial QA examples. The generated QA pairs must follow specific formatting rules, such
as maintaining the same format as the input examples and ensuring that questions and
answers are paired correctly."""
DEFAULT_SYSTEM_PROMPT = (
"่ฏทไฝ ไป็ป่งๅฏๅคไธช็คบไพๆฐๆฎ็่พๅ
ฅๅ่พๅบ๏ผๆ็
งไฝ ็็่งฃ๏ผๆป็ปๅบ็ธๅบ่ง็ฉ๏ผ็ถๅๅๅบไธไธชๆฐ็ใ้ฎ้ขใๅใๅ็ญใใ"
"ๆณจๆ๏ผๆฐ็ๆ็ใ้ฎ้ขใๅใๅ็ญใ้่ฆๆปก่ถณๅฆไธ่ฆๆฑ๏ผ\n"
"1. ็ๆ็ใ้ฎ้ขใๅใๅ็ญใไธ่ฝไธ่พๅ
ฅ็ใ้ฎ้ขใๅใๅ็ญใไธ่ด๏ผไฝๆฏ้่ฆไฟๆๆ ผๅผ็ธๅใ\n"
"2. ็ๆ็ใ้ฎ้ขใไธไธๅฎ่ฆๅฑ้ไบ่พๅ
ฅใ้ฎ้ขใ็่ฏ้ขๆ้ขๅ๏ผ็ๆ็ใๅ็ญใ้่ฆๆญฃ็กฎๅ็ญ็ๆ็ใ้ฎ้ขใใ\n"
"3. ๆไพ็ใ้ฎ้ขใๅใๅ็ญใๅฏ่ฝๆฏๅค่ฝฎๅฏน่ฏ๏ผ็ๆ็ใ้ฎ้ขใๅใๅ็ญใไนๅฏไปฅๆฏๅค่ฝฎ๏ผไฝๆฏ้่ฆไฟๆๆ ผๅผ็ธๅใ\n"
"4. ็ๆ็ใ้ฎ้ขใๅใๅ็ญใๅฟ
้กปๆๅฏนๅบ็ฐ๏ผ่ไธใ้ฎ้ขใ้่ฆๅจใๅ็ญใไนๅใ\n"
)
DEFAULT_INPUT_TEMPLATE = "{}"
DEFAULT_EXAMPLE_TEMPLATE = "\nๅฆไธๆฏไธๆก็คบไพๆฐๆฎ๏ผ\n{}"
DEFAULT_QA_PAIR_TEMPLATE = "ใ้ฎ้ขใ\n{}\nใๅ็ญใ\n{}\n"
DEFAULT_OUTPUT_PATTERN = r"ใ้ฎ้ขใ(.*?)ใๅ็ญใ(.*?)(?=ใ้ฎ้ขใ|$)"
_accelerator = "cuda"
[docs]
def __init__(
self,
hf_model: str = "Qwen/Qwen2.5-7B-Instruct",
*,
seed_file: str = "",
example_num: PositiveInt = 3,
similarity_threshold: float = 0.7,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
example_template: Optional[str] = None,
qa_pair_template: Optional[str] = None,
output_pattern: Optional[str] = None,
enable_vllm: bool = False,
model_params: Optional[Dict] = None,
sampling_params: Optional[Dict] = None,
**kwargs,
):
"""
Initialization method.
:param hf_model: Huggingface model ID.
:param seed_file: Path to the seed file in chatml format.
:param example_num: The number of selected examples.
Randomly select N examples from "seed_file" and
put them into prompt as QA examples.
:param similarity_threshold: The similarity score threshold
between the generated samples and the seed examples.
Range from 0 to 1. Samples with similarity score less than
this threshold will be kept.
:param system_prompt: System prompt for guiding the generation task.
:param input_template: Template for building the input prompt. It must
include one placeholder '{}', which will be replaced by
`example_num` formatted examples defined by `example_template`.
:param example_template: Template for formatting one QA example. It
must include one placeholder '{}', which will be replaced by one
formatted qa_pair.
:param qa_pair_template: Template for formatting a single QA pair
within each example. Must include two placeholders '{}' for the
question and answer.
:param output_pattern: Regular expression pattern to extract questions
and answers from model response.
:param enable_vllm: Whether to use vllm for inference acceleration.
:param model_params: Parameters for initializing the model.
:param sampling_params: Sampling parameters for text generation.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
if not seed_file:
raise ValueError(
"Please provide `seed_file` in chatml format."
"Example: data-juicer/demos/data/demo-dataset-chatml.jsonl"
)
self.seed_file = seed_file
self.example_num = example_num
self.similarity_threshold = similarity_threshold
self.similarity_type = "rouge_l"
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.example_template = example_template or self.DEFAULT_EXAMPLE_TEMPLATE # noqa: E501
self.qa_pair_template = qa_pair_template or self.DEFAULT_QA_PAIR_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.enable_vllm = enable_vllm
model_params = model_params or {}
sampling_params = sampling_params or {}
sampling_params = update_sampling_params(sampling_params, hf_model, self.enable_vllm)
if enable_vllm:
if not is_ray_mode():
# cannot initialize vllm replicas on different GPUs
self.num_proc = 1
self.model_key = prepare_model(model_type="vllm", pretrained_model_name_or_path=hf_model, **model_params)
self.sampling_params = vllm.SamplingParams(**sampling_params)
else:
self.model_key = prepare_model(
model_type="huggingface", pretrained_model_name_or_path=hf_model, return_pipe=True, **model_params
)
self.sampling_params = sampling_params
self.seed_qa_samples = self._load_seed_qa_samples()
if len(self.seed_qa_samples) == 0:
raise ValueError("No QA data was parsed from the seed file!")
def _load_seed_qa_samples(self):
"""Load QA pairs from chatml format file."""
qa_samples = []
with open(self.seed_file, encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
line = line.strip()
qa_pairs = self._parse_chatml_str(line)
if len(qa_pairs) > 0:
qa_samples.append(qa_pairs)
return qa_samples
def _sample_to_str(self, qa_sample):
return "\n".join(["\n".join(qa_pair) for qa_pair in qa_sample]) + "\n"
def _max_rouge_l_score(self, hypothesis, references):
r = rouge.Rouge()
max_score = 0.0
hyp_str = self._sample_to_str(hypothesis)
for reference in references:
ref_str = self._sample_to_str(reference)
scores = r.get_scores(hyp_str, ref_str)
rouge_l_score = scores[0]["rouge-l"]["f"]
if rouge_l_score > max_score:
max_score = rouge_l_score
return max_score
def _parse_chatml_str(self, sample_str):
user_input = None
assistant_output = None
qa_pairs = []
data = json.loads(sample_str)
for message in data["messages"]:
role = message["role"]
content = message["content"]
if role == "user":
user_input = content
elif role == "assistant":
assistant_output = content
qa_pairs.append((user_input, assistant_output))
return qa_pairs
[docs]
def parse_output(self, raw_output):
logger.debug(raw_output)
output_qa_pairs = []
matches = re.findall(self.output_pattern, raw_output, re.DOTALL)
for match in matches:
question, answer = match
output_qa_pairs.append((question.strip(), answer.strip()))
return output_qa_pairs
[docs]
def process_single(self, sample, rank=None):
model, _ = get_model(self.model_key, rank, self.use_cuda())
random_qa_samples = random.sample(self.seed_qa_samples, self.example_num)
input_prompt = self.build_input(random_qa_samples)
messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": input_prompt}]
if self.enable_vllm:
response = model.chat(messages, self.sampling_params)
output = response[0].outputs[0].text
else:
# model is pipe
response = model(messages, return_full_text=False, **self.sampling_params)
output = response[0]["generated_text"]
output_qa_pairs = self.parse_output(output)
if len(output_qa_pairs) == 0:
logger.warning("Parse model response error! " "No data generated for the current response!")
sample.update({self.query_key: "", self.response_key: "", self.history_key: self.empty_history()})
return sample
if self.similarity_type == "rouge_l":
sim_score = self._max_rouge_l_score(output_qa_pairs, random_qa_samples)
else:
raise ValueError(f'Not support similarity type "{self.similarity_type}"!')
if sim_score <= self.similarity_threshold:
query, response = output_qa_pairs[-1]
history = output_qa_pairs[:-1]
if len(history) == 0:
history = self.empty_history()
else:
query = response = ""
history = self.empty_history()
logger.info("Filter this generated sample due to similarity.")
sample.update({self.query_key: query, self.response_key: response, self.history_key: history})
return sample