import io
import json
from typing import Any, Dict, Optional, Union
import numpy as np
import PIL.Image
from data_juicer.utils.mm_utils import load_audio
_VIDEO_EXTENTIONS = ["mp4", "mov", "avi", "mkv"]
_AUDIO_EXTENTIONS = ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]
_IMAGE_EXTENTIONS = ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]
[文档]
def read_file_as_bytes(file_path):
with open(file_path, "rb") as f:
return f.read()
def _load_image(value, format="PIL"):
import numpy as np
import PIL.Image
if format == "PIL":
return PIL.Image.open(io.BytesIO(value))
else:
return np.asarray(PIL.Image.open(io.BytesIO(value)))
def _custom_default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True):
"""A custom decoder for webdataset. Support multiple images list decoding.
This handles common file extensions: .txt, .cls, .cls2,
.jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl.
These are the most common extensions used in webdataset.
For other extensions, users can provide their own decoder.
Args:
sample: sample, modified in place
"""
sample = dict(sample)
for key, value in sample.items():
extension = key.split(".")[-1]
if key.startswith("__"):
continue
elif extension in ["txt", "text"]:
sample[key] = value.decode("utf-8")
elif extension in ["cls", "cls2"]:
sample[key] = int(value.decode("utf-8"))
elif extension in _IMAGE_EXTENTIONS:
sample[key] = _load_image(value, format)
elif extension in [s + "s" for s in _IMAGE_EXTENTIONS]:
import pickle
value = pickle.loads(value)
sample[key] = [_load_image(v, format) for v in value]
elif extension == "json":
sample[key] = json.loads(value)
elif extension == "npy":
import numpy as np
sample[key] = np.load(io.BytesIO(value))
elif extension == "mp":
import msgpack
sample[key] = msgpack.unpackb(value, raw=False)
elif extension in ["pt", "pth"]:
import torch
sample[key] = torch.load(io.BytesIO(value))
elif extension in ["pickle", "pkl"]:
import pickle
sample[key] = pickle.loads(value)
elif extension in _AUDIO_EXTENTIONS:
sample[key] = load_audio(value)
elif extension in [s + "s" for s in _AUDIO_EXTENTIONS]:
import pickle
sample[key] = [load_audio(v) for v in pickle.loads(value)]
elif extension in _VIDEO_EXTENTIONS:
import pickle
value = pickle.loads(value)
sample[key] = [_load_image(frame, format) for frame in value]
elif extension in [s + "s" for s in _VIDEO_EXTENTIONS]:
import pickle
videos_frames_list = pickle.loads(value)
videos_frames_decode = []
for video_frames in videos_frames_list:
videos_frames_decode.append([_load_image(frame) for frame in video_frames])
# list in list
sample[key] = videos_frames_decode
return sample
def _encode_image(value, extension):
from ray.data._internal.datasource.webdataset_datasource import extension_to_format
if isinstance(value, np.ndarray):
value = PIL.Image.fromarray(value)
elif isinstance(value, bytes):
return value
elif isinstance(value, str):
return read_file_as_bytes(value)
assert isinstance(value, PIL.Image.Image)
stream = io.BytesIO()
value.save(stream, format=extension_to_format.get(extension.lower(), extension))
return stream.getvalue()
def _encode_audio(value):
if isinstance(value, str):
return read_file_as_bytes(value)
elif isinstance(value, bytes):
return value
assert isinstance(value, bytes), f"value should be a bytes, got {type(value)}"
return value
def _custom_default_encoder(sample: Dict[str, Any], format: Optional[Union[str, bool]] = True):
"""A custom encoder for webdataset.
In addition to the original encoding, it also supports encode image lists and byte type images.
This handles common file extensions: .txt, .cls, .cls2, .jpg,
.png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl, .jpgs (images list),
.jpegs (images list), .pngs (images list) .mp3 (audio) .mp3s (audios list)
.mp4 (video frames list) .mp4s (multi videos frames list) and so on.
Please note that the .mp4s extension is used to encode multi videos frames list,
the data format should be list of list of frames:
[
[video1_frame1, video1_frame2, ...], # video1 frames path or bytes
[video2_frame1, video2_frame2, ...], # video2 frames path or bytes
...
]
These are the most common extensions used in webdataset.
For other extensions, users can provide their own encoder.
Args:
sample (Dict[str, Any]): sample
"""
sample = dict(sample)
for key, value in sample.items():
extension = key.split(".")[-1]
if key.startswith("__"):
continue
elif extension in ["txt"]:
if isinstance(value, list):
sample[key] = [v.encode("utf-8") for v in value]
else:
sample[key] = value.encode("utf-8")
elif extension in ["cls", "cls2"]:
sample[key] = str(value).encode("utf-8")
elif extension in _IMAGE_EXTENTIONS:
sample[key] = _encode_image(value, extension)
elif extension in [s + "s" for s in _IMAGE_EXTENTIONS]:
import pickle
extension = extension.rstrip("s")
sample[key] = pickle.dumps([_encode_image(v, extension) for v in value])
elif extension == "json":
sample[key] = json.dumps(value).encode("utf-8")
elif extension == "npy":
import numpy as np
stream = io.BytesIO()
np.save(stream, value)
sample[key] = stream.getvalue()
elif extension == "mp":
import msgpack
sample[key] = msgpack.dumps(value)
elif extension in ["pt", "pth"]:
import torch
stream = io.BytesIO()
torch.save(value, stream)
sample[key] = stream.getvalue()
elif extension in ["pickle", "pkl"]:
import pickle
stream = io.BytesIO()
pickle.dump(value, stream)
sample[key] = stream.getvalue()
elif extension in _AUDIO_EXTENTIONS:
sample[key] = _encode_audio(value)
elif extension in [s + "s" for s in _AUDIO_EXTENTIONS]:
import pickle
extension = extension.rstrip("s")
sample[key] = pickle.dumps([_encode_audio(v) for v in value])
elif extension in _VIDEO_EXTENTIONS:
import pickle
extension = "jpg"
sample[key] = pickle.dumps([_encode_image(frame, extension) for frame in value])
elif extension in [s + "s" for s in _VIDEO_EXTENTIONS]:
import pickle
extension = "jpg"
videos_frames_list = value
videos_frames_decode = []
for video_frames in videos_frames_list:
cur_decode_frames = []
for frame in video_frames:
if isinstance(frame, str):
frame = _encode_image(frame, extension)
assert isinstance(frame, bytes), "frame should be string path or bytes"
cur_decode_frames.append(frame)
videos_frames_decode.append(cur_decode_frames)
# list in list
sample[key] = pickle.dumps(videos_frames_decode)
return sample