Source code for data_juicer.utils.fingerprint_utils

from typing import Any, Dict, List, Union

import dill
import xxhash
from datasets.fingerprint import (
    _CACHING_ENABLED,
    fingerprint_warnings,
    format_kwargs_for_fingerprint,
    format_transform_for_fingerprint,
    generate_random_fingerprint,
    validate_fingerprint,
)
from loguru import logger


[docs] class Hasher: """Hasher that accepts python objects as inputs.""" dispatch: Dict = {}
[docs] def __init__(self): self.m = xxhash.xxh64()
[docs] @classmethod def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str: value = [value] if isinstance(value, bytes) else value m = xxhash.xxh64() for x in value: m.update(x) return m.hexdigest()
@classmethod def _find_op_owner(cls, value): """Walk the ``__self__`` / ``__wrapped__`` chain to find an object that exposes ``_fingerprint_bytes``. Returns ``(obj, func_name)`` or ``(None, None)``.""" # Direct bound method obj = getattr(value, "__self__", None) if obj is not None: if callable(getattr(obj, "_fingerprint_bytes", None)): func_name = getattr(value, "__name__", getattr(value, "__qualname__", "")) return obj, func_name # Walk the full __wrapped__ chain (handles multiple decorator # layers such as wrap_func_with_nested_access → @wraps → bound # method). cur = value for _ in range(10): # guard against infinite loops cur = getattr(cur, "__wrapped__", None) if cur is None: break obj = getattr(cur, "__self__", None) if obj is not None and callable(getattr(obj, "_fingerprint_bytes", None)): func_name = getattr(cur, "__name__", getattr(cur, "__qualname__", "")) return obj, func_name return None, None
[docs] @classmethod def hash_default(cls, value: Any) -> str: """ Use dill to serialize objects to avoid serialization failures. If the object exposes a ``_fingerprint_bytes()`` method (e.g. OP subclasses), use it so that execution-only attributes like ``work_dir`` are excluded from the cache key. """ fingerprint_bytes = getattr(value, "_fingerprint_bytes", None) if callable(fingerprint_bytes): return cls.hash_bytes(fingerprint_bytes()) # For bound methods / wrapped functions whose __self__ supports # _fingerprint_bytes, hash the (fingerprint, method_name) pair # instead of dill-dumping the bound method (which would # re-serialize the full object including excluded attrs). obj, func_name = cls._find_op_owner(value) if obj is not None: return cls.hash_bytes(obj._fingerprint_bytes() + dill.dumps(func_name)) return cls.hash_bytes(dill.dumps(value))
[docs] @classmethod def hash(cls, value: Any) -> str: if type(value) in cls.dispatch: return cls.dispatch[type(value)](cls, value) else: return cls.hash_default(value)
[docs] def update(self, value: Any) -> None: header_for_update = f"=={type(value)}==" value_for_update = self.hash(value) self.m.update(header_for_update.encode("utf8")) self.m.update(value_for_update.encode("utf-8"))
[docs] def hexdigest(self) -> str: return self.m.hexdigest()
[docs] def update_fingerprint(fingerprint, transform, transform_args): """ Combining various objects to update the fingerprint. """ hasher = Hasher() hasher.update(fingerprint) try: hasher.update(transform) except: # noqa various errors might raise here from pickle or dill if _CACHING_ENABLED: if not fingerprint_warnings.get("update_fingerprint_transform_hash_failed", False): logger.warning( f"Transform {transform} couldn't be hashed properly, \ a random hash was used instead. Make sure your \ transforms and parameters are serializable with \ pickle or dill for the dataset fingerprinting and \ caching to work. If you reuse this transform, the \ caching mechanism will consider it to be different \ from the previous calls and recompute everything. \ This warning is only showed once. Subsequent hashing \ failures won't be showed." ) fingerprint_warnings["update_fingerprint_transform_hash_failed"] = True else: logger.info( f"Transform {transform} couldn't be hashed properly, \ a random hash was used instead." ) else: logger.info( f"Transform {transform} couldn't be hashed properly, a \ random hash was used instead. This doesn't affect caching \ since it's disabled." ) return generate_random_fingerprint() for key in sorted(transform_args): hasher.update(key) try: hasher.update(transform_args[key]) except: # noqa various errors might raise here from pickle or dill if _CACHING_ENABLED: if not fingerprint_warnings.get("update_fingerprint_transform_hash_failed", False): logger.warning( f"Parameter '{key}'={transform_args[key]} of the \ transform {transform} couldn't be hashed properly, \ a random hash was used instead. Make sure your \ transforms and parameters are serializable with \ pickle or dill for the dataset fingerprinting and \ caching to work. If you reuse this transform, the \ caching mechanism will consider it to be different \ from the previous calls and recompute everything. \ This warning is only showed once. Subsequent hashing \ failures won't be showed." ) fingerprint_warnings["update_fingerprint_transform_hash_failed"] = True else: logger.info( f"Parameter '{key}'={transform_args[key]} of the \ transform {transform} couldn't be hashed properly, \ a random hash was used instead." ) else: logger.info( f"Parameter '{key}'={transform_args[key]} of the transform \ {transform} couldn't be hashed properly, a random hash \ was used instead. This doesn't affect caching since it's \ disabled." ) return generate_random_fingerprint() return hasher.hexdigest()
[docs] def generate_fingerprint(ds, *args, **kwargs): """ Generate new fingerprints by using various kwargs of the dataset. """ if args: args = list(args) dataset_kwargs = {"shard": ds, "function": args[0]} else: dataset_kwargs = {"shard": ds} dataset_kwargs.update(kwargs) # we create a unique hash from the function, # current dataset file and the mapping args transform = format_transform_for_fingerprint(ds._map_single) kwargs_for_fingerprint = format_kwargs_for_fingerprint(ds._map_single, (), dataset_kwargs) kwargs_for_fingerprint["fingerprint_name"] = "new_fingerprint" new_fingerprint = update_fingerprint(ds._fingerprint, transform, kwargs_for_fingerprint) validate_fingerprint(new_fingerprint) return new_fingerprint