Source code for data_juicer.ops.mapper.s3_download_file_mapper

import asyncio
import copy
import os
import os.path as osp
from typing import List, Union

from loguru import logger

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.s3_utils import get_aws_credentials

boto3 = LazyLoader("boto3")
botocore_exceptions = LazyLoader("botocore.exceptions")

OP_NAME = "s3_download_file_mapper"


[docs] @OPERATORS.register_module(OP_NAME) class S3DownloadFileMapper(Mapper): """Mapper to download files from S3 to local files or load them into memory. This operator downloads files from S3 URLs (s3://...) or handles local files. It supports: - Downloading multiple files concurrently - Saving files to a specified directory or loading content into memory - Resume download functionality - S3 authentication with access keys - Custom S3 endpoints (for S3-compatible services like MinIO) The operator processes nested lists of URLs/paths, maintaining the original structure in the output.""" _batched_op = True
[docs] def __init__( self, download_field: str = None, save_dir: str = None, save_field: str = None, resume_download: bool = False, timeout: int = 30, max_concurrent: int = 10, # S3 credentials aws_access_key_id: str = None, aws_secret_access_key: str = None, aws_session_token: str = None, aws_region: str = None, endpoint_url: str = None, *args, **kwargs, ): """ Initialization method. :param download_field: The field name to get the URL/path to download. :param save_dir: The directory to save downloaded files. :param save_field: The field name to save the downloaded file content. :param resume_download: Whether to resume download. If True, skip the sample if it exists. :param timeout: (Deprecated) Kept for backward compatibility, not used for S3 downloads. :param max_concurrent: Maximum concurrent downloads. :param aws_access_key_id: AWS access key ID for S3. :param aws_secret_access_key: AWS secret access key for S3. :param aws_session_token: AWS session token for S3 (optional). :param aws_region: AWS region for S3. :param endpoint_url: Custom S3 endpoint URL (for S3-compatible services). :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) self.download_field = download_field self.save_dir = save_dir self.save_field = save_field self.resume_download = resume_download if not (self.save_dir or self.save_field): logger.warning( "Both `save_dir` and `save_field` are not specified. Use the default `image_bytes` key to " "save the downloaded contents." ) self.save_field = self.image_bytes_key if self.save_dir: os.makedirs(self.save_dir, exist_ok=True) self.timeout = timeout self.max_concurrent = max_concurrent # Prepare config dict for get_aws_credentials ds_config = {} if aws_access_key_id: ds_config["aws_access_key_id"] = aws_access_key_id if aws_secret_access_key: ds_config["aws_secret_access_key"] = aws_secret_access_key if aws_session_token: ds_config["aws_session_token"] = aws_session_token if aws_region: ds_config["aws_region"] = aws_region if endpoint_url: ds_config["endpoint_url"] = endpoint_url # Get credentials with priority: environment variables > operator parameters ( resolved_access_key_id, resolved_secret_access_key, resolved_session_token, resolved_region, ) = get_aws_credentials(ds_config) # Store S3 configuration (don't create client here to avoid serialization issues) self.s3_config = None self._s3_client = None if resolved_access_key_id and resolved_secret_access_key: self.s3_config = { "aws_access_key_id": resolved_access_key_id, "aws_secret_access_key": resolved_secret_access_key, } if resolved_session_token: self.s3_config["aws_session_token"] = resolved_session_token if resolved_region: self.s3_config["region_name"] = resolved_region if endpoint_url: self.s3_config["endpoint_url"] = endpoint_url logger.info(f"S3 configuration stored with endpoint: {endpoint_url or 'default'}") else: logger.info("No S3 credentials provided. S3 URLs will not be supported.")
@property def s3_client(self): """Lazy initialization of S3 client to avoid serialization issues with Ray.""" if self._s3_client is None and self.s3_config is not None: self._s3_client = boto3.client("s3", **self.s3_config) logger.debug("S3 client initialized (lazy)") return self._s3_client def _is_s3_url(self, url: str) -> bool: """Check if the URL is an S3 URL.""" return url.startswith("s3://") def _parse_s3_url(self, s3_url: str): """Parse S3 URL into bucket and key. Example: s3://bucket-name/path/to/file.mp4 -> ('bucket-name', 'path/to/file.mp4') """ if not s3_url.startswith("s3://"): raise ValueError(f"Invalid S3 URL: {s3_url}") parts = s3_url[5:].split("/", 1) bucket = parts[0] key = parts[1] if len(parts) > 1 else "" return bucket, key def _download_from_s3(self, s3_url: str, save_path: str = None, return_content: bool = False): """Download a file from S3. :param s3_url: S3 URL (s3://bucket/key) :param save_path: Local path to save the file :param return_content: Whether to return file content as bytes :return: (status, response, content, save_path) """ if not self.s3_client: raise ValueError("S3 client not initialized. Please provide AWS credentials.") try: bucket, key = self._parse_s3_url(s3_url) if save_path: # Ensure parent directory exists save_dir = os.path.dirname(save_path) if save_dir: os.makedirs(save_dir, exist_ok=True) # Download to file self.s3_client.download_file(bucket, key, save_path) logger.debug(f"Downloaded S3 file: {s3_url} -> {save_path}") # Read content if needed content = None if return_content: with open(save_path, "rb") as f: content = f.read() return "success", None, content, save_path elif return_content: # Download to memory response = self.s3_client.get_object(Bucket=bucket, Key=key) content = response["Body"].read() logger.debug(f"Downloaded S3 file to memory: {s3_url}") return "success", None, content, None else: return "success", None, None, None except botocore_exceptions.ClientError as e: error_msg = f"S3 download failed: {e}" logger.error(error_msg) return "failed", error_msg, None, None except Exception as e: error_msg = f"S3 download error: {e}" logger.error(error_msg) return "failed", error_msg, None, None
[docs] async def download_files_async(self, urls, return_contents, save_dir=None, **kwargs): """Download files asynchronously from S3.""" async def _download_file( semaphore: asyncio.Semaphore, idx: int, url: str, save_dir=None, return_content=False, **kwargs, ) -> dict: async with semaphore: try: status, response, content, save_path = "success", None, None, None # Handle S3 URLs (synchronous operation in async context) if self._is_s3_url(url): if save_dir: filename = os.path.basename(self._parse_s3_url(url)[1]) save_path = osp.join(save_dir, filename) # Check if file exists and resume is enabled if os.path.exists(save_path) and self.resume_download: if return_content: with open(save_path, "rb") as f: content = f.read() return idx, save_path, status, response, content # Download from S3 (run in executor to avoid blocking) loop = asyncio.get_event_loop() status, response, content, save_path = await loop.run_in_executor( None, self._download_from_s3, url, save_path, return_content ) return idx, save_path, status, response, content # Check for HTTP/HTTPS URLs - not supported if url.startswith("http://") or url.startswith("https://"): raise ValueError( f"HTTP/HTTPS URLs are not supported. This mapper only supports S3 URLs (s3://...) and local files. Got: {url}" ) # Handle local files if return_content: with open(url, "rb") as f: content = f.read() if save_dir: save_path = url return idx, save_path, status, response, content except Exception as e: status = "failed" response = str(e) save_path = None content = None return idx, save_path, status, response, content semaphore = asyncio.Semaphore(self.max_concurrent) tasks = [ _download_file(semaphore, idx, url, save_dir, return_contents[idx], **kwargs) for idx, url in enumerate(urls) ] results = await asyncio.gather(*tasks) results.sort(key=lambda x: x[0]) return results
def _flat_urls(self, nested_urls): """Flatten nested URLs while preserving structure information.""" flat_urls = [] structure_info = [] # save as original index, sub index for idx, urls in enumerate(nested_urls): if isinstance(urls, list): for sub_idx, url in enumerate(urls): flat_urls.append(url) structure_info.append((idx, sub_idx)) else: flat_urls.append(urls) structure_info.append((idx, -1)) # -1 means single str element return flat_urls, structure_info def _create_path_struct(self, nested_urls, keep_failed_url=True) -> List[Union[str, List[str]]]: """Create path structure for output.""" if keep_failed_url: reconstructed = copy.deepcopy(nested_urls) else: reconstructed = [] for item in nested_urls: if isinstance(item, list): reconstructed.append([None] * len(item)) else: reconstructed.append(None) return reconstructed def _create_save_field_struct(self, nested_urls, save_field_contents=None) -> List[Union[bytes, List[bytes]]]: """Create save field structure for output.""" if save_field_contents is None: save_field_contents = [] for item in nested_urls: if isinstance(item, list): save_field_contents.append([None] * len(item)) else: save_field_contents.append(None) else: # check whether the save_field_contents format is correct and correct it automatically for i, item in enumerate(nested_urls): if isinstance(item, list): if not save_field_contents[i] or len(save_field_contents[i]) != len(item): save_field_contents[i] = [None] * len(item) return save_field_contents
[docs] async def download_nested_urls( self, nested_urls: List[Union[str, List[str]]], save_dir=None, save_field_contents=None ): """Download nested URLs with structure preservation.""" flat_urls, structure_info = self._flat_urls(nested_urls) if save_field_contents is None: # not save contents, set return_contents to False return_contents = [False] * len(flat_urls) else: # if original content None, set bool value to True to get content else False to skip reload it return_contents = [] for item in save_field_contents: if isinstance(item, list): return_contents.extend([not c for c in item]) else: return_contents.append(not item) download_results = await self.download_files_async( flat_urls, return_contents, save_dir, ) if self.save_dir: reconstructed_path = self._create_path_struct(nested_urls) else: reconstructed_path = None failed_info = "" for i, (idx, save_path, status, response, content) in enumerate(download_results): orig_idx, sub_idx = structure_info[i] if status != "success": save_path = flat_urls[i] failed_info += "\n" + str(response) if save_field_contents is not None: if return_contents[i]: if sub_idx == -1: save_field_contents[orig_idx] = content else: save_field_contents[orig_idx][sub_idx] = content if self.save_dir: if sub_idx == -1: reconstructed_path[orig_idx] = save_path else: reconstructed_path[orig_idx][sub_idx] = save_path return save_field_contents, reconstructed_path, failed_info
[docs] def process_batched(self, samples): """Process a batch of samples.""" if self.download_field not in samples or not samples[self.download_field]: return samples batch_nested_urls = samples[self.download_field] if self.save_field: if not self.resume_download: if self.save_field in samples: raise ValueError( f"{self.save_field} is already in samples. " f"If you want to resume download, please set `resume_download=True`" ) save_field_contents = self._create_save_field_struct(batch_nested_urls) else: if self.save_field not in samples: save_field_contents = self._create_save_field_struct(batch_nested_urls) else: save_field_contents = self._create_save_field_struct(batch_nested_urls, samples[self.save_field]) else: save_field_contents = None save_field_contents, reconstructed_path, failed_info = asyncio.run( self.download_nested_urls( batch_nested_urls, save_dir=self.save_dir, save_field_contents=save_field_contents ) ) if self.save_dir: samples[self.download_field] = reconstructed_path if self.save_field: samples[self.save_field] = save_field_contents if len(failed_info): logger.error(f"Failed files:\n{failed_info}") return samples