Source code for data_juicer.utils.webdataset_utils

import io
import json
from typing import Any, Dict, Optional, Union

import numpy as np
import PIL.Image

from data_juicer.utils.mm_utils import load_audio

_VIDEO_EXTENTIONS = ["mp4", "mov", "avi", "mkv"]
_AUDIO_EXTENTIONS = ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]
_IMAGE_EXTENTIONS = ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]


[docs] def read_file_as_bytes(file_path): with open(file_path, "rb") as f: return f.read()
def _load_image(value, format="PIL"): import numpy as np import PIL.Image if format == "PIL": return PIL.Image.open(io.BytesIO(value)) else: return np.asarray(PIL.Image.open(io.BytesIO(value))) def _custom_default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True): """A custom decoder for webdataset. Support multiple images list decoding. This handles common file extensions: .txt, .cls, .cls2, .jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl. These are the most common extensions used in webdataset. For other extensions, users can provide their own decoder. Args: sample: sample, modified in place """ sample = dict(sample) for key, value in sample.items(): extension = key.split(".")[-1] if key.startswith("__"): continue elif extension in ["txt", "text"]: sample[key] = value.decode("utf-8") elif extension in ["cls", "cls2"]: sample[key] = int(value.decode("utf-8")) elif extension in _IMAGE_EXTENTIONS: sample[key] = _load_image(value, format) elif extension in [s + "s" for s in _IMAGE_EXTENTIONS]: import pickle value = pickle.loads(value) sample[key] = [_load_image(v, format) for v in value] elif extension == "json": sample[key] = json.loads(value) elif extension == "npy": import numpy as np sample[key] = np.load(io.BytesIO(value)) elif extension == "mp": import msgpack sample[key] = msgpack.unpackb(value, raw=False) elif extension in ["pt", "pth"]: import torch sample[key] = torch.load(io.BytesIO(value)) elif extension in ["pickle", "pkl"]: import pickle sample[key] = pickle.loads(value) elif extension in _AUDIO_EXTENTIONS: sample[key] = load_audio(value) elif extension in [s + "s" for s in _AUDIO_EXTENTIONS]: import pickle sample[key] = [load_audio(v) for v in pickle.loads(value)] elif extension in _VIDEO_EXTENTIONS: import pickle value = pickle.loads(value) sample[key] = [_load_image(frame, format) for frame in value] elif extension in [s + "s" for s in _VIDEO_EXTENTIONS]: import pickle videos_frames_list = pickle.loads(value) videos_frames_decode = [] for video_frames in videos_frames_list: videos_frames_decode.append([_load_image(frame) for frame in video_frames]) # list in list sample[key] = videos_frames_decode return sample def _encode_image(value, extension): from ray.data._internal.datasource.webdataset_datasource import extension_to_format if isinstance(value, np.ndarray): value = PIL.Image.fromarray(value) elif isinstance(value, bytes): return value elif isinstance(value, str): return read_file_as_bytes(value) assert isinstance(value, PIL.Image.Image) stream = io.BytesIO() value.save(stream, format=extension_to_format.get(extension.lower(), extension)) return stream.getvalue() def _encode_audio(value): if isinstance(value, str): return read_file_as_bytes(value) elif isinstance(value, bytes): return value assert isinstance(value, bytes), f"value should be a bytes, got {type(value)}" return value def _custom_default_encoder(sample: Dict[str, Any], format: Optional[Union[str, bool]] = True): """A custom encoder for webdataset. In addition to the original encoding, it also supports encode image lists and byte type images. This handles common file extensions: .txt, .cls, .cls2, .jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl, .jpgs (images list), .jpegs (images list), .pngs (images list) .mp3 (audio) .mp3s (audios list) .mp4 (video frames list) .mp4s (multi videos frames list) and so on. Please note that the .mp4s extension is used to encode multi videos frames list, the data format should be list of list of frames: [ [video1_frame1, video1_frame2, ...], # video1 frames path or bytes [video2_frame1, video2_frame2, ...], # video2 frames path or bytes ... ] These are the most common extensions used in webdataset. For other extensions, users can provide their own encoder. Args: sample (Dict[str, Any]): sample """ sample = dict(sample) for key, value in sample.items(): extension = key.split(".")[-1] if key.startswith("__"): continue elif extension in ["txt"]: if isinstance(value, list): sample[key] = [v.encode("utf-8") for v in value] else: sample[key] = value.encode("utf-8") elif extension in ["cls", "cls2"]: sample[key] = str(value).encode("utf-8") elif extension in _IMAGE_EXTENTIONS: sample[key] = _encode_image(value, extension) elif extension in [s + "s" for s in _IMAGE_EXTENTIONS]: import pickle extension = extension.rstrip("s") sample[key] = pickle.dumps([_encode_image(v, extension) for v in value]) elif extension == "json": sample[key] = json.dumps(value).encode("utf-8") elif extension == "npy": import numpy as np stream = io.BytesIO() np.save(stream, value) sample[key] = stream.getvalue() elif extension == "mp": import msgpack sample[key] = msgpack.dumps(value) elif extension in ["pt", "pth"]: import torch stream = io.BytesIO() torch.save(value, stream) sample[key] = stream.getvalue() elif extension in ["pickle", "pkl"]: import pickle stream = io.BytesIO() pickle.dump(value, stream) sample[key] = stream.getvalue() elif extension in _AUDIO_EXTENTIONS: sample[key] = _encode_audio(value) elif extension in [s + "s" for s in _AUDIO_EXTENTIONS]: import pickle extension = extension.rstrip("s") sample[key] = pickle.dumps([_encode_audio(v) for v in value]) elif extension in _VIDEO_EXTENTIONS: import pickle extension = "jpg" sample[key] = pickle.dumps([_encode_image(frame, extension) for frame in value]) elif extension in [s + "s" for s in _VIDEO_EXTENTIONS]: import pickle extension = "jpg" videos_frames_list = value videos_frames_decode = [] for video_frames in videos_frames_list: cur_decode_frames = [] for frame in video_frames: if isinstance(frame, str): frame = _encode_image(frame, extension) assert isinstance(frame, bytes), "frame should be string path or bytes" cur_decode_frames.append(frame) videos_frames_decode.append(cur_decode_frames) # list in list sample[key] = pickle.dumps(videos_frames_decode) return sample
[docs] def reconstruct_custom_webdataset_format(samples, field_mapping: Optional[Dict[str, str]] = None): """ Reconstruct the original dataset to the WebDataset format. For all keys, they can be specified by `field_mapping` argument, which is a dict mapping from the target field key in the result format to the source field key in the original samples. :param samples: the input samples batch to be reconstructed :param field_mapping: the field mapping to construct the left fields. """ if field_mapping is None: field_mapping = {} assert isinstance(field_mapping, dict) # not specified -- return the original samples if len(field_mapping) == 0: return samples # construct the left fields reconstructed_sample = {} for tgt_field, src_field in field_mapping.items(): assert isinstance(src_field, str) or isinstance(src_field, list) if isinstance(src_field, str): reconstructed_sample[tgt_field] = samples[src_field] elif isinstance(src_field, list): reconstructed_sample[tgt_field] = {src_field_item: samples[src_field_item] for src_field_item in src_field} return reconstructed_sample