Source code for data_juicer.ops.deduplicator.document_line_deduplicator
import hashlib
import re
import string
from collections import defaultdict
from typing import Dict, Set
from data_juicer.utils.constant import HashKeys
from ..base_op import OPERATORS, Deduplicator
BRACKETS_ONLY = frozenset("{}[]();")
LATEX_ENV_RE = re.compile(r"^\s*\\(begin|end)\{")
HTML_TAG_RE = re.compile(r"^\s*</?[a-z][a-z0-9]*[^>]*>?\s*$", re.IGNORECASE)
SPECIAL_CHAR_RE = re.compile(rf"\s+|\d+|[{re.escape(string.punctuation)}]")
[docs]
@OPERATORS.register_module("document_line_deduplicator")
class DocumentLineDeduplicator(Deduplicator):
"""Deduplicates at the line level across documents.
This operator identifies lines that appear in many documents (boilerplate
text, copyright notices, navigation bars, etc.) and removes them. It works
in two phases:
1. **compute_hash** â splits each document into lines, applies configurable
skip rules, and computes an MD5 hash for every non-skipped line.
2. **process** â counts in how many *distinct* documents each line hash
appears. Lines whose document frequency exceeds
``frequency_threshold`` are removed from every document.
"""
[docs]
def __init__(
self,
frequency_threshold: int = 6,
lowercase: bool = False,
ignore_special_character: bool = False,
min_line_length: int = 2,
skip_brackets: bool = True,
skip_markdown_headers: bool = True,
skip_latex_env: bool = True,
skip_html_tags: bool = True,
*args,
**kwargs,
):
"""
Initialization method.
:param frequency_threshold: document-frequency threshold. Lines
appearing in **more than** this many documents are removed.
:param lowercase: whether to lower-case a line before hashing.
:param ignore_special_character: whether to strip whitespace, digits,
and punctuation before hashing.
:param min_line_length: lines whose stripped length is below this
value are skipped (never considered for dedup).
:param skip_brackets: skip lines consisting solely of bracket /
semicolon characters such as ``{ } [ ] ( ) ;``.
:param skip_markdown_headers: skip lines that start with ``#``
(Markdown headings).
:param skip_latex_env: skip LaTeX ``\\begin{âĻ}`` / ``\\end{âĻ}``
environment declarations.
:param skip_html_tags: skip lines that are pure HTML / XML tags.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.frequency_threshold = frequency_threshold
self.lowercase = lowercase
self.ignore_special_character = ignore_special_character
self.min_line_length = min_line_length
self.skip_brackets = skip_brackets
self.skip_markdown_headers = skip_markdown_headers
self.skip_latex_env = skip_latex_env
self.skip_html_tags = skip_html_tags
# ------------------------------------------------------------------
# Skip-rule helpers
# ------------------------------------------------------------------
def _should_skip_line(self, line: str) -> bool:
"""Return True if *line* should be exempt from deduplication."""
stripped = line.strip()
if len(stripped) < self.min_line_length:
return True
if self.skip_brackets and all(ch in BRACKETS_ONLY for ch in stripped):
return True
if self.skip_markdown_headers and stripped.startswith("#"):
return True
if self.skip_latex_env and LATEX_ENV_RE.match(stripped):
return True
if self.skip_html_tags and HTML_TAG_RE.match(stripped):
return True
return False
# ------------------------------------------------------------------
# Phase 1 â per-document hash computation
# ------------------------------------------------------------------
[docs]
def compute_hash(self, sample):
"""
Compute per-line MD5 hashes for a single document.
Skipped lines receive an empty-string hash so that the list of
hashes stays aligned with the original lines.
:param sample: input sample
:return: sample with ``HashKeys.line_hashes`` populated.
"""
if HashKeys.line_hashes in sample:
return sample
text = sample[self.text_key]
lines = text.split("\n")
line_hashes = []
for line in lines:
if self._should_skip_line(line):
line_hashes.append("")
else:
norm = line
if self.lowercase:
norm = norm.lower()
if self.ignore_special_character:
norm = SPECIAL_CHAR_RE.sub("", norm)
md5 = hashlib.md5(norm.strip().encode("utf-8")).hexdigest()
line_hashes.append(md5)
sample[HashKeys.line_hashes] = line_hashes
return sample
# ------------------------------------------------------------------
# Phase 2 â global frequency counting & line removal
# ------------------------------------------------------------------
[docs]
def process(self, dataset, show_num=0):
"""
Remove high-frequency lines from the dataset.
:param dataset: input dataset (already hash-annotated).
:param show_num: number of traced duplicate pairs for inspection.
:return: (dataset, dup_pairs) where *dup_pairs* maps a line hash
to sample texts that contained it.
"""
if len(dataset) <= 1:
return dataset, {}
# --- count document frequency for each line hash ---------------
doc_freq: Dict[str, int] = defaultdict(int)
for row in dataset:
unique_hashes = {h for h in row[HashKeys.line_hashes] if h}
for h in unique_hashes:
doc_freq[h] += 1
high_freq: Set[str] = {h for h, cnt in doc_freq.items() if cnt > self.frequency_threshold}
if not high_freq:
# nothing to remove â drop the temporary column and return
dataset = dataset.remove_columns([HashKeys.line_hashes])
return dataset, {}
# --- optionally collect duplicate pairs for tracing ------------
dup_pairs: dict = {}
if show_num > 0:
sorted_hashes = sorted(
[(h, doc_freq[h]) for h in high_freq],
key=lambda x: x[1],
reverse=True,
)[:show_num]
dup_pairs = {h: [] for h, _ in sorted_hashes}
# --- remove high-frequency lines from every document -----------
def _remove_high_freq_lines(sample):
line_hashes = sample[HashKeys.line_hashes]
text = sample[self.text_key]
lines = text.split("\n")
kept_lines = []
for line, h in zip(lines, line_hashes):
if h and h in high_freq:
# collect tracing info
if show_num > 0 and h in dup_pairs:
if len(dup_pairs[h]) < 2:
dup_pairs[h].append(sample[self.text_key])
else:
kept_lines.append(line)
sample[self.text_key] = "\n".join(kept_lines)
return sample
dataset = dataset.map(
_remove_high_freq_lines,
num_proc=self.runtime_np() if show_num == 0 else 1,
load_from_cache_file=False if show_num > 0 else True,
)
# clean up temporary column
dataset = dataset.remove_columns([HashKeys.line_hashes])
return dataset, dup_pairs